#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import dataclasses
import datetime
import functools
import itertools
import json
import types
import typing
from typing import Any, TypeVar, Union, Callable, Optional

from .annotation import try_extract_type_notes
from .eval_forward_refs import get_evaled_dataclass_fields
from .dataclass_protocol import DataclassProtocol
from .is_instance import get_container_type, is_instance
from .str_to_bool import to_bool


@functools.cache
def get_dataclass_field_name_to_field(
        dataclass_: Any,
        with_init_vars: bool = True, with_class_vars: bool = False
) -> dict[str, dataclasses.Field]:
    try:
        # noinspection PyUnresolvedReferences,PyProtectedMember
        fields = getattr(dataclass_, dataclasses._FIELDS)
    except AttributeError:
        raise TypeError("expected dataclass type")
    # noinspection PyUnresolvedReferences,PyProtectedMember
    allowed_field_types = {dataclasses._FIELD}
    # noinspection PyUnresolvedReferences,PyProtectedMember
    allowed_field_types |= {dataclasses._FIELD_INITVAR} if with_init_vars else set()
    # noinspection PyUnresolvedReferences,PyProtectedMember
    allowed_field_types |= {dataclasses._FIELD_CLASSVAR} if with_class_vars else set()
    # noinspection PyProtectedMember
    result = {name: field for name, field in fields.items() if field._field_type in allowed_field_types}
    return result


class FieldErrors(ValueError):
    def __init__(self, field_to_error: dict[str, Exception]):
        self.field_to_error = field_to_error


class MissingField(ValueError):
    pass


def to_datetime(value: Union[str, int, float]) -> datetime.datetime:
    if isinstance(value, str):
        return datetime.datetime.fromisoformat(value)
    return datetime.datetime.fromtimestamp(value)


def pass_value(value: Any) -> Any:
    return value


@functools.cache
def get_converter(type_: type | types.GenericAlias | types.UnionType) -> Callable[[Any], Any]:
    type_, _ = try_extract_type_notes(type_)
    type_origin = getattr(type_, '__origin__', None)
    type_args = getattr(type_, '__args__', None)
    if not type_origin and isinstance(type_, type):
        if dataclasses.is_dataclass(type_):
            return functools.partial(dict_to_dataclass, dataclass=type_)
        converter = type_
        if type_ == bool:
            converter = to_bool
        if type_ == typing.Any:
            converter = pass_value
        if type_ == datetime.time:
            converter = datetime.time.fromisoformat
        if type_ == datetime.date:
            converter = datetime.date.fromisoformat
        if type_ == datetime.datetime:
            converter = to_datetime
    elif type_origin is Union or isinstance(type_, types.UnionType):
        def union_converter(value: str):
            if is_instance(value, type_args):
                return value
            for possible_type in type_args:
                try:
                    return get_converter(possible_type)(value)
                except (TypeError, ValueError, AttributeError):
                    pass
            raise ValueError(f"Expected one of: {type_args}, got {type(value)}: {value}")

        return union_converter
    elif type_origin is not None:
        container_type = get_container_type(type_origin)
        if container_type == dict:
            key_converter = get_converter(type_args[0])
            value_converter = get_converter(type_args[1])

            def converter(base_value: Any):
                raw_values = json.loads(base_value) if isinstance(base_value, str) else base_value
                if not isinstance(raw_values, dict):
                    raise ValueError(f'Expected dict, got {type(raw_values).__name__}: {raw_values}')
                parsed_values = [(key_converter(key), value_converter(value)) for key, value in raw_values.items()]
                result = container_type(parsed_values)
                return result
        elif container_type == tuple:
            has_ellipsis = False
            element_converters = []
            for type_ in type_args:
                if type_ is ...:
                    has_ellipsis = True
                else:
                    if has_ellipsis:
                        raise TypeError("Ellipsis (three dots (...)) should be the last"
                                        " in tuple to make it variable-length")
                    element_converters.append(get_converter(type_))
            if has_ellipsis:
                element_converters = itertools.cycle(element_converters)

            def converter(base_value: Any):
                raw_values = json.loads(base_value) if isinstance(base_value, str) else base_value
                if not isinstance(raw_values, (list, tuple)):
                    raise ValueError(f'Expected iterable, got {type(raw_values).__name__}: {raw_values}')
                if not has_ellipsis and len(raw_values) != len(element_converters):
                    raise ValueError(f'Expected exactly {len(element_converters)} elements, got {len(raw_values)}')
                result = tuple(converter_(element) for element, converter_ in zip(raw_values, element_converters))
                return result
        else:
            element_converter = get_converter(type_args[0])

            def converter(base_value: Any):
                raw_values = json.loads(base_value) if isinstance(base_value, str) else base_value
                if not isinstance(raw_values, (list, tuple)):
                    raise ValueError(f'Expected iterable, got {type(raw_values).__name__}: {raw_values}')
                parsed_values = [element_converter(element) for element in raw_values]
                result = container_type(parsed_values)
                return result
    elif isinstance(type_, dataclasses.InitVar):
        return get_converter(type_.type)
    else:
        raise TypeError(f'Attempt to cast to unexpected type: {type_} ({type(type_).__name__})')
    return converter


