import functools
import json
from dataclasses import dataclass
from enum import StrEnum
from typing import Any, TypeVar

from aiohttp import MultipartReader, BodyPartReader
from aiohttp.web_exceptions import HTTPBadRequest
from multidict import CIMultiDictProxy, CIMultiDict

from .mime_types import ContentType


class MultipartHeaders(StrEnum):
    content_type = 'Content-Type'


T = TypeVar("T")


@dataclass(frozen=True)
class BodyPart:
    filename: str
    header_name_to_value: CIMultiDictProxy[str]
    content: bytes
    encoding: str | None

    def __post_init__(self):
        assert (content_type := self.content_type) is None or isinstance(content_type, str)

    @functools.cached_property
    def content_type(self) -> str | None:
        content_type = None
        content_types: list[str] | None = self.get_all_header_values(MultipartHeaders.content_type)
        if content_types:
            if len(content_types) > 1:
                raise HTTPBadRequest(reason=str(f'Body part has more than one content type: {content_types}'))
            content_type = content_types[0]
        return content_type

    @functools.cached_property
    def size(self) -> int:
        return len(self.content)

    def parse(self, default_content_type: ContentType = ContentType.Text) -> Any:
        content_type = self.content_type or default_content_type
        if content_type == ContentType.Text:
            return self.content.decode(self.encoding or "utf-8")
        elif content_type == ContentType.Json:
            return json.loads(self.content.decode(self.encoding or "utf-8"))
        else:
            return self.content

    def get_one_header_value(self, name: str, default: T = None) -> str | T:
        return self.header_name_to_value.getone(name, default)

    def get_all_header_values(self, name: str, default: T = None) -> list[str] | T:
        return self.header_name_to_value.getall(name, default)

    def __str__(self):
        parts = []
        parts += [f"filename={self.filename}"] if self.filename is not None else []
        parts += [f"content_type={self.content_type}"] if self.content_type is not None else []
        if self.header_name_to_value.keys() - {MultipartHeaders.content_type}:
            headers = [
                repr(f'{k}:{v}')
                for k, v in self.header_name_to_value.items()
                if k != MultipartHeaders.content_type
            ]
            parts += [f"headers={{{','.join(headers)}}}"]
        parts += f'size={self.size}'
        return f'BodyPart({", ".join(parts)})'


async def parse_multipart_form_data(reader: MultipartReader) -> CIMultiDictProxy[BodyPart]:
    result: CIMultiDict[BodyPart] = CIMultiDict()

    async for part in reader:
        part: BodyPartReader
        # noinspection PyTypeChecker
        encoding: str | None = part.get_charset(default=None)
        result.add(part.name, BodyPart(part.filename, part.headers, await part.read(decode=True), encoding=encoding))

    return CIMultiDictProxy(result)
