import collections.abc
import functools
import itertools
from typing import List, Union, Type, Any, Callable, Mapping, MutableMapping, Dict, Hashable

from .grouping import Grouping
from .default_caster import BaseCaster, DefaultCaster
from .exception import ItemException, ComplexItemException
from .extras import to_list, NoDefault
from .item import Item

Arg = Union[Item, Grouping]


class DictCaster:
    def __init__(self, items: Union[List[Arg], Arg] = None, *args: Arg, caster: BaseCaster = None) -> None:
        args = ([] if items is None else to_list(items)) + list(args)
        self.items: List[Item] = []
        self._item_to_grouping: Dict[int, Grouping] = {}
        self.default_grouping = Grouping
        for arg in args:
            if isinstance(arg, Grouping):
                for item in arg.items:
                    self.items.append(item)
                    self._item_to_grouping[id(item)] = arg
            elif isinstance(arg, Item):
                self.items.append(arg)
            else:
                raise TypeError(f"got unexpected type in DictCaster __init__: {type(arg)}")
        self.caster = DefaultCaster() if caster is None else caster

    def get_keys(self) -> set[Hashable]:
        return set(itertools.chain(*[item.get_keys() for item in self.items]))

    def _reinit_groupings(self) -> None:
        for grouping in set(self._item_to_grouping.values()):
            grouping.reinit()

    def _finish_groupings(self) -> None:
        for grouping in set(self._item_to_grouping.values()):
            grouping.finish()

    def _notify_grouping(self, item: Item, value: Any) -> None:
        grouping = self._item_to_grouping.get(id(item))
        if grouping is not None:
            grouping.on_item_got(item, value)

    def add_item(self, item) -> 'DictCaster':
        self.items.append(item)
        return self

    def _process_item(
            self, item: Item, getter: Callable[[Hashable], Any],
            write_callback: Callable[[Hashable, Any], None],
            optional_is_default_none: bool = False
    ) -> None:
        if isinstance(item.source, Hashable):
            value = getter(item.source)
            self._notify_grouping(item, value)
            if value is not NoDefault:
                value = item.preprocess(value)
            elif item.default is not NoDefault:
                value = item.default
            elif item.optional:
                if optional_is_default_none:
                    write_callback(item.result_key, None)
                return
            else:
                raise ValueError(f'required, but missing field: "{item.source}"')
            result_value = self.caster.cast(value, item.cast_as)
        else:
            value = {}
            self._notify_grouping(item, value)
            [self._process_item(sub_item, getter, value.__setitem__) for sub_item in item.source]
            result_value = item.cast_as(**value)

        result_value = item.postprocess(result_value)
        write_callback(item.result_key, result_value)

    def cast(self, target: Mapping[Any, Any], write_into: Union[MutableMapping, Any] = None,
             write_callback: Callable[[Any, Any], None] = None, exception_type: Type[Exception] = ItemException,
             extra_keys_allowed: bool = True, optional_is_default_none: bool = False) -> dict:
        if write_into is None:
            write_into = {}

        if write_callback is None:
            if isinstance(write_into, collections.abc.MutableMapping):
                write_callback = write_into.__setitem__
            else:
                write_callback = functools.partial(setattr, write_into)

        try:
            if not extra_keys_allowed:
                extra_items = set(target) - self.get_keys()
                if extra_items:
                    raise ValueError(f'extra items: {extra_items}')

            getter = lambda x: target.get(x, NoDefault)
            self._reinit_groupings()
            for item in self.items:
                self._process_item(item, getter, write_callback, optional_is_default_none)
            self._finish_groupings()
        except (ValueError, ComplexItemException) as e:
            raise exception_type(e.args[0]) from e

        return write_into

    def cast_and_return(self, target: Mapping[Any, Any], exception_type: Type[Exception] = ItemException,
                        extra_keys_allowed: bool = True) -> Union[List[Any], Any]:
        result = []
        self.cast(target, write_callback=lambda x, y: result.append(y),
                  exception_type=exception_type, extra_keys_allowed=extra_keys_allowed,
                  optional_is_default_none=True)
        if len(result) == 1:
            return result[0]
        return result