class UnsupportedType(TypeError):
    pass


T = TypeVar('T')
DP = TypeVar('DP', bound=DataclassProtocol)


class NoValue:
    pass


def write_to_dict_or_raise(container: Optional[dict[str, Exception]], key: str, exception: Exception) -> None:
    if container is not None:
        container[key] = exception
    else:
        raise exception


def convert_to_type(field_type: type[T], field_name: str, value: Any,
                    field_name_to_error: dict[str, Exception] = None) -> T:
    field_type, _ = try_extract_type_notes(field_type)
    if dataclasses.is_dataclass(field_type):
        try:
            # noinspection PyTypeChecker
            value = dict_to_dataclass(value, field_type)
        except FieldErrors as e:
            for field, error in e.field_to_error.items():
                write_to_dict_or_raise(field_name_to_error, key=f'{field_name}.{field}', exception=error)
    else:
        try:
            converter = get_converter(field_type)
        except TypeError as e:
            raise UnsupportedType(f'Field "{field_name}": {str(e)}') from e

        try:
            value = converter(value)
        except Exception as e:
            write_to_dict_or_raise(field_name_to_error, field_name, e)
    return value


def convert(name: str, type_: type[T], init: bool, default: Any, default_factory: Callable, value: Any,
            field_name_to_error: dict[str, Exception] = None) -> T:
    if not init:
        return NoValue

    type_, _ = try_extract_type_notes(type_)
    if value != NoValue:
        value = convert_to_type(type_, name, value, field_name_to_error)
    else:
        if default != dataclasses.MISSING:
            return NoValue
        elif default_factory != dataclasses.MISSING:
            return NoValue
        elif dataclasses.is_dataclass(type_):
            try:
                value = dict_to_dataclass({}, type_)
            except FieldErrors as e:
                for field, error in e.field_to_error.items():
                    write_to_dict_or_raise(field_name_to_error, key=f'{name}.{field}', exception=error)
        else:
            write_to_dict_or_raise(field_name_to_error, name, MissingField())
            return NoValue

    return value


def dict_to_dataclass(dict_value: dict[str, Any], dataclass: type[DP]) -> DP:
    dataclass, _ = try_extract_type_notes(dataclass)
    if not dataclasses.is_dataclass(dataclass):
        raise TypeError("expected dataclass type")

    if isinstance(dict_value, dataclass):
        return dict_value

    args = {}
    field_name_to_field = get_dataclass_field_name_to_field(dataclass)
    field_name_to_type = typing.get_type_hints(dataclass) | get_evaled_dataclass_fields(dataclass)
    field_name_to_error = {}
    if hasattr(dict_value, "get"):
        get_value = dict_value.get
    else:
        get_value = lambda name, default: getattr(dict_value, name, default)
    for field_name in field_name_to_field:
        field = field_name_to_field[field_name]
        field_type = field_name_to_type[field_name]
        raw_value = get_value(field_name, NoValue)
        result_value = convert(
            name=field_name, type_=field_type, init=field.init, default=field.default,
            default_factory=field.default_factory, value=raw_value, field_name_to_error=field_name_to_error)
        if result_value is not NoValue:
            args[field_name] = result_value

    if field_name_to_error:
        raise FieldErrors(field_name_to_error)

    return dataclass(**args)
