#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>

import functools
import inspect
from dataclasses import dataclass, fields
from typing import TypeVar, Final, Type

from .camel_case_to_underscore import camel_case_to_underscore


Entity = TypeVar('Entity')
ENTITY_NAME_FIELD: Final[str] = '__entity_name__'
ENTITY_COLUMNS_FIELD: Final[str] = 'c'
RELATIONSHIP_COLUMN_SEPARATOR: Final[str] = '.'


@dataclass
class EntityColumn:
    name: str
    type: type


def get_all_annotations(cls) -> dict[str, type]:
    all_annotations = {}
    for c in cls.__mro__:
        if annotations := inspect.get_annotations(c):
            all_annotations.update(annotations)
    return all_annotations


class ColumnGetter:
    def __init__(self, entity: Type[Entity]) -> None:
        self._name_to_column = {
            name: EntityColumn(name, type_) for name, type_ in get_all_annotations(entity).items()
        }

    def __getattr__(self, name: str) -> EntityColumn:
        column = self._name_to_column.get(name)
        if column is None:
            raise AttributeError(f'attribute not found: {name}')
        return column


def entity_dataclass(class_: type = None, *, entity_name: str = None, **kwargs):
    def _entity_dataclass(cls):
        entity = dataclass(cls, **kwargs)
        setattr(entity, ENTITY_NAME_FIELD, entity_name if entity_name else camel_case_to_underscore(cls.__name__))
        setattr(entity, ENTITY_COLUMNS_FIELD, ColumnGetter(cls))
        return entity

    if class_ is None:
        return _entity_dataclass

    return _entity_dataclass(class_)


@functools.cache
def prepare_entity_column_names(entity_type: Type[Entity], prefix: str = '') -> list[str]:
    if not is_entity_dataclass(entity_type):
        raise TypeError("expected entity_dataclass")

    result = []
    for field in fields(entity_type):
        type_origin = getattr(field.type, '__origin__', None)
        type_args = getattr(field.type, '__args__', None)
        if isinstance(type_origin, type) and is_entity_dataclass(type_args[0]):
            result += prepare_entity_column_names(type_args[0], field.name + RELATIONSHIP_COLUMN_SEPARATOR)
        else:
            result.append(prefix + field.name)
    return result


def is_entity_dataclass(entity) -> bool:
    cls = entity if isinstance(entity, type) else type(entity)
    return hasattr(cls, ENTITY_NAME_FIELD)
