import logging
from dataclasses import dataclass, field
from typing import Type, Self

from dict_caster.extras import to_list

from ..elements.abstract.sqlable import Sqlable
from ..elements.order import Order
from ..elements.filter_by import Condition
from ..elements.output_format import OutputFormat
from ..elements.result import SelectResult
from ..elements.table import Table
from .abstract_statement import AbstractStatement
from ..elements.abstract.selectable import Selectable

logger = logging.getLogger(__name__)


@dataclass
class Select(AbstractStatement, Sqlable):
    select_from: Type[Table] | Self
    selectables: list[Selectable] | Selectable | None = field(default_factory=list)
    condition: Condition.get_condition_type() | None = None
    group_by: list[Sqlable] | Sqlable | None = field(default_factory=list)
    order_by: list[Order] | Order | None = field(default_factory=list)
    offset: int | None = None
    limit: int | None = None
    output_format: OutputFormat = OutputFormat.Dict
    distinct_on: list[Selectable] | Selectable | None = field(default_factory=list)


    def __post_init__(self):
        self.table.get_table_engine().preprocess_select_request(self)

        self.selectables = to_list(self.selectables) if self.selectables is not None else []
        self.distinct_on = to_list(self.distinct_on) if self.distinct_on is not None else []
        self.group_by = to_list(self.group_by) if self.group_by is not None else []
        self.order_by = to_list(self.order_by) if self.order_by is not None else []
    def to_sql(self) -> str:
        return self.generate_sql()

    @property
    def table(self):
        if isinstance(self.select_from, Select):
            return self.select_from.table
        elif isinstance(self.select_from, type) and issubclass(self.select_from, Table):
            return self.select_from
        else:
            raise ValueError(f'unexpected {self.select_from=}')

    def generate_sql(self):

        select_sql = self._construct_select_sql()
        from_sql = self._construct_from_sql()
        distinct_on_sql = self._construct_distinct_on_sql()
        where_sql = self._construct_where_sql()
        group_sql = self._construct_group_sql()
        order_sql = self._construct_order_sql()
        limit_sql = self._construct_limit_sql()
        offset_sql =self._construct_offset_sql()
        sql = f"SELECT {distinct_on_sql} {select_sql} FROM {from_sql} {where_sql} {group_sql} {order_sql} {limit_sql} {offset_sql}"
        logger.debug(f"constructed {sql=}")
        return sql


    def _construct_from_sql(self):
        sql = self.select_from.to_sql()
        if isinstance(self.select_from, Select):
            sql = f"({sql})"
        return sql

    def _construct_select_sql(self) -> str:
        return ', '.join(selectable.to_selector() for selectable in self.selectables)

    def _construct_distinct_on_sql(self) -> str:
        distinct_on_sql = ""
        if self.distinct_on:
            if not self.order_by:
                raise NotImplementedError("order by required when selecting distinct_on")
            distinct_on_sql = \
                f"DISTINCT ON({', '.join(selectable.to_selector() for selectable in self.distinct_on)}) "
        return distinct_on_sql
    def _construct_where_sql(self) -> str:
        return f'\nWHERE {self.condition.to_sql()}' if self.condition else ''

    def _construct_group_sql(self) -> str:
        group_by_part = ', '.join((g.to_sql() for g in self.group_by))
        return f'\nGROUP BY {group_by_part}' if group_by_part else ''

    def _construct_order_sql(self) -> str:
        order_by_part = ', '.join((o.to_sql() for o in self.order_by))
        return f'\nORDER BY {order_by_part}' if order_by_part else ''

    def _construct_limit_sql(self) -> str:
        return f'\nLIMIT {self.limit}' if self.limit else ''

    def _construct_offset_sql(self) -> str:
        return f'\nOFFSET {self.offset}' if self.offset else ''

    def form_result(self, payload: str) -> SelectResult:
        return SelectResult(payload, self.selectables, self.output_format)
