#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import operator
import typing
from dataclasses import dataclass, field
from types import UnionType
from typing import Any, ClassVar, Iterable

from dict_caster.extras import is_iterable
from http_tools import IncomingRequest
from http_tools.mime_types import ContentType
from http_tools.multipart_form_data import BodyPart
from init_helpers import raise_if
from init_helpers.dict_to_dataclass import NoValue, FieldErrors
from multidict import CIMultiDictProxy

from ..example import BaseExample
from .aggregation.call_template import CallTemplate
from .body.body_parameter import BodyParameter
from .header_parameter import HeaderParameter
from .parameter_style import ParameterStyle


@dataclass(frozen=True)
class ParsedBodyPart:
    filename: str | None
    header_name_to_value: dict[str, Any]
    payload: Any
    encoding: str | None
    content_type: ContentType | None


class MultipartParameterHeaders:
    def __init__(self, parent: 'MultipartParameter') -> None:
        self.parent = parent
        self._parent_header_call = CallTemplate(operator.attrgetter("header_name_to_value"), self.parent)

    def __getitem__(self, item: str):
        raise_if(not isinstance(item, str), TypeError(f'Header names must be string, got: {type(item)}: {item}'))
        return CallTemplate(operator.itemgetter(item), self._parent_header_call)

    def __str__(self):
        return f'MultipartParameterHeaders(parent={self.parent})'

    __repr__ = __str__


@dataclass(frozen=True, kw_only=True)
class MultipartParameter(BodyParameter):
    body_mime_type: ClassVar[ContentType] = ContentType.MultipartForm
    schema: type | UnionType
    allowed_content_type: ContentType | tuple[ContentType, ...] | None = None
    headers: tuple[HeaderParameter, ...] = field(init=False)
    unpack: bool = True
    explode: bool | None = None

    def __hash__(self):
        # TODO: discover why this helps
        return hash((
            self.name, self.schema, self.is_optional, self.default, self.description, self.examples,
            self.allowed_content_type, self.headers, self.unpack, self.explode, self.is_filename_required, self.style,
        ))

    def __init__(
            self,
            name: str,
            schema: type | UnionType = bytes,
            is_optional: bool = False,
            default: Any = NoValue,
            description: str | NoValue = NoValue,
            examples: tuple[BaseExample, ...] | NoValue = NoValue,
            allowed_content_type: ContentType | Iterable[ContentType] | None = None,
            headers: Iterable[HeaderParameter] | None = None,
            unpack: bool = True,
            explode: bool | None = None,
            is_filename_required: bool = False,
            style: ParameterStyle | None = None
    ) -> None:
        if not unpack and default is not NoValue:
            raise_if(not isinstance(default, BodyPart),
                     TypeError(f'With unpack=False default MUST be BodyPart, got {default=}'))
        if explode and typing.get_origin(schema) not in (list, tuple, set):
            raise TypeError(f'explode=True can be used only with iterable schema, got {schema=}')
        object.__setattr__(self, 'name', name)
        object.__setattr__(self, 'schema', schema)
        object.__setattr__(self, 'is_optional', is_optional)
        object.__setattr__(self, 'default', default)
        object.__setattr__(self, 'description', description)
        object.__setattr__(self, 'examples', examples)

        if allowed_content_type:
            if is_iterable(allowed_content_type):
                # noinspection PyUnresolvedReferences
                types_tuple = tuple(t.value for t in allowed_content_type)
                object.__setattr__(self, 'allowed_content_type', types_tuple)
                object.__setattr__(
                    self, '_check_content_type',
                    lambda x: raise_if(x not in types_tuple, KeyError(f'Expected {types_tuple}, got: {x!r}')))
            else:
                object.__setattr__(self, 'allowed_content_type', allowed_content_type)
                object.__setattr__(
                    self, '_check_content_type',
                    lambda x: raise_if(x != allowed_content_type, KeyError(f'Expected {allowed_content_type}, got: {x}')))
        else:
            object.__setattr__(self, 'allowed_content_type', None)

        if headers:
            object.__setattr__(self, 'headers', headers_tuple := tuple(headers))
            raise_if(
                len({h.name for h in headers_tuple}) < len(headers_tuple),
                KeyError('Got header parameters with ')
            )
            # object.__setattr__(
            #     self, '_check_headers',
            #     lambda x: raise_if(x not in content_type, KeyError(f'Expected {content_type}, got: {x}')))
        else:
            object.__setattr__(self, 'headers', tuple())
        object.__setattr__(self, 'unpack', unpack)
        object.__setattr__(self, 'explode', explode)
        object.__setattr__(self, 'is_filename_required', is_filename_required)
        object.__setattr__(self, 'style', style)

    def _check_content_type(self, content_type: str | None) -> None:
        pass

    def _check_headers(self, header_name_to_value: CIMultiDictProxy[str]) -> None:
        pass

    def _get(self, incoming_request: IncomingRequest) -> BodyPart | Any:
        raise_if(incoming_request.parsed_body is None, KeyError(self.name))
        raise_if(not isinstance(incoming_request.parsed_body, CIMultiDictProxy),
                 TypeError('Multipart parameter got non multipart body'))
        body: CIMultiDictProxy = incoming_request.parsed_body
        result = body.getall(self.name)
        for body_part in result:
            self._check_content_type(body_part.content_type)
        if not self.explode:
            if len(result) != 1:
                raise ValueError(f"Multipart got {len(result)} values, expected 1")
            result = result[0]

        return result

    def _extract_headers(self, value: BodyPart) -> dict[str, Any]:
        result: dict[str, Any] = {}
        field_name_to_error: dict[str, Exception] = {}
        for header in self.headers:
            if (val := value.get_all_header_values(header.name, header.default)) == NoValue and header.is_required:
                field_name_to_error[header.name] = KeyError("Missing multipart header")
                continue
            if not header.explode:
                if len(val) != 1:
                    field_name_to_error[header.name] = ValueError(f"Multipart header got {len(val)} values, expected 1")
                    continue
                val = val[0]
            result[header.name] = header._cast_value_to_schema(val)

        if field_name_to_error:
            raise FieldErrors(field_name_to_error)
        return result

    def _cast_single_value(self, value: BodyPart, schema: type, name: str) -> ParsedBodyPart:
        if self.unpack:
            if isinstance(value, BodyPart):
                value = value.parse()
            return super()._cast_single_value(value=value, schema=schema, name=name)
        if not isinstance(value, BodyPart):
            raise TypeError(f'UNEXPECTED: Non unpack cast got non BodyPart {value=}')
        return ParsedBodyPart(
            filename=value.filename,
            header_name_to_value=self._extract_headers(value),
            payload=super()._cast_single_value(value=value.parse(), schema=schema, name=name),
            encoding=value.encoding,
            content_type=value.content_type
        )

    def _cast_value_to_schema(self, value: Any):
        if self.unpack:
            if isinstance(value, BodyPart):
                return super()._cast_value_to_schema(value=value.parse())
            return super()._cast_value_to_schema(value=value)
        return super()._cast_value_to_schema(value=value)

    @property
    def payload(self) -> 'CallTemplate[bytes]':
        return CallTemplate(operator.attrgetter("payload"), self)

    @property
    def filename(self) -> 'CallTemplate[str | None]':
        return CallTemplate(operator.attrgetter("filename"), self)

    @property
    def content_type(self) -> 'CallTemplate[ContentType | None]':
        return CallTemplate(operator.attrgetter("content_type"), self)

    @property
    def header(self) -> MultipartParameterHeaders:
        return MultipartParameterHeaders(self)
