import dataclasses
import datetime
import functools
import logging
import types
import typing
from dataclasses import dataclass, is_dataclass
from decimal import Decimal
from enum import Enum
from types import UnionType
from typing import ClassVar, Optional, Any, Iterable, Callable, Literal

import apispec.utils
from apispec import APISpec
from http_tools import Answer
from init_helpers import raise_if, try_extract_type_notes
from init_helpers.dict_to_dataclass import NoValue

from .example import Example, AnswerExample
from .extras import DataclassProtocol
from .parameter import SpecParameter
from .security import Security, SecurityScheme

logger = logging.getLogger(__name__)


@dataclass
class Content:
    mime_type: str
    schema: type


@dataclass
class RequestBody:
    content: Content


@dataclass
class Endpoint:
    path: str
    method: Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'HEAD', 'OPTIONS', 'TRACE',
                    'get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace']   # HTTP метод
    operation_id: str
    securities: list[Security]
    parameters: list[SpecParameter]
    request_body: Optional[RequestBody]
    code_to_answer: dict[int, type[Answer]]
    summary: str = ''
    description: str | None = None
    tags: list[str] | None = None
    deprecated: bool = False


@dataclass
class GetIdsFromXlsxResponseSchema:
    done: bool
    result: list[int]


class OpenApiWrapper:
    openapi_version: ClassVar[str] = "3.1.0"

    @dataclasses.dataclass
    class Config:
        version: str = ''  # empty means autogenerate

    @dataclasses.dataclass
    class Context:
        company: str
        project_group: str
        project_name: str

    def __init__(self, config: Config, context: Context):
        self.version = config.version or self._get_version()
        self.config = config
        self.context = context
        self.endpoints: list[Endpoint] = []
        self.dataclass_type_to_ref: dict[type[DataclassProtocol], str] = {}
        self.required_dataclass_types: set[type[DataclassProtocol]] = set()
        self.dataclass_type_to_name: dict[type[DataclassProtocol] | UnionType, str] = {}
        self.origin_type_to_factory: dict[type | UnionType, Callable[[Iterable[type | UnionType]], dict]] = {
            typing.Union: self._get_union_description,
            types.UnionType: self._get_union_description,
            list: self._get_list_description,
            dict: self._get_dict_description,
            set: self._get_set_description,
        }
        self.primitive_type_to_description: dict[type, dict] = {
            bool: {"type": "boolean"},  # MUST be before int
            int: {"type": "integer"},
            float: {"type": "number"},
            bytes: {"type": "string", "format": "binary"},
            str: {"type": "string"},
            Decimal: {"type": "string", 'pattern': r'^\d*\.?\d*$'},
            type(None): {"type": "null"},
        }

    @functools.cache
    def _get_dataclass_type_name(self, dataclass_type: type[DataclassProtocol]) -> str:
        if name := self.dataclass_type_to_name.get(dataclass_type):
            return name
        clean_type, notes = try_extract_type_notes(dataclass_type)
        base_name = f"schema_{clean_type.__name__}"
        occupied_names = self.dataclass_type_to_name.values()
        i = 1
        while (name := f'{base_name}{"_" + str(i) if i > 1 else ""}') in occupied_names:
            i += 1
        self.dataclass_type_to_name[dataclass_type] = name
        return name

    def register(self, endpoint: Endpoint) -> None:
        raise_if(
            endpoint.operation_id in {e.operation_id for e in self.endpoints},
            KeyError(f"operation_id {endpoint.operation_id!r} already registered")
        )
        self.endpoints.append(endpoint)

    def register_dataclass(self, dataclass_type: type[DataclassProtocol], name: str) -> None:
        self.dataclass_type_to_name[dataclass_type] = name

    def get_spec_dict(self) -> dict:
        return self._get_spec().to_dict()

    def _get_spec(self) -> APISpec:
        spec = APISpec(
            title=f"{self.context.company}:{self.context.project_group}:{self.context.project_name}",
            version=self.version,
            openapi_version=self.openapi_version
        )
        for dataclass_type in self.dataclass_type_to_name:
            self._get_schema_description(dataclass_type)
        for endpoint in self.endpoints:
            self._add_endpoint_to_spec(spec, endpoint)
        self._add_required_dataclass_types_to_spec(spec)
        return spec

    def _add_required_dataclass_types_to_spec(self, spec: APISpec) -> None:
        added_dataclass_types = set()
        while missing_dataclass_types := self.required_dataclass_types - added_dataclass_types:
            for dataclass_type in missing_dataclass_types:
                self._add_dataclass_to_spec(spec, dataclass_type)
                added_dataclass_types.add(dataclass_type)

    @staticmethod
    def _get_version() -> str:
        now = datetime.datetime.now()
        return f"{now.year - 2021}.{now.month}.{now.day}{now.hour:02d}{now.minute:02d}"

    def _add_endpoint_to_spec(self, spec: APISpec, endpoint: Endpoint) -> None:
        operation_dict = {
            "summary": endpoint.summary,
            "operationId": endpoint.operation_id,
            "responses": self._add_operation_code_to_answer(spec, endpoint.code_to_answer),
        }
        if security := self._add_operation_securities(spec, endpoint.securities):
            operation_dict["security"] = security
        if parameters := self._get_operation_parameters(endpoint.parameters):
            operation_dict["parameters"] = parameters
        if endpoint.description:
            operation_dict["description"] = endpoint.description
        if request_body_datum := self._get_operation_request_body(endpoint.request_body):
            operation_dict["requestBody"] = request_body_datum

        spec.path(endpoint.path, operations={endpoint.method.lower(): operation_dict})

    def _get_operation_parameters(self, parameters: list[SpecParameter]) -> list[dict[str, Any]]:
        result: list[dict[str, Any]] = []
        for param in sorted(parameters, key=lambda x: x.name):
            schema_name = self._get_schema_description(param.schema, default=param.default)
            param_dict = {
                "name": param.name, "in": param.location, "required": param.is_required, "schema": schema_name}
            if param.description is not NoValue:
                param_dict['description'] = param.description
            if param.examples is not NoValue:
                if len(param.examples) == 1 and isinstance(ex := param.examples[0], Example) and not ex.name:
                    param_dict['example'] = ex.value
                else:
                    param_dict['examples'] = {
                        (example.name or str(i)): example.as_dict() for i, example in enumerate(param.examples, 1)}
            result.append(param_dict)
        return result

    def _add_operation_securities(self, spec: APISpec, securities: list[Security]) -> list[dict[str, list[str]]]:
        return [
            {self._add_security_schema_to_spec(spec, security.scheme): list(security.scopes)}
            for security in securities
        ]

    def _add_operation_code_to_answer(self, spec: APISpec, code_to_answer: dict[int, type[Answer]]) -> dict[str, str]:
        result: dict[str, str] = {}
        for code, answer in code_to_answer.items():
            response_name = self._add_answer_to_spec(spec, answer)
            result[str(int(code))] = response_name
        return result

    def _get_operation_request_body(self, request_body: Optional[RequestBody]) -> dict[str, Any] | None:
        result = None
        if request_body is not None:
            schema_description = self._get_schema_description(request_body.content.schema)
            result = {"required": True, "content": {request_body.content.mime_type: {"schema": schema_description}}}
        return result

    @functools.cache
    def _get_union_description(self, type_args: tuple[type | UnionType]) -> dict:
        assert type_args and len(type_args) >= 1, f'bad union element type: {type_args}'
        return {"anyOf": [self._get_schema_description(arg) for arg in type_args]}

    @functools.cache
    def _get_list_description(self, type_args: tuple[type | UnionType] | None) -> dict:
        if type_args:
            assert type_args and len(type_args) == 1, f'bad list element type: {type_args}'
            return {"type": "array", "items": self._get_schema_description(type_args[0])}
        else:
            return {"type": "array"}

    def _get_set_description(self, type_args: tuple[type | UnionType] | None) -> dict:
        description = {"type": "array", 'uniqueItems': True}
        if type_args:
            assert type_args and len(type_args) == 1, f'bad set element type: {type_args}'
            description |= {"items": self._get_schema_description(type_args[0])}
        return description

    def _get_dict_description(self, type_args: tuple[type | UnionType] | None) -> dict:
        if type_args:
            assert len(type_args) == 2, f'bad dict item type: {type_args}'
            return {"type": "object", "additionalProperties": self._get_schema_description(type_args[1])}
        else:
            return {"type": "object"}

    def _get_dataclass_description(self, schema: type[DataclassProtocol]) -> dict:
        key_to_schema = {}
        schema, notes = try_extract_type_notes(schema)
        key_to_type_hint = typing.get_type_hints(schema)
        for field_ in dataclasses.fields(schema):
            if field_.repr:
                try:
                    key_to_schema[field_.name] = self._get_schema_description(field_.type)
                except TypeError:
                    key_to_schema[field_.name] = self._get_schema_description(key_to_type_hint[field_.name])
        result = {"type": "object", "properties": key_to_schema}

        if examples := tuple(note for note in notes if isinstance(note, Example)):
            result["example"] = examples[0].value
        return result

    @functools.cache
    def _get_generic_description(
            self, origin_type: type | UnionType, type_args: tuple[type | UnionType] | None = None) -> dict:
        if factory := self.origin_type_to_factory.get(origin_type):
            return factory(type_args)
        else:
            raise TypeError(f"Unknown origin_type: {origin_type}")

    @classmethod
    def _save_schema_to_spec(cls, spec: APISpec, description: dict, name: str) -> dict:
        logger.debug(f'save schema {name!r}: {description}')
        spec.components.schema(component_id=name, component=description)
        return spec.components.get_ref("schema", name)

    @functools.cache
    def _get_primitive_description(self, type_: type) -> dict:
        for primitive_type, description in self.primitive_type_to_description.items():
            if issubclass(type_, primitive_type):
                return description
        raise TypeError(f"Unknown type: {type_}")

    @functools.cache
    def _get_schema_description(self, schema: type | UnionType, default: Any = NoValue) -> dict:
        logger.debug('_get_schema_description: %s', schema)
        raw_schema = schema
        schema, notes = try_extract_type_notes(schema)
        if schema is typing.Any:
            schema = dict | list | str | float | int | bool | None
        if origin_type := typing.get_origin(schema):
            description = self._get_generic_description(origin_type, typing.get_args(schema))
        elif schema in self.origin_type_to_factory:
            description = self._get_generic_description(schema)
        elif dataclasses.is_dataclass(schema):
            self.required_dataclass_types.add(raw_schema)
            name = self._get_dataclass_type_name(raw_schema)
            return apispec.utils.build_reference("schema", 3, name)
        else:
            description = self._get_primitive_description(schema)

        # AVOID MODIFICATION OF CACHED VALUES!!!
        if isinstance(schema, type) and issubclass(schema, Enum):
            description = description | {'enum': [e.value for e in schema]}
        if default is None or isinstance(default, (int, float, str, list, dict)):
            description = description | {'default': default}
        if notes:
            if examples := tuple(note for note in notes if isinstance(note, Example)):
                description = description | {'example': examples[0].value}
            if descriptions := [note for note in notes if isinstance(note, str)]:
                description = description | {'description': descriptions[0]}
        # if default == NoValue:  # TODO: think about it, should we place "required" inside schema properties or not?
        #     description['required'] = True
        return description

    @functools.cache
    def _add_dataclass_to_spec(self, spec: APISpec, schema: DataclassProtocol) -> None:
        logger.debug(f'_add_dataclass_to_spec: {schema}')
        clean_schema, notes = try_extract_type_notes(schema)
        raise_if(not is_dataclass(clean_schema), TypeError(f"Expected dataclass, got {schema}"))
        description = self._get_dataclass_description(schema)
        name = self._get_dataclass_type_name(schema)
        # self._safe_add_component_to_spec(spec, 'schema', name, description)
        spec.components.schema(component_id=name, component=description)

    @functools.cache
    def _add_answer_to_spec(self, spec: APISpec, answer: type[Answer]) -> str:
        answer, notes = try_extract_type_notes(answer)
        descriptions = [note for note in notes if isinstance(note, str)]

        content_type = answer.get_class_content_type()
        schema = answer.get_class_payload_type()
        if answer_examples := [note for note in notes if isinstance(note, AnswerExample)]:
            example_body = answer_examples[0].value
            schema = typing.Annotated[schema, Example(example_body)]
        raise_if(schema is None, TypeError(f'Attempt to add {answer=} to spec, but answer schema is None'))
        schema_name_or_description = self._get_schema_description(schema)
        component = {"description": descriptions[0] if descriptions else ''}
        if content_type is not None:
            component["content"] = {content_type: {"schema": schema_name_or_description}}
        return self._safe_add_component_to_spec(spec, 'response', answer.__name__, component)

    def _safe_add_component_to_spec(
            self, spec: APISpec,
            component_type: Literal['schema', 'response', 'parameter', 'header', 'example', 'security_scheme'],
            base_name: str, description: dict
    ) -> str:
        components_dict = spec.components._subsections[component_type]
        i = 1
        while (name := f'{base_name}{"_" + str(i) if i > 1 else ""}') in components_dict:
            i += 1
        components_dict[name] = description
        return name

    @functools.cache
    def _add_security_schema_to_spec(self, spec: APISpec, security_schema: SecurityScheme) -> str:
        spec.components.security_scheme(component_id=security_schema.key, component=security_schema.to_spec())
        return security_schema.key

    @staticmethod
    def _get_name_for_object(object_: Any) -> str:
        return f"component_{object_.__name__}"
