#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import enum

import sqlalchemy as sa
from alembic import op

from . import get_trigger_name, get_function_name
from .extract_columns import extract_columns, get_column_to_dependency_columns
from .utils import apply_prefix


new_line = "\n"


class When(enum.StrEnum):
    BEFORE = 'BEFORE'
    AFTER = 'AFTER'
    INSTEAD_OF = 'INSTEAD OF'


class Operation(enum.StrEnum):
    INSERT = 'INSERT'
    UPDATE = 'UPDATE'
    DELETE = 'DELETE'
    TRUNCATE = 'TRUNCATE'


def create_triggers(
        table_name: str,
        old_column_to_expression: dict[sa.Column, sa.ColumnElement],
        new_column_to_expression: dict[sa.Column, sa.ColumnElement],
        revision: str
) -> None:
    create_insert_trigger(table_name, old_column_to_expression, new_column_to_expression, revision)
    create_update_trigger(table_name, old_column_to_expression, new_column_to_expression, revision)


def drop_triggers(table_name: str, revision: str) -> None:
    drop_trigger_and_function(table_name, When.BEFORE, Operation.INSERT, revision)
    drop_trigger_and_function(table_name, When.BEFORE, Operation.UPDATE, revision)


def drop_trigger_and_function(table_name: str, when: When, operation: Operation, revision: str) -> None:
    op.execute(f"DROP TRIGGER IF EXISTS {get_trigger_name(table_name, when, operation, revision)} ON {table_name};")
    op.execute(f"DROP FUNCTION IF EXISTS {get_function_name(table_name, when, operation, revision)};")


def _render_psql(column: sa.Column, expression: sa.ColumnElement) -> str:
    expression_columns = extract_columns(expression)
    compiled_expression = str(expression.compile(compile_kwargs={"literal_binds": True}))
    for col in expression_columns:
        compiled_expression = apply_prefix(compiled_expression, "NEW.", col.name)
    return f"NEW.{column.name} := {compiled_expression};"


def _make_insert_func_part(
        column_name: str,
        dependent_columns: set[sa.Column],
        col_to_expr: dict[sa.Column, sa.ColumnElement]
) -> str:
    parts = [_render_psql(column, expr) for column in dependent_columns if
             (expr := col_to_expr.get(column)) is not None]
    if not parts:
        return ''
    return f"""    IF NEW.{column_name} IS NOT NULL THEN
        {new_line.join(parts)}
    END IF;"""


def _make_update_func_part(
        column_name: str,
        involved_column: set[sa.Column],
        col_to_expr: dict[sa.Column, sa.ColumnElement]
) -> str:
    parts = [_render_psql(column, expr) for column in involved_column if (expr := col_to_expr.get(column)) is not None]
    if not parts:
        return ''
    return f"""IF NEW.{column_name} IS DISTINCT FROM OLD.{column_name} THEN
{new_line.join([' ' * 12 + part for part in parts])}
        END IF;"""


def _create_trigger(
        trigger_name: str, table_name: str, function_name: str, when: When, operation: Operation
) -> None:
    op.execute(f"""
CREATE TRIGGER {trigger_name} {when} {operation} ON {table_name}
FOR EACH ROW EXECUTE FUNCTION {function_name}()
""")


def _create_function(function_name: str, function_body: str) -> None:
    op.execute(f"""
    CREATE OR REPLACE FUNCTION {function_name}() RETURNS TRIGGER AS $$
    BEGIN
{function_body}
        RETURN NEW;
    END;
    $$ LANGUAGE plpgsql
    """)


def create_insert_trigger(
        table_name: str,
        old_column_to_expression: dict[sa.Column, sa.ColumnElement | None],
        new_column_to_expression: dict[sa.Column, sa.ColumnElement | None],
        revision: str
) -> None:
    when, operation = When.BEFORE, Operation.INSERT
    column_to_expression = old_column_to_expression | new_column_to_expression
    column_to_dependencies = (get_column_to_dependency_columns(old_column_to_expression)
                              | get_column_to_dependency_columns(new_column_to_expression))

    column_to_dependents: dict[sa.Column, set[sa.Column]] = {}
    for column, dependencies in column_to_dependencies.items():
        for dependency_column in dependencies:
            column_to_dependents.setdefault(dependency_column, set()).add(column)

    trigger_body = new_line.join(
        " " * 8 + _make_insert_func_part(column.name, dependents, column_to_expression)
        for column in column_to_expression
        if (dependents := column_to_dependents.get(column)) is not None
    )
    function_name = get_function_name(table_name, when, operation, revision)
    _create_function(function_name, trigger_body)
    _create_trigger(get_trigger_name(table_name, when, operation, revision), table_name, function_name, when, operation)


def create_update_trigger(
        table_name: str,
        old_column_to_expression: dict[sa.Column, sa.ColumnElement | None],
        new_column_to_expression: dict[sa.Column, sa.ColumnElement | None],
        revision: str
) -> None:
    when, operation = When.BEFORE, Operation.UPDATE

    column_to_expression = old_column_to_expression | new_column_to_expression
    column_to_dependencies = (get_column_to_dependency_columns(old_column_to_expression)
                              | get_column_to_dependency_columns(new_column_to_expression))

    column_to_dependents: dict[sa.Column, set[sa.Column]] = {}
    for column, dependencies in column_to_dependencies.items():
        for dependency_column in dependencies:
            column_to_dependents.setdefault(dependency_column, set()).add(column)

    trigger_body = new_line.join(
        " " * 8 + _make_update_func_part(column.name, dependents, column_to_expression)
        for column in column_to_expression
        if (dependents := column_to_dependents.get(column)) is not None
    )
    function_name = get_function_name(table_name, when, operation, revision)
    _create_function(function_name, trigger_body)
    _create_trigger(get_trigger_name(table_name, when, operation, revision), table_name, function_name, when, operation)
