#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
import itertools
from typing import Optional

from .select_atoms import EColumn, SqlContext, Node, EOrder, EFunction


@dataclasses.dataclass
class LowerSelector:
    table_name: str
    result_name_to_ecolumn: dict[str, Node] = dataclasses.field(default_factory=dict)
    filter_by: list[Node] = dataclasses.field(default_factory=list)
    is_hidden: bool = False
    order: list[EOrder] = dataclasses.field(default_factory=list)
    limit: Optional[int] = None
    offset: Optional[int] = None

    local_column_to_remote: dict[str, str] = dataclasses.field(default_factory=dict)
    step_to_sub_selector: dict[str, 'LowerSelector'] = dataclasses.field(default_factory=dict)

    def attach_selector(self: 'LowerSelector', child_name: str, child_selector: 'LowerSelector',
                        child_column_name_to_parent_column_name: dict[str, str]):
        self.step_to_sub_selector[child_name] = child_selector
        assert not child_selector.local_column_to_remote
        child_selector.local_column_to_remote = child_column_name_to_parent_column_name
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

    def get_subquery_conditions(self, my_name: str, parent_name: str) -> str:
        return str.join("AND", [
            f'{my_name}.{local} = {parent_name}.{remote}'
            for local, remote in self.local_column_to_remote.items()
        ]) or "TRUE"

    def get_local_relation_columns(self, my_name: str) -> dict[str, str]:
        return {local: f'{my_name}.{local}' for local in self.local_column_to_remote}

    def get_where(self, context: SqlContext) -> str:
        filters_sql = str.join(" AND ", [condition.to_sql(context) for condition in self.filter_by])
        return filters_sql

    def get_order(self, context: SqlContext) -> str:
        order_sql = str.join(", ", [order.to_order_sql(context) for order in self.order])
        if order_sql:
            order_sql = "ORDER BY " + order_sql
        return order_sql

    def get_limit_offset_sql(self) -> str:
        parts: list[str] = []
        if self.limit is not None:
            parts.append(f'LIMIT {self.limit}')
        if self.offset is not None:
            parts.append(f'OFFSET {self.offset}')
        result = '\n' + str.join(' ', parts) if parts else ''
        return result
#
# @dataclasses.dataclass
# class RecursiveSelector(LowerSelector):
#     table_name: str = dataclasses.field(init=False)
#     entity: str
#     def __post_init__(self):
#
#     def get_sub_selector_for(self, step: ProxyStep) -> 'RecursiveSelector':
#         attribute_name = step.name
#         attribute = getattr(self.entity, attribute_name)
#         attribute_name = step.name if step.alias is None else step.alias
#         if attribute_name not in self.step_to_sub_selector:
#             print(f"sub selector {attribute_name} does not exist, create")
#             if isinstance(attribute.property, sqlalchemy.orm.relationships.RelationshipProperty):
#                 next_entity_proxy = attribute.descend()
#             elif isinstance(attribute.property, sqlalchemy.orm.properties.ColumnProperty):
#                 assert step.alias is not None
#                 next_entity_proxy = attribute.ascend(step.alias)
#             else:
#                 raise TypeError(f"attempt to get SubSelector for unexpected type: {type(attribute)}")
#             self.step_to_sub_selector[attribute_name] = RecursiveSelector(
#                 next_entity_proxy.get_entity(), local_column_to_remote=step.local_column_to_remote
#             )
#         else:
#             print(f"sub selector {attribute_name} already exists")
#         return self.step_to_sub_selector[attribute_name]


def _secure_alias(alias: str, *, alias_max_len: int = 63, separator: str = "HIDDEN"):
    if len(alias) < alias_max_len:
        return alias

    part_size = int((alias_max_len - len(separator)) / 2)
    return f'{alias[:part_size]}{separator}{alias[-part_size:]}'


payload_name = 'payload'
inner_payload_name = 'inner_payload'


