#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import functools
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypeVar, Self, Generic

from entity_read import sql
from entity_read.entity import Entity
from entity_read.sql.atoms import Selectable

from .expression import Expression
from .remote_entity import RemoteEntity


if TYPE_CHECKING:
    from .condition import (
        ConditionEq, ConditionNe, ConditionLt, ConditionLe, ConditionGe, ConditionGt, ConditionIn,
        ConditionNotIn, ConditionIsNull, ConditionIsNotNull, ConditionOr, ConditionAnd, ConditionNot,
    )


ColumnType = TypeVar('ColumnType', covariant=True)


@dataclass(frozen=True, kw_only=True, repr=False)
class Column(Expression, Generic[ColumnType]):
    key: str
    parent: RemoteEntity | type[RemoteEntity] | None = field(hash=False, default=None)

    def __class_getitem__(cls, item: type[ColumnType]):
        name = f"{cls.__name__}[{item.__name__}]"
        return type(name, (cls,), {"_column_type": item})

    def __call__(self, parent: RemoteEntity | type[RemoteEntity] | None) -> "Column[ColumnType]":
        return type(self)(key=self.key, parent=parent)

    @functools.cached_property
    def _type_arg(self) -> type[ColumnType]:
        return getattr(self, "_column_type")

    def shortcut(self) -> str:
        return f"col.{self.key}"

    def get_parent(self) -> RemoteEntity | type[RemoteEntity] | None:
        return self.parent

    def base(self) -> Self:
        return Column(key=self.key, parent=None)

    def eval(self, entity_type: type[Entity], variables: dict[str, Selectable]) -> sql.atoms.column.Column:
        if (column_field := entity_type.get_key_to_column_field().get(self.key)) is None:
            raise KeyError(f"Not found column {self.key!r} in {entity_type.__name__}({entity_type.get_column_names()})")

        return sql.atoms.column.Column(key=self.key, type_=column_field.type)

    def eq(self, selectable: Expression) -> 'ConditionEq':
        from .condition import ConditionEq
        return ConditionEq(self, selectable)

    def ne(self, selectable: Expression) -> 'ConditionNe':
        from .condition import ConditionNe
        return ConditionNe(self, selectable)

    def lt(self, selectable: Expression) -> 'ConditionLt':
        from .condition import ConditionLt
        return ConditionLt(self, selectable)

    def le(self, selectable: Expression) -> 'ConditionLe':
        from .condition import ConditionLe
        return ConditionLe(self, selectable)

    def ge(self, selectable: Expression) -> 'ConditionGe':
        from .condition import ConditionGe
        return ConditionGe(self, selectable)

    def gt(self, selectable: Expression) -> 'ConditionGt':
        from .condition import ConditionGt
        return ConditionGt(self, selectable)

    def in_(self, selectable: Expression) -> 'ConditionIn':
        from .condition import ConditionIn
        return ConditionIn(self, selectable)

    def not_in(self, selectable: Expression) -> 'ConditionNotIn':
        from .condition import ConditionNotIn
        return ConditionNotIn(self, selectable)

    def is_null(self) -> 'ConditionIsNull':
        from .condition import ConditionIsNull
        return ConditionIsNull(self)

    def is_not_null(self) -> 'ConditionIsNotNull':
        from .condition import ConditionIsNotNull
        return ConditionIsNotNull(self)

    def or_(self, selectable: Expression) -> 'ConditionOr':
        from .condition import ConditionOr
        return ConditionOr(self, selectable)

    def and_(self, selectable: Expression) -> 'ConditionAnd':
        from .condition import ConditionAnd
        return ConditionAnd(self, selectable)

    def not_(self) -> 'ConditionNot':
        from .condition import ConditionNot
        return ConditionNot(self)
