from dataclasses import dataclass, field
from typing import Type, Optional, List

from ..elements.filter_by import Filter
from ..elements.result import CountResult
from ..elements.table import Table
from ..statements.abstract_statement import AbstractStatement


@dataclass
class Count(AbstractStatement):
    table: Type[Table]

    filter_by: Optional[List[Filter]] = field(default_factory=list)
    search_by: Optional[List[Filter]] = field(default_factory=list)

    def generate_sql(self) -> str:
        table_name = self.table.get_name()

        where_filter_part = ' AND '.join((str(f) for f in self.filter_by) if self.filter_by is not None else [])
        if where_filter_part:
            where_filter_part = f'({where_filter_part})'
        where_search_part = ' OR '.join((str(f) for f in self.search_by) if self.search_by is not None else [])
        if where_search_part:
            where_search_part = f'({where_search_part})'
        where_parts_union_connector = ' AND ' if (where_filter_part and where_search_part) else ''
        full_where = where_filter_part + where_parts_union_connector + where_search_part
        where_sql = f'\nWHERE {full_where}' if full_where else ''

        return f"SELECT COUNT(*) FROM {self.database}.{table_name} {where_sql}"

    def form_result(self, payload: str) -> CountResult:
        return CountResult(payload)
