#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
from logging import getLogger
from typing import Union, Type

from more_itertools import first

import entity_tools.utils.move
from .atoms.context import SqlContext
from .atoms.node import Node
from .atoms.order import EOrder
from .atoms.root import ERoot
from .atoms.selectable.alias import EAlias
from .atoms.selectable.column import EColumn
from .atoms.selectable.function import EFunction
from .atoms.selectable.literal import ELiteral
from .lower import LowerSelector, wrap_sql_as_count, wrap_sql_as_list, _generate_sql
# import typing

import sqlalchemy
import sqlalchemy.orm
# from dict_caster.extras import first
#
# from extended_logger import get_logger
# from .select_atoms import Node, EColumn, ELiteral, ERoot, EFunction, EAlias, EOrder, SqlContext
# from .select_lower import LowerSelector, wrap_sql_as_list, generate_sql, wrap_sql_as_count
# from ..components import EntityProxy, AttributeProxy, ProxyStep
from ..entity import Entity
from ..utils.proxy import EntityProxy, AttributeProxy, ProxyStep

logger = getLogger(__name__)


@dataclasses.dataclass
class MiddleSelector(LowerSelector):
    step_to_sub_selector: dict[str, 'MiddleSelector'] = dataclasses.field(default_factory=dict)
    entity: Union[Type[Entity], Entity] = None

    @classmethod
    def init_with_entity(cls, entity: Union[Type[Entity], Entity],
                         local_column_to_remote: dict[str, str] = None) -> 'MiddleSelector':
        local_column_to_remote = local_column_to_remote or {}
        return MiddleSelector(
            entity=entity,
            table_name=entity.get_table_name(),
            local_column_to_remote=local_column_to_remote,
        )

    def get_sub_selector_for(self, step: ProxyStep) -> 'MiddleSelector':
        attribute_name = step.name
        attribute = getattr(self.entity, attribute_name)
        attribute_name = step.name if step.alias is None else step.alias
        if attribute_name not in self.step_to_sub_selector:
            # logger.trace(f"sub selector {attribute_name} does not exist, create")
            if isinstance(attribute.property, sqlalchemy.orm.relationships.RelationshipProperty):
                next_entity_proxy = entity_tools.utils.move.descend(attribute)
            elif isinstance(attribute.property, sqlalchemy.orm.properties.ColumnProperty):
                assert step.alias is not None
                next_entity_proxy = entity_tools.utils.move.ascend(attribute, step.alias)
            else:
                raise TypeError(f"attempt to get SubSelector for unexpected type: {type(attribute)}")
            self.step_to_sub_selector[attribute_name] = MiddleSelector.init_with_entity(
                entity=next_entity_proxy.get_entity(), local_column_to_remote=step.local_column_to_remote
            )
        else:
            # logger.trace(f"sub selector {attribute_name} already exists")
            pass
        return self.step_to_sub_selector[attribute_name]


def process_attribute(selector: MiddleSelector, attribute: Node):
    if isinstance(attribute, EAlias):
        selector.result_name_to_selectable[attribute.get_name()] = attribute.value
    elif isinstance(attribute, EColumn):
        selector.result_name_to_selectable[attribute.get_name()] = attribute
    elif isinstance(attribute, ELiteral):
        assert False
        # assert attribute.alias is not None
        pass
        # selector.result_name_to_column[attribute.get_name()] = attribute.value
    elif isinstance(attribute, ERoot):
        relation = getattr(selector.entity, attribute.entity)
        proxy: EntityProxy = entity_tools.utils.move.descend(relation, attribute.get_name())
        selector = selector.get_sub_selector_for(first(proxy.get_path()))
        get_sub_selector_from_nodes(attribute, selector)
        # selector.result_name_to_column[attribute.get_name()] = attribute.value
    elif isinstance(attribute, EntityProxy):
        proxy: EntityProxy = attribute
        selector = selector.get_sub_selector_for(first(proxy.get_path()))
        get_sub_selector_from_nodes(ERoot(proxy.get_entity()), selector)
        # selector.result_name_to_column[attribute.get_name()] = attribute.value
    elif isinstance(attribute, AttributeProxy):
        # proxy: EntityProxy = attribute
        selector = selector.get_sub_selector_for(first(attribute.path))
        get_sub_selector_from_nodes(ERoot(attribute.attribute.class_, attrs=[attribute.attribute]), selector)
        # selector.result_name_to_column[attribute.get_name()] = attribute.value
    elif isinstance(attribute, EFunction):
        name = attribute.to_sql(SqlContext(selector.table_name))
        selector.result_name_to_selectable[name] = attribute
    else:
        logger.warning(f'process_attribute>- unexpected attribute: %s(%s)', attribute, type(attribute))
        assert False


def process_variable(selector: MiddleSelector, variable: Node):
    assert isinstance(variable, ERoot)
    relation = getattr(selector.entity, variable.entity)
    proxy: EntityProxy = entity_tools.utils.move.descend(relation, variable.get_name())
    selector = selector.get_sub_selector_for(first(proxy.get_path()))
    selector.is_hidden = True
    get_sub_selector_from_nodes(variable, selector)
    # selector.result_name_to_column[attribute.get_name()] = attribute.value


def process_filter(selector: LowerSelector, filter_: Node):
    assert isinstance(filter_, EFunction)
    selector.filter_by.append(filter_)


def process_order(selector: LowerSelector, order: EOrder):
    assert isinstance(order, EOrder)
    selector.order.append(order)


def get_sub_selector_from_nodes(root: ERoot, selector: LowerSelector = None) -> LowerSelector:
    if selector is None:
        assert isinstance(root.entity, type) and issubclass(root.entity, Entity)
        selector = MiddleSelector.init_with_entity(entity=root.entity)

    for argument in root.attrs:
        process_attribute(selector, argument)
    for var in root.vars:
        process_variable(selector, var)
    for filter_ in root.filter:
        process_filter(selector, filter_)
    for order in root.order:
        process_order(selector, order)
    selector.limit = root.limit
    selector.offset = root.offset
    return selector


def e_root_to_sql(root: ERoot, count: bool = False) -> tuple[str, dict]:
    lower_selector = get_sub_selector_from_nodes(root)
    sql_parameter_key_to_value = {}
    inner = _generate_sql("", "", lower_selector, sql_parameter_key_to_value)
    if count:
        sql = wrap_sql_as_count(inner)
    else:
        sql = wrap_sql_as_list(inner)
    return sql, sql_parameter_key_to_value
