#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
from dataclasses import dataclass, field

from sqlalchemy.util import to_list

import entity_read.sql.inter.agg_func
from entity_read import sql
from entity_read.entity import Entity
from .expression import Expression

key_to_function_factory = {
    "array": entity_read.sql.inter.agg_func.AggregationArray,
    "unique": entity_read.sql.inter.agg_func.AggregationUnique,
    "count": entity_read.sql.inter.agg_func.AggregationCount,
    "sum": entity_read.sql.inter.agg_func.AggregationSum,
    "max": entity_read.sql.inter.agg_func.AggregationMax,
    "min": entity_read.sql.inter.agg_func.AggregationMin,
}


@dataclass(frozen=True, kw_only=True, repr=False)
class Aggregation(Expression):
    key: str
    args: list[Expression]
    filter: list[Expression] = field(default_factory=list)

    def _get_lower_type(self) -> type[sql.atoms.Aggregation]:
        if (function_factory := key_to_function_factory.get(self.key)) is None:
            raise KeyError(f"Not found function {self.key!r}")
        return function_factory

    def eval(self, entity_type: type[Entity], variables: dict[str, sql.atoms.Selectable]) -> sql.atoms.Aggregation:
        lower_type = self._get_lower_type()
        # noinspection PyArgumentList
        return lower_type(
            args=[arg.eval(entity_type, variables) for arg in self.args],
            filters=[f.eval(entity_type, variables) for f in self.filter]
        )

    def shortcut(self) -> str:
        return f"agg.{self.key}({','.join([arg.shortcut() for arg in self.args])})"

    def where(self, filter: Expression | list[Expression]) -> 'Aggregation':
        return Aggregation(key=self.key, args=self.args, filter=to_list(filter))
