#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
from dataclasses import dataclass

import sqlalchemy.orm
from more_itertools import first

from entity_read.entity import Entity
from entity_read.utils.sqlalchemy_dataclass_entity import default_registry
from .order import Order
from .expression import Expression
from .query_token import QueryToken


@dataclass(frozen=True, kw_only=True, repr=False)
class Reference(QueryToken):
    key: str

    def get_related_entity(self, entity_type: type[Entity]) -> type[Entity]:
        referenced_column = self._get_referenced_column(entity_type)
        referenced_table = referenced_column.table
        referenced_entity = get_entity_by_table(referenced_table)
        return referenced_entity

    def get_remote_column_name_to_local(self, entity_type: type[Entity]) -> dict[str, str]:
        referenced_column = self._get_referenced_column(entity_type)
        return {referenced_column.name: self.key}

    def _get_referenced_column(self, entity_type: type[Entity]) -> sqlalchemy.Column:
        if (column := entity_type.get_key_to_column().get(self.key)) is None:
            raise KeyError(f"Not found column {self.key!r} in {entity_type.__name__}({entity_type.get_column_names()})")
        foreign_keys = column.foreign_keys
        if not foreign_keys:
            raise KeyError(f"Column {self.key!r} in {entity_type.__name__} references nothing")
        if len(foreign_keys) > 1:
            raise NotImplementedError(f"Column {self.key!r} in {entity_type.__name__} references multiple tables")
        foreign_key: sqlalchemy.ForeignKey = first(foreign_keys)
        referenced_column = foreign_key.column
        return referenced_column

    def subquery(
            self,
            attrs: list[Expression] | None = None,
            vars: dict[str, Expression] | None = None,
            filters: list[Expression] | None = None,
            searches: list[Expression] | None = None,
            orders: list[Order] | None = None,
            limit: int | None = None,
            offset: int | None = None,
    ):
        from . import SubQuery
        return SubQuery(
            over=self, attrs=attrs or [], vars=vars or {},
            filters=filters or [], searches=searches or [],
            orders=orders or [], limit=limit, offset=offset
        )

    def exists(
            self,
            vars: dict[str, Expression] | None = None,
            filters: list[Expression] | None = None,
            searches: list[Expression] | None = None,
    ):
        from .function import Function
        subquery = self.subquery(vars=vars, filters=filters, searches=searches)
        return Function(key="exists", args=[subquery])

    def shortcut(self) -> str:
        return f"ref.{self.key}"


def get_entity_by_table(
        table: sqlalchemy.Table,
        registry: sqlalchemy.orm.registry = default_registry
) -> type[Entity] | None:
    found_classes = {
        cls for cls in registry._class_registry.values()
        if getattr(cls, '__table__', None) is table and issubclass(cls, Entity)
    }
    if len(found_classes) > 1:
        raise ValueError(f"Multiple classes found for table {table.name}.")
    elif found_classes:
        return first(found_classes)
    return None
