#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
from logging import getLogger
from typing import ClassVar

import sqlalchemy
import sqlalchemy.sql
import sqlalchemy.sql.schema
from apispec import APISpec
from openapi_tools.extras import DataclassProtocol
from openapi_tools.open_api_wrapper import OpenApiWrapper

from entity_read.entity import Entity

logger = getLogger(__name__)


class EntityWrapper(OpenApiWrapper):
    primary_key_names__key: ClassVar[str] = 'x-primaryKeys'
    fk_name_to_fk__key: ClassVar[str] = 'x-foreignKeys'

    def _get_dataclass_description(self, schema: DataclassProtocol) -> dict:
        result = super()._get_dataclass_description(schema)
        assert self.primary_key_names__key not in result
        primary_key_names = result[self.primary_key_names__key] = []
        assert self.fk_name_to_fk__key not in result
        fk_name_to_fk = result[self.fk_name_to_fk__key] = {}
        if isinstance(schema, type) and issubclass(schema, Entity):
            logger.debug(f"add {schema.__name__} constraints to spec")
            table = schema.get_table()
            for constraint in table.constraints:
                constraint: sqlalchemy.sql.schema.ColumnCollectionConstraint
                logger.debug(f"found {constraint.name=} {type(constraint)=} cols:{[c.key for c in constraint.columns]}")
                if isinstance(constraint, sqlalchemy.sql.schema.PrimaryKeyConstraint):
                    primary_key_names += [col.name for col in constraint.columns]
                    logger.debug(f"updated {self.primary_key_names__key!r}: {primary_key_names}")
                elif isinstance(constraint, sqlalchemy.sql.schema.ForeignKeyConstraint):
                    assert constraint.name, f"Got constraint without name: {constraint}"
                    referred_entity = schema.get_entity_by_table(constraint.referred_table)
                    fk_name_to_fk[constraint.name] = {
                        "target_table": self._get_schema_description(referred_entity),
                        "local_to_remote_columns": {
                            constraint_pair.parent.key: constraint_pair.column.key
                            for constraint_pair in constraint.elements
                        },
                        "on_delete": constraint.ondelete,
                        "on_update": constraint.onupdate,
                    }
                    logger.debug(f"updated {self.fk_name_to_fk__key!r}: {fk_name_to_fk}")
                elif isinstance(constraint, sqlalchemy.sql.schema.UniqueConstraint):
                    logger.debug(f"Skipped UniqueConstraint")
                elif isinstance(constraint, sqlalchemy.sql.schema.CheckConstraint):
                    logger.debug(f"Skipped CheckConstraint")

        return result
