#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
import itertools
from typing import Optional

from .atoms.context import SqlContext
from .atoms.node import SQL
from .atoms.order import EOrder
from .atoms.selectable.abstract import Selectable, BooleanSelectable


@dataclasses.dataclass
class LowerSelector:
    table_name: str
    result_name_to_selectable: dict[str, Selectable] = dataclasses.field(default_factory=dict)
    filter_by: list[BooleanSelectable] = dataclasses.field(default_factory=list)
    is_hidden: bool = False
    order: list[EOrder] = dataclasses.field(default_factory=list)
    limit: Optional[int] = None
    offset: Optional[int] = None

    local_column_to_remote: dict[str, str] = dataclasses.field(default_factory=dict)
    step_to_sub_selector: dict[str, 'LowerSelector'] = dataclasses.field(default_factory=dict)

    def attach_selector(self: 'LowerSelector', child_name: str, child_selector: 'LowerSelector',
                        child_column_name_to_parent_column_name: dict[str, str]):
        if child_name in self.step_to_sub_selector:
            raise ValueError(f"attempt to attach selector with name: {child_name!r}, but it is already used")
        self.step_to_sub_selector[child_name] = child_selector
        assert not child_selector.local_column_to_remote
        child_selector.local_column_to_remote = child_column_name_to_parent_column_name
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

    def get_subquery_join_conditions(self, my_name: str, parent_name: str) -> list[SQL]:
        return [f'{my_name}.{local} = {parent_name}.{remote}' for local, remote in self.local_column_to_remote.items()]

    def get_local_relation_columns(self, my_name: str) -> dict[str, str]:
        return {local: f'{my_name}.{local}' for local in self.local_column_to_remote}

    def get_conditions(self, context: SqlContext) -> list[SQL]:
        return [condition.to_sql(context) for condition in self.filter_by]

    def get_order(self, context: SqlContext) -> str:
        order_sql = str.join(", ", [order.to_order_sql(context) for order in self.order])
        return "ORDER BY " + order_sql if order_sql else ''

    def get_limit_offset_sql(self) -> str:
        parts: list[str] = []
        if self.limit is not None:
            parts.append(f'LIMIT {self.limit}')
        if self.offset is not None:
            parts.append(f'OFFSET {self.offset}')
        result = '\n' + str.join(' ', parts) if parts else ''
        return result


def _ensure_label(alias: str, *, alias_max_len: int = 63, separator: str = "HIDDEN") -> str:
    if len(alias) < alias_max_len:
        return alias

    part_size = int((alias_max_len - len(separator)) / 2)
    return f'{alias[:part_size]}{separator}{alias[-part_size:]}'


JSON_LIST__NAME = 'json_list'
JSON_OBJECT__NAME = 'json_object'
MAX_FUNC_ARGS = 50  # Postgresql limitation


def _generate_sql(
        parent_alias: str, outer_alias: str, selector: LowerSelector, key_to_param: Optional[dict] = None) -> SQL:
    required__name_to_sql = {}
    select__name_to_sql = {}
    current_table_with_subqueries_alias = _ensure_label(f'inner_{outer_alias}')
    current_table_alias = _ensure_label(f'{selector.table_name}_alias')

    inner_context = SqlContext(current_table_alias, key_to_param)
    outer_context = SqlContext(current_table_with_subqueries_alias, key_to_param)

    for node in itertools.chain(selector.filter_by, selector.order, selector.result_name_to_selectable.values()):
        required__name_to_sql.update({column.get_name(): column.to_sql(inner_context) for column in node.requires()})

    for result_name, selectable in selector.result_name_to_selectable.items():
        select__name_to_sql[f"'{result_name}'"] = selectable.to_sql(outer_context)

    joins = []
    for i, (step, child_selector) in enumerate(selector.step_to_sub_selector.items()):
        child_alias: str = _ensure_label(f'{outer_alias}_{step}')
        child_column_name: str = _ensure_label(step)
        if not child_selector.is_hidden:
            select__name_to_sql[f"'{child_column_name}'"] = child_column_name
        required__name_to_sql[child_column_name] = f"coalesce({child_alias}.{JSON_LIST__NAME}, '[]'::jsonb)"
        joins.append(
            f"\nLEFT JOIN LATERAL ({_generate_sql(current_table_alias, child_alias, child_selector)}) "
            f"as {child_alias} ON TRUE"
        )

    required__name_to_sql.update(selector.get_local_relation_columns(current_table_alias))
    required__name_to_sql = {
        k: (v + " as " + k if k in selector.step_to_sub_selector else v)
        for k, v in required__name_to_sql.items()
    }
    where_conditions = " AND ".join(
        selector.get_subquery_join_conditions(current_table_with_subqueries_alias, parent_alias)
        + selector.get_conditions(outer_context))
    where_sql = "WHERE " + where_conditions if where_conditions else ""

    return f"""
    SELECT jsonb_agg({JSON_OBJECT__NAME}) as {JSON_LIST__NAME}
    FROM (
        SELECT {_prepare_json_object_sql(select__name_to_sql)} as {JSON_OBJECT__NAME}
        FROM (
            SELECT {", ".join(required__name_to_sql.values())}
            FROM {selector.table_name} as {current_table_alias} {"".join(joins)}
        ) as {current_table_with_subqueries_alias}
        {where_sql} {selector.get_order(outer_context)} {selector.get_limit_offset_sql()}
    ) as jsoned
    """


def _prepare_json_object_sql(select__name_to_sql: dict[str, str]) -> SQL:
    jsonb_objects = []
    name_to_sql_pairs = list(select__name_to_sql.items())
    for i in range(0, len(select__name_to_sql), MAX_FUNC_ARGS):
        args = ", ".join(f"{name}, {value}" for name, value in name_to_sql_pairs[i:i+MAX_FUNC_ARGS])
        jsonb_objects.append(f"jsonb_build_object({args})")
    jsonb_objects = jsonb_objects or ["jsonb_build_object()"]
    result = " || ".join(jsonb_objects)
    return result

# <- generate_sql
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# -> Utils


def wrap_sql_as_list(sql: str) -> str:
    return f"SELECT coalesce({JSON_LIST__NAME}, '[]'::jsonb) FROM ({sql}) as _main_"


def wrap_sql_as_count(sql: str) -> str:
    return f"SELECT COALESCE(jsonb_array_length({JSON_LIST__NAME}), 0) FROM ({sql}) as _main_"


def generate_sql_from_recursive_selector(sub_selector: LowerSelector) -> str:
    sql = _generate_sql("", "", sub_selector)
    result = wrap_sql_as_list(sql)
    return result


def align_sql(sql: str):
    result_lines = []
    depth = 0
    for line in sql.splitlines():
        line = line.strip()
        for char in line:
            if char in '({[<':
                depth += 1
            elif char in ')}[>':
                depth -= 1
        result_lines.append("  " * depth + line)
    return str.join("\n", result_lines)
