#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>
import dataclasses
from typing import Any

import sqlalchemy

from .fields import get_column_from_field


def _ensure_column(column: str | sqlalchemy.Column | dataclasses.Field) -> str | sqlalchemy.Column:
    if isinstance(column, dataclasses.Field):
        column = get_column_from_field(column)
    if isinstance(column, sqlalchemy.orm.attributes.InstrumentedAttribute):
        assert isinstance(column_property := column.property, sqlalchemy.orm.ColumnProperty)
        assert len(columns := column_property.columns) == 1
        column = columns[0]

    if isinstance(column, str | sqlalchemy.Column):
        return column

    raise TypeError(f"Unsupported column type: {type(column)}")


def _ensure_name(column: str | sqlalchemy.Column, dialect_kw: dict[str, Any]) -> None:
    if 'name' in dialect_kw:
        return

    if isinstance(column, sqlalchemy.Column):
        column: str = f'{column.table.name}.{column.name}'

    if isinstance(column, str):
        inner_name = column.replace(".", "_")
    else:
        raise TypeError(f"Unexpected type {type(column)}")

    dialect_kw["name"] = f"{inner_name}_fkey"


class CascadeForeignKey(sqlalchemy.ForeignKey):
    def __init__(self, column: str | sqlalchemy.Column | dataclasses.Field, **dialect_kw):
        dialect_kw["onupdate"] = dialect_kw["ondelete"] = "CASCADE"
        column = _ensure_column(column)
        _ensure_name(column, dialect_kw)
        super().__init__(column, **dialect_kw)


class RestrictForeignKey(sqlalchemy.ForeignKey):
    def __init__(self, column: sqlalchemy.Column | dataclasses.Field, **dialect_kw):
        dialect_kw['onupdate'] = dialect_kw["ondelete"] = "RESTRICT"
        column = _ensure_column(column)
        _ensure_name(column, dialect_kw)
        super().__init__(column, **dialect_kw)


class SetNullForeignKey(sqlalchemy.ForeignKey):
    def __init__(self, column: sqlalchemy.Column | dataclasses.Field, **dialect_kw):
        dialect_kw['onupdate'] = dialect_kw["ondelete"] = "SET NULL"
        column = _ensure_column(column)
        _ensure_name(column, dialect_kw)
        super().__init__(column, **dialect_kw)
