#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import asyncio
import dataclasses
import functools
from dataclasses import dataclass
from logging import getLogger
from types import MappingProxyType
from typing import Mapping, Any, ClassVar, Callable, Awaitable

import sqlalchemy.orm
from async_tools import acall
from init_helpers.dict_to_dataclass import NoValue
from sqlalchemy import Column

from .entity_field import get_sqlalchemy_metadata
from .utils.no_default import NoDefault

logger = getLogger(__name__)


@dataclass
class Entity:
    """
    Entity - сущность бизнес логики, она:
    - должна каким-то образом идентифицироваться
    - может быть сохранена в базе или передана по сети
    """
    __table__: ClassVar[sqlalchemy.Table] = None
    __tablename__: ClassVar[str] = None
    __table_schema_name__: ClassVar[str] = None

    @classmethod
    def get_entity_by_table(cls, table: sqlalchemy.Table) -> type['Entity'] | None:
        for mapper in cls.get_class_manager().registry.mappers:
            if mapper.tables == [table]:
                return mapper.class_

    @classmethod
    def get_class_manager(cls) -> sqlalchemy.orm.ClassManager:
        return getattr(cls, sqlalchemy.orm.ClassManager.MANAGER_ATTR)

    @classmethod
    def get_table(cls) -> sqlalchemy.Table:
        return cls.__table__

    @classmethod
    def get_table_name(cls) -> str:
        return cls.__tablename__

    @classmethod
    def get_table_schema_name(cls) -> str:
        return cls.__table_schema_name__

    @classmethod
    @functools.cache
    def get_fields(cls) -> tuple[dataclasses.Field, ...]:
        if fields := cls.__dict__.get(dataclasses._FIELDS):
            return tuple(f for f in fields.values() if f._field_type is dataclasses._FIELD)
        raise TypeError('must be called with a dataclass type or instance')

    @classmethod
    @functools.cache
    def get_field_to_sqlalchemy_data(cls) -> dict[dataclasses.Field, Any]:
        return {field: get_sqlalchemy_metadata(field) for field in cls.get_fields()}

    @classmethod
    @functools.cache
    def get_key_to_relation(cls) -> Mapping[str, sqlalchemy.orm.Relationship]:
        return {
            field.name: property_
            for field, property_ in cls.get_field_to_sqlalchemy_data().items()
            if isinstance(property_, sqlalchemy.orm.Relationship)
        }

    @classmethod
    @functools.cache
    def get_relations(cls) -> tuple[sqlalchemy.orm.Relationship, ...]:
        return tuple(val for val in cls.get_field_to_sqlalchemy_data().values()
                     if isinstance(val, sqlalchemy.orm.Relationship))

    @classmethod
    @functools.cache
    def get_relation(cls, relation_name: str) -> sqlalchemy.orm.Relationship:
        return cls.get_key_to_relation()[relation_name]

    @classmethod
    @functools.cache
    def get_relation_names(cls) -> tuple[str, ...]:
        return tuple(field.name for field, data in cls.get_field_to_sqlalchemy_data().items()
                     if isinstance(data, sqlalchemy.orm.Relationship))

    @classmethod
    @functools.cache
    def get_key_to_column(cls) -> Mapping[str, sqlalchemy.Column]:
        return {
            field.name: property_
            for field, property_ in cls.get_field_to_sqlalchemy_data().items()
            if isinstance(property_, sqlalchemy.Column)
        }

    @classmethod
    @functools.cache
    def get_columns(cls) -> tuple[sqlalchemy.Column, ...]:
        return tuple(val for val in cls.get_field_to_sqlalchemy_data().values() if isinstance(val, Column))

    @classmethod
    @functools.cache
    def get_column(cls, column_name: str) -> sqlalchemy.Column:
        return cls.get_key_to_column()[column_name]

    @classmethod
    @functools.cache
    def get_column_names(cls) -> tuple[str, ...]:
        return tuple(field.name for field, col in cls.get_field_to_sqlalchemy_data().items() if isinstance(col, Column))

    @classmethod
    @functools.cache
    def get_key_to_column_field(cls) -> Mapping[str, dataclasses.Field]:
        return {
            field.name: field
            for field, property_ in cls.get_field_to_sqlalchemy_data().items()
            if isinstance(property_, sqlalchemy.Column)
        }

    @classmethod
    @functools.cache
    def get_primary_key_columns(cls) -> tuple[sqlalchemy.Column, ...]:
        return tuple(col for col in cls.get_columns() if col.primary_key)

    @classmethod
    @functools.cache
    def get_primary_key_names(cls) -> tuple[str, ...]:
        return tuple(
            field.name for field, col in cls.get_field_to_sqlalchemy_data().items()
            if isinstance(col, Column) and col.primary_key
        )

    @classmethod
    @functools.cache
    def get_non_primary_key_names(cls) -> tuple[str, ...]:
        return tuple(name for name in cls.get_column_names() if name not in cls.get_primary_key_names())

    @classmethod
    @functools.cache
    def get_identifier_columns(cls) -> tuple[Column, ...]:
        return tuple(get_sqlalchemy_metadata(field) for field in cls.get_fields() if field.metadata.get("identifier"))

    def get_primary_key_name_to_value(self) -> dict[str, Any]:
        return {name: getattr(self, name) for name in self.get_primary_key_names()}

    def get_primary_key_values(self) -> tuple[Any, ...]:
        return tuple(getattr(self, name) for name in self.get_primary_key_names())

    def as_dict(self, plain: bool = False) -> dict[str, Any]:
        result = {}
        for column in self.get_columns():
            value = getattr(self, column.name)
            if value in (NoValue, NoDefault) or value is None and not column.nullable:
                continue
            result[column.name] = value

        if plain:
            return result

        for relation_name, relation in self.get_key_to_relation().items():
            value = getattr(self, relation_name)
            if value in (NoValue, NoDefault):
                continue
            result[relation_name] = [val.as_dict() for val in value] if relation.uselist else value.as_dict()
        return result

    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
    # after sqlalchemy applied ->

    # @classmethod
    # @functools.cache
    # def get_table(cls) -> sqlalchemy.Table:
    #     column: sqlalchemy.Column = next(iter(cls.get_primary_key_columns()))
    #     result = column.table
    #     if result is None:
    #         raise RuntimeError("Table not yet initialised")
    #     return result

    # def get_foreign_keys(self) -> list[sqlalchemy.ForeignKey]:


    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
    # insert_default ->

    @classmethod
    @functools.cache
    def get_key_to_insert_default(cls) -> Mapping[str, Callable[[...], Awaitable[Any]]]:
        # noinspection PyDataclass
        return MappingProxyType({
            field.name: insert_default
            for field in dataclasses.fields(cls)
            if (insert_default := field.metadata.get("insert_default"))
        })

    @classmethod
    @functools.cache
    def get_key_to_relation_with_insert_default(cls) -> Mapping[str, Callable[[...], Awaitable[Any]]]:
        return MappingProxyType({
            key: relation
            for key, relation in cls.get_key_to_relation().items()
            if get_related_entity_type(relation).has_insert_defaults()
        })

    @classmethod
    @functools.cache
    def has_insert_defaults(cls) -> bool:
        return bool(cls.get_key_to_insert_default() or cls.get_key_to_relation_with_insert_default())

    async def resolve_insert_defaults(self, context) -> None:
        for key, insert_default in self.get_key_to_insert_default().items():
            logger.debug("async default: %s, %s", type(self), key)
            setattr(self, key, await acall(insert_default(context)))

        for key, relation in self.get_key_to_relation_with_insert_default().items():
            value = getattr(self, key)
            if relation.uselist:
                if isinstance(value, list):
                    logger.debug("insert default one-to_many relation: %s, %s", type(self), key)
                    await asyncio.gather(*[val.resolve_insert_defaults(context) for val in value])
            elif isinstance(value, Entity):
                logger.debug("insert default ono-to-one relation: %s, %s", type(self), key)
                await value.resolve_insert_defaults(context)
        return

    # <- insert_default
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #


def get_related_entity_type(relation: sqlalchemy.orm.RelationshipProperty) -> type[Entity]:
    return relation.argument