def generate_sql(parent_alias: str, outer_alias: str, selector: LowerSelector, key_to_param: Optional[dict] = None):
    select__name_to_sql = {}
    aggregate_name_to_sql = {}
    joins = []
    table_name = selector.table_name
    inner_alias_1 = _secure_alias(f'inner_{outer_alias}')
    inner_alias_2 = _secure_alias(f'{table_name}_alias')

    inner_context = SqlContext(inner_alias_2, key_to_param)
    outer_context = SqlContext(inner_alias_1, key_to_param)

    for filter_ in selector.filter_by:
        required_columns = filter_.requires()
        select__name_to_sql.update({column.get_name(): column.to_sql(inner_context) for column in required_columns})

    for order in selector.order:
        required_columns = order.requires()
        select__name_to_sql.update({column.get_name(): column.to_sql(inner_context) for column in required_columns})

    for col_alias, col in selector.result_name_to_ecolumn.items():
        required_columns: list[EColumn] = []
        if isinstance(col, EFunction):
            required_columns = col.requires()
        if isinstance(col, EColumn):
            required_columns = [col]
        select__name_to_sql.update({column.get_name(): column.to_sql(inner_context) for column in required_columns})
        aggregate_name_to_sql[f"'{col_alias}'"] = col.to_sql(outer_context)

    for i, step in enumerate(selector.step_to_sub_selector):
        child_selector = selector.step_to_sub_selector[step]
        child_alias = _secure_alias(f'{outer_alias}_{step}')
        child_column = _secure_alias(f'{step[:]}')
        if not child_selector.is_hidden:
            aggregate_name_to_sql[f"'{child_column}'"] = child_column
        select__name_to_sql[child_column] = f"coalesce({child_alias}.{payload_name}, '[]'::jsonb)"
        joins.append(
            f"\nLEFT JOIN LATERAL ({generate_sql(inner_alias_2, child_alias, child_selector)}) as {child_alias} ON TRUE"
        )

    select__name_to_sql.update(selector.get_local_relation_columns(inner_alias_2))
    joins_str = "".join(joins)
    where_str = selector.get_where(outer_context)
    order_str = selector.get_order(outer_context)

    select__name_to_sql = {
        k: (v + " as " + k if k in selector.step_to_sub_selector else v)
        for k, v in select__name_to_sql.items()
    }
    MAX_FUNC_ARGS = 50
    jsonb_objects = []
    name_to_sql_pairs = list(aggregate_name_to_sql.items())
    for i in range(0, len(aggregate_name_to_sql), MAX_FUNC_ARGS):
        args = ", ".join(f"{name}, {value}" for name, value in name_to_sql_pairs[i:i+MAX_FUNC_ARGS])
        jsonb_objects.append(f"jsonb_build_object({args})")
    jsonb_objects = jsonb_objects or ["jsonb_build_object()"]
    agg_str = " || ".join(jsonb_objects)

    return f"""
    SELECT jsonb_agg({inner_payload_name}) as {payload_name}
    FROM (
    SELECT {agg_str} as {inner_payload_name}
        FROM (
            SELECT {", ".join(select__name_to_sql.values())}
            FROM {table_name} as {inner_alias_2} {joins_str}
        ) as {inner_alias_1}
        WHERE {selector.get_subquery_conditions(inner_alias_1, parent_alias)}
        {("AND " + where_str) if where_str else ""}
        {order_str} {selector.get_limit_offset_sql()}
    ) as jsoned
    """


# <- generate_sql
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# -> Utils

def wrap_sql_as_list(sql: str) -> str:
    return f"SELECT coalesce({payload_name}, '[]'::jsonb) FROM ({sql}) as _main_"


def wrap_sql_as_count(sql: str) -> str:
    return f"SELECT COALESCE(jsonb_array_length(payload), 0) FROM ({sql}) as _main_"


def generate_sql_from_recursive_selector(sub_selector: LowerSelector) -> str:
    sql = generate_sql("", "", sub_selector)
    result = wrap_sql_as_list(sql)
    return result


def align_sql(sql: str):
    result_lines = []
    depth = 0
    for line in sql.splitlines():
        line = line.strip()
        for char in line:
            if char in '({[<':
                depth += 1
            elif char in ')}[>':
                depth -= 1
        result_lines.append("  " * depth + line)
    return str.join("\n", result_lines)
