import dataclasses
from typing import Union, Type, Any, List, Tuple, Callable, Dict, Optional

from sqlalchemy import ForeignKey, Column
from sqlalchemy.types import TypeEngine
from sqlalchemy.orm import relationship
from sqlalchemy.orm.attributes import InstrumentedAttribute


def sql_relation_field(remote_class, *, init: bool = False, use_list: bool = True,
                       order_by: Union[str, bool, list] = False, lazy: str = "select",
                       additional_metadata: Optional[Dict[str, Any]] = None,
                       **relationship_kwargs) -> dataclasses.field:
    metadata = {"sa": relationship(remote_class, uselist=use_list, order_by=order_by, lazy=lazy, **relationship_kwargs)}
    if additional_metadata is not None:
        metadata.update(additional_metadata)
    return dataclasses.field(metadata=metadata, init=init)


def sql_field(type_: Union[Type[TypeEngine], TypeEngine], foreign_key: ForeignKey = None,
              *, primary_key: bool = False, composite_primary_key: bool = False, allow_init_primary_key: bool = False,
              nullable: Optional[bool] = None, unique: Optional[bool] = None, index: bool = False,
              default: Any = None, server_default: Any = None, default_factory: Callable = None,
              repr_: bool = True, additional_metadata: Optional[Dict[str, Any]] = None
              ) -> dataclasses.field:
    if composite_primary_key:
        primary_key = True
    column_args = [type_, foreign_key]
    column_kwargs = {"primary_key": primary_key}
    if nullable is not None:
        column_kwargs['nullable'] = nullable
    if unique is not None:
        column_kwargs['unique'] = unique
    if default is not None:
        column_kwargs['default'] = default
    if index:
        column_kwargs['index'] = index
    if server_default is not None:
        column_kwargs['server_default'] = server_default
    metadata = {"sa": Column(*column_args, **column_kwargs)}

    if additional_metadata is not None:
        metadata.update(additional_metadata)

    kwargs = {"metadata": metadata, "repr": repr_}
    if default_factory:
        kwargs['default_factory'] = default_factory
    if default is not None:
        key = "default_factory" if callable(default) else "default"
        kwargs[key] = default
    elif nullable:
        kwargs['default'] = None
    if not allow_init_primary_key and primary_key and not composite_primary_key:
        kwargs['init'] = False
    return dataclasses.field(**kwargs)


def get_columns(entity) -> List[InstrumentedAttribute]:
    columns = entity.__table__.c
    return list(columns)


def get_primary_key_columns(entity) -> List[Column]:
    columns = entity.__table__.primary_key.columns
    return list(columns)


def get_value_tuple_by_columns(instance, columns: List[Column]) -> Tuple[Any]:
    return tuple(getattr(instance, column.name) for column in columns)


def get_column_from_field(field: dataclasses.Field) -> Column:
    return field.metadata["sa"]
