#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
import functools
from logging import getLogger
from typing import Type, Optional, Any, Union, Iterable

from sqlalchemy import Column
from sqlalchemy.orm import InstrumentedAttribute

from .atoms.node import Node
from .atoms.order import EOrder
from .atoms.root import ERoot
from .atoms.selectable.alias import EAlias
from .atoms.selectable.column import EColumn
from .atoms.selectable.function import EFunction, EFunctionGT, EFunctionGE, EFunctionLT, EFunctionLE, EFunctionEQ, \
    EFunctionLen, EFunctionIncluded, EFunctionStringContains, EFunctionCastToStr, EFunctionGetJsonField, \
    EFunctionGetJsonFieldFromEach, EFunctionIsNull, EFunctionNotNull, EFunctionOr, EFunctionAnd, EFunctionNE
from .atoms.selectable.literal import ELiteral
from ..entity import Entity
from ..utils.proxy import EntityProxy, AttributeProxy

logger = getLogger(__name__)


def _auto_cast(value, e_root_allowed: bool = True):
    if isinstance(value, (bool, int, float, list, set, str)):
        return ELiteral(value)
    if isinstance(value, (InstrumentedAttribute, Column)):
        return EColumn(value.key)
    if isinstance(value, ERoot):
        return value if e_root_allowed else EColumn(value.alias)
    if isinstance(value, (EColumn, ELiteral, EFunction, ERoot, EAlias)):
        return value
    if isinstance(value, EntityProxy):
        return value
    if isinstance(value, AttributeProxy):
        return value
    logger.error(f'_auto_cast>- unexpected value: %s(%s)', value, type(value))
    assert False


_auto_cast_strict = functools.partial(_auto_cast, e_root_allowed=False)


class FunctionCollection:
    @staticmethod
    def greater(left, right) -> EFunctionGT:
        return EFunctionGT([_auto_cast_strict(left), _auto_cast_strict(right)])

    @staticmethod
    def greater_or_equal(left, right) -> EFunctionGE:
        return EFunctionGE([_auto_cast_strict(left), _auto_cast_strict(right)])

    @staticmethod
    def less(left, right) -> EFunctionLT:
        return EFunctionLT([_auto_cast_strict(left), _auto_cast_strict(right)])

    @staticmethod
    def less_or_equal(left, right) -> EFunctionLE:
        return EFunctionLE([_auto_cast_strict(left), _auto_cast_strict(right)])

    @staticmethod
    def equal(left, right) -> EFunctionEQ:
        return EFunctionEQ([_auto_cast_strict(left), _auto_cast_strict(right)])

    @staticmethod
    def non_equal(left, right) -> EFunctionNE:
        return EFunctionNE([_auto_cast_strict(left), _auto_cast_strict(right)])

    @staticmethod
    def len(value) -> EFunctionLen:
        return EFunctionLen([_auto_cast_strict(value)])

    @staticmethod
    def included(what, into) -> EFunctionIncluded:
        return EFunctionIncluded([_auto_cast_strict(what), _auto_cast_strict(into)])

    @staticmethod
    def string_contains(what, into) -> EFunctionStringContains:
        return EFunctionStringContains([_auto_cast_strict(what), _auto_cast_strict(into)])

    @staticmethod
    def to_str(what) -> EFunctionCastToStr:
        return EFunctionCastToStr([_auto_cast_strict(what)])

    @staticmethod
    def get_json_field(json_object, field_name: str) -> EFunctionGetJsonField:
        return EFunctionGetJsonField([_auto_cast_strict(json_object), _auto_cast_strict(field_name)])

    @staticmethod
    def get_json_field_from_each(json_object, field_name: str) -> EFunctionGetJsonFieldFromEach:
        assert "'" not in field_name
        return EFunctionGetJsonFieldFromEach([_auto_cast_strict(json_object), field_name])

    @staticmethod
    def is_null(what) -> EFunctionIsNull:
        return EFunctionIsNull([_auto_cast_strict(what)])

    @staticmethod
    def not_null(what) -> EFunctionNotNull:
        return EFunctionNotNull([_auto_cast_strict(what)])

    @staticmethod
    def or_(*what) -> EFunctionOr:
        return EFunctionOr([_auto_cast_strict(val) for val in what])

    @staticmethod
    def and_(*what) -> EFunctionAnd:
        return EFunctionAnd([_auto_cast_strict(val) for val in what])


op = FunctionCollection


def order(column, asc: bool = True, nulls_last: Optional[bool] = None) -> EOrder:
    if not isinstance(column, (EColumn, EFunction)):
        column = _auto_cast_strict(column)
    result = EOrder(column, asc=asc, nulls_last=nulls_last)
    return result


def _query(entity: Union[Type[Entity], str], alias: Optional[str] = None, attrs: Optional[Iterable[Any]] = None,
           vars: Optional[Iterable[Any]] = None, filter: Optional[Iterable[Any]] = None,
           order: Optional[Iterable[EOrder]] = None, limit: Optional[int] = None,
           offset: Optional[int] = None) -> ERoot:
    attrs = attrs or []
    vars = vars or []
    filter = filter or []
    order = [] if order is None else list(order)

    e_attrs = list(map(_auto_cast, attrs))
    e_vars = list(map(_auto_cast, vars))
    e_filter = list(map(_auto_cast_strict, filter))
    result = ERoot(
        entity, alias=alias, attrs=e_attrs, vars=e_vars, filter=e_filter, order=order, limit=limit, offset=offset
    )
    return result


# noinspection PyTypeChecker
def query(entity: Type[Entity], alias: Optional[str] = None, attrs: Optional[Iterable[Any]] = None,
          vars: Optional[Iterable[Any]] = None, filters: Optional[Iterable[Any]] = None,
          orders: Optional[Iterable[EOrder]] = None,
          limit: Optional[int] = None, offset: Optional[int] = None) -> ERoot:
    return _query(
        entity=entity, alias=alias, attrs=attrs, vars=vars, filter=filters, order=orders, limit=limit, offset=offset
    )


def alias(value: Node, alias: str) -> EAlias:
    return EAlias(_auto_cast_strict(value), alias)


def subquery(relation, alias: Optional[str] = None, attrs: Optional[Iterable[Any]] = None,
             vars: Optional[Iterable[Any]] = None, filter: Optional[Iterable[Any]] = None,
             order: Optional[Iterable[EOrder]] = None):
    key = relation.key
    alias = relation.key if alias is None else alias
    return _query(entity=key, alias=alias, attrs=attrs, vars=vars, filter=filter, order=order)
