import datetime
import logging
from collections import defaultdict
from typing import List, Dict, Any, Union, Optional

from sqlalchemy import Integer, BigInteger, SmallInteger, String, Float, Boolean, JSON, ARRAY, DateTime, cast, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ARRAY as PG_ARRAY, JSON as PG_JSON, JSONB
from sqlalchemy.sql.elements import BinaryExpression

from ..entity_helpers.analyze import analyze_entity
from ..entity_helpers.entity_name_to_class import get_entity_by_name
from ..entity_helpers.sqlalchemy_base import Values
from ..exceptions import WrongArgumentsException
from ..utils.singleton import Singleton


logger = logging.getLogger(__name__)


class SQLAlchemyCRUD(metaclass=Singleton):
    @classmethod
    async def add(cls,
                  entity,
                  column_name_to_value_list: Union[List[Dict[str, Any]], Dict[str, Any]],
                  db_session: AsyncSession,
                  ignore_relation: bool = False
                  ) -> List[int]:
        entity_description = analyze_entity(entity)

        if not isinstance(column_name_to_value_list, list):
            column_name_to_value_list = [column_name_to_value_list]

        entity_instances = [
            await cls.recursive_add(entity, val, db_session, ignore_relation) for val in column_name_to_value_list
        ]
        db_session.add_all(entity_instances)
        await db_session.flush()

        p_keys = [getattr(entity_instance, entity_description.p_key.key) for entity_instance in entity_instances]

        return p_keys

    @classmethod
    async def get(cls, *args, **kwargs):
        raise NotImplementedError("Get method  hasn't been implemented yet.")

    @classmethod
    async def update(cls,
                     entity,
                     column_name_to_value_list: Union[List[Dict[str, Any]], Dict[str, Any]],
                     db_session: AsyncSession,
                     ) -> bool:
        entity_description = analyze_entity(entity)

        if not isinstance(column_name_to_value_list, list):
            column_name_to_value_list = [column_name_to_value_list]

        # TODO composite primary keys not working
        consideration_objects_p_keys = [state[entity_description.p_key.name] for state in column_name_to_value_list]
        filters = [entity_description.p_key.in_(consideration_objects_p_keys)]

        updated_status = await cls.recursive_update(entity, column_name_to_value_list, filters, db_session)

        return updated_status

    @classmethod
    async def delete(cls,
                     entity,
                     filters: List[BinaryExpression],
                     db_session: AsyncSession
                     ) -> bool:
        entity_description = analyze_entity(entity)

        stmt = entity.__table__.delete().where(*filters).returning(entity_description.p_key)
        answer = await db_session.execute(stmt)

        deleted_primary_keys = [row[0] for row in answer.fetchall()]
        return bool(deleted_primary_keys)

    @classmethod
    async def plain_update(cls,
                           entity,
                           column_name_to_value_list: List[Dict[str, Any]],
                           db_session: AsyncSession
                           ) -> int:
        entity_description = analyze_entity(entity)

        simple_column_names = entity_description.simple_column_names()
        column_names_tuple_to_values_tuple = defaultdict(list)
        for column_name_to_value in column_name_to_value_list:
            needed_column_name_to_value = {
                c_name: val for c_name, val in column_name_to_value.items() if c_name in simple_column_names
            }
            column_names, values = tuple(), tuple()
            for column_name, value in sorted(needed_column_name_to_value.items()):
                column_names += (column_name,)
                values += (value,)
            column_names_tuple_to_values_tuple[column_names].append(values)

        columns_to_map = set(entity_description.primary_key)
        column_names_to_map = {column.name for column in columns_to_map}
        updated_rows_amount = 0
        for column_name_tuple, values in column_names_tuple_to_values_tuple.items():
            column_name_set = set(column_name_tuple)
            if column_names_to_map > column_name_set:
                raise WrongArgumentsException(f'got {type(entity).__name__} to update '
                                              f'without necessary column_names: {column_names_to_map}, '
                                              f'got only {column_name_set}')
            column_names_to_update = column_name_set - column_names_to_map

            if not column_names_to_update:
                continue

            columns = list(map(lambda x: getattr(entity, x), column_name_tuple))
            sql_values = Values(values, columns).alias('states')
            key_name_to_column = {
                key: getattr(sql_values.c, f'column{i + 1}') for i, key in enumerate(column_name_tuple)
            }

            set_values = {}
            for name in column_names_to_update:
                type_ = entity_description.column_name_to_column[name].type
                column = key_name_to_column[name]
                set_values[name] = cast(column, type_)

            wheres = [column == key_name_to_column[column.name] for column in columns_to_map]
            stmt = entity.__table__.update().values(**set_values)
            for where in wheres:
                stmt = stmt.where(where)

            answer = await db_session.execute(stmt)
            updated_rows_amount += answer.rowcount

        return updated_rows_amount

    @classmethod
    async def recursive_add(cls,
                            entity,
                            column_name_to_value: Dict[str, Any],
                            db_session: AsyncSession,
                            ignore_relation: bool = False
                            ) -> type:
        entity_description = analyze_entity(entity)
        entity_args = {}
        logger.debug(f"Recursive add, column_name_to_value: {column_name_to_value}, ignore_relation: {ignore_relation}")
        for column_name, value in column_name_to_value.items():
            if (relation := entity_description.relations.get(column_name)) and not ignore_relation:
                remote_entity = get_entity_by_name(relation.table_name)
                column = entity_description.column_name_to_column[column_name]
                if value is None:
                    continue
                elif isinstance(value, list):
                    if not column.property.uselist:
                        raise WrongArgumentsException(f"{column_name} is OneToMany! Use list")
                    entity_args[column_name] = [await cls.recursive_add(remote_entity, v, db_session) for v in value]
                elif isinstance(value, dict):
                    if column.property.uselist:
                        raise WrongArgumentsException(f"{column_name} is OneToOne! Use dict")
                    entity_args[column_name] = await cls.recursive_add(remote_entity, value, db_session)
                else:
                    raise WrongArgumentsException(f"Check relation field type ({column_name})."
                                                  f"Allowed relation type: list or dict.")
            elif column := entity_description.column_name_to_column.get(column_name):
                column_type = type(column.type)
                try:
                    if column.nullable and value is None:
                        entity_args[column_name] = None
                    elif column_type in (Integer, BigInteger, SmallInteger):
                        entity_args[column_name] = int(value)
                    elif column_type == String:
                        entity_args[column_name] = str(value)
                    elif column_type in (Float, DOUBLE_PRECISION):
                        entity_args[column_name] = float(value)
                    elif column_type == Boolean:
                        entity_args[column_name] = bool(value)
                    elif column_type == DateTime:
                        if isinstance(value, (int, float)):
                            if value > 2 ** 32:
                                value = value / 1000
                            value = datetime.datetime.fromtimestamp(value)
                        entity_args[column_name] = value
                    elif column_type in (JSON, JSONB, PG_JSON):
                        entity_args[column_name] = value
                    elif column_type in (ARRAY, PG_ARRAY):
                        entity_args[column_name] = list(value)
                    else:
                        raise WrongArgumentsException(f'Unexpected type {type(column_type)}')
                except ValueError as e:
                    raise WrongArgumentsException(entity.__name__, column_name, value, e, )

        return entity(**entity_args)

    @classmethod
    async def recursive_update(cls,
                               entity,
                               column_name_to_value_list: List[Dict[str, Any]],
                               filters: List[BinaryExpression],
                               db_session: AsyncSession
                               ) -> bool:
        entity_description = analyze_entity(entity)

        updated_status = False

        answer = await db_session.execute(select(entity_description.p_key).where(*filters))
        existent_p_keys = [row[0] for row in answer.all()]
        p_keys_for_delete = existent_p_keys.copy()

        objects_to_create, objects_to_update = [], []
        for column_name_to_value in column_name_to_value_list:
            p_key = column_name_to_value.get(entity_description.p_key.name)
            if p_key and p_key in existent_p_keys:
                objects_to_update.append(column_name_to_value)
                p_keys_for_delete.remove(p_key)
            else:
                objects_to_create.append(column_name_to_value)

        if p_keys_for_delete:
            updated_status += await cls.delete(entity, [entity_description.p_key.in_(p_keys_for_delete)], db_session)

        relational_table_name_to_states = defaultdict(list)
        relational_table_name_to_filter = {}

        logger.debug(f"objects_to_update: {objects_to_update}")
        updated_rows_amount = await cls.plain_update(entity, objects_to_update, db_session)
        logger.debug(f"objects_to_create: {objects_to_create}")
        new_object_ids = await cls.add(entity, objects_to_create, db_session, ignore_relation=True)

        updated_status += bool(updated_rows_amount) or bool(new_object_ids)

        for object_id, column_name_to_value in zip(new_object_ids + [None] * len(objects_to_update),
                                                   objects_to_create + objects_to_update):
            if object_id is None:
                object_id = column_name_to_value.get(entity_description.p_key.name)
            for relation_column_name, relation in entity_description.relations.items():
                distant_key = relation.distant_column.name
                distant_table = relation.table_name
                relational_table_name_to_filter[distant_table] = [relation.distant_column.in_(existent_p_keys)]
                if relation_column_name in column_name_to_value:
                    relational_values = column_name_to_value[relation_column_name]
                    if relational_values is None:
                        continue
                    if not relational_values:
                        relational_values = []
                    elif not isinstance(relational_values, list):
                        relational_values = [relational_values]
                    for value in relational_values:
                        logger.debug(f'Set {distant_key} to {object_id} in {value}')
                        value[distant_key] = object_id
                    relational_table_name_to_states[distant_table] += relational_values

        for relational_table_name in relational_table_name_to_states:
            distant_entity = get_entity_by_name(relational_table_name)
            states = relational_table_name_to_states[relational_table_name]
            filters = relational_table_name_to_filter[relational_table_name]
            updated_status += await cls.recursive_update(distant_entity, states, filters, db_session)

        return bool(updated_status)
