#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import collections.abc
import dataclasses
import inspect
import logging
import traceback
import typing
import warnings
from dataclasses import dataclass, field, MISSING
from functools import cached_property
from typing import TypeVar, Any, Callable, Literal, Awaitable, Iterable, Mapping

from async_tools import acall
from dict_caster.extras import first, is_iterable
from frozendict import frozendict
from http_tools import Answer, HttpStatusCode, ContentType, IncomingRequest, StreamAnswer, JsonableAnswer

from http_tools.answer import BaseAnswer, ExceptionAnswer, ErrorDescription
from init_helpers import raise_if
from init_helpers.dict_to_dataclass import NoValue

from openapi_tools.freeze import repr_type
from .endpoint_body import EndpointBody
from .get_response_from_answer import get_response_from_answer
from openapi_tools.spec import SecurityRequirement, CallParameter, SpecParameter, SecurityScheme, RawBodyParameter, \
    Tag, get_securities_arg_names, BodyParameter, RequestBody, MultipartParameter, BodyPartDescription, \
    Unauthorized, Forbidden, ContentSchema, resolver_cache, Operation, Response
from ..utils import snake_case_to_camel_case

logger = logging.getLogger(__name__)
ResultType = TypeVar('ResultType', bound=Any)
AnswerType = TypeVar('AnswerType', bound=BaseAnswer)


@dataclass(frozen=True)
class Endpoint(typing.Generic[ResultType, AnswerType]):
    func: Callable[..., ResultType]  # request handler
    method: Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'HEAD', 'OPTIONS', 'TRACE',
                    'get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace']   # HTTP метод
    path: str  # HTTP path
    securities: tuple[SecurityRequirement, ...]  # possible ways to check auth, empty means "public"(no auth check)
    # maps handler parameter names to parameter objects, capable of producing value from incoming request
    arg_name_to_param: dict[str, CallParameter]
    answer_type: type[AnswerType]  # HTTP answer type
    # exception mapping to HTTP answer type, used to process <func> exceptions and make spec
    exception_type_to_answer_type: dict[type[Exception], type[Answer]]
    # HTTP answer type for unexpected exceptions
    unhandled_exception__answer_type: type[Answer]
    # main endpoint identification
    operation_id: str
    # Callable to convert <func> call result to an HTTP answer instance
    answer_factory: Callable[[ResultType], AnswerType]
    summary: str  # spec description
    description: str  # spec description
    tags: Iterable[str]  # mark endpoint with tags, spec only
    deprecated: bool  # mark endpoint as deprecated, spec only
    additional_parameters: tuple[SpecParameter, ...]  # add to spec, no other processing

    def __init__(
            self,
            func: Callable[..., ResultType],
            method: str,
            path: str = '',
            securities: Iterable[SecurityRequirement | SecurityScheme] = tuple(),
            arg_name_to_param: Mapping[str, CallParameter | type[RawBodyParameter]] = frozendict(),
            answer_type: type[AnswerType] = Answer[HttpStatusCode.OK],
            exception_type_to_answer_type: dict[type[Exception], type[Answer]] | None = None,
            unhandled_exception__answer_type: type[Answer] | None = ExceptionAnswer[HttpStatusCode.InternalServerError],
            operation_id: str = '',
            answer_factory: Callable[[ResultType], AnswerType] | None = None,
            summary: str | None = None,
            description: str | None = None,
            tags: list[str | Tag] | None = None,
            deprecated: bool = False,
            additional_parameters: tuple[SpecParameter, ...] | None = None,  # add to spec, no other processing
    ) -> None:
        error_text = 'path MUST be either specified directly(path="/test") or appended to "method"(method="GET/test")'
        error = ValueError(error_text + f" got {path=!r}, {method=!r}")
        if not path:
            raise_if('/' not in method, error)
            method, path = method.split("/", 1)
            path = '/' + path

        raise_if('/' in method or not path.startswith('/'), error)
        raise_if(not callable(func), ValueError(f'"func" arg value MUST be callable, got {func=!r}'))
        object.__setattr__(self, 'func', func)
        object.__setattr__(self, 'method', method)
        object.__setattr__(self, 'path', path)
        object.__setattr__(self, 'securities',
                           tuple(s.has() if isinstance(s, SecurityScheme) else s for s in securities))
        object.__setattr__(self, 'arg_name_to_param', {
            a: p() if isinstance(p, type) else p for a, p in arg_name_to_param.items()})
        object.__setattr__(self, 'answer_type', answer_type)
        object.__setattr__(self, 'exception_type_to_answer_type', exception_type_to_answer_type or {})
        object.__setattr__(self, 'unhandled_exception__answer_type', unhandled_exception__answer_type)
        object.__setattr__(self, 'operation_id', operation_id or func.__name__)
        object.__setattr__(self, 'answer_factory', answer_factory or answer_type)
        object.__setattr__(self, 'summary', summary or '')
        object.__setattr__(self, 'description', description or func.__doc__ or '')
        object.__setattr__(self, 'tags', tuple(tags or ()))
        object.__setattr__(self, 'deprecated', deprecated)
        object.__setattr__(self, 'additional_parameters', additional_parameters or tuple())
        get_securities_arg_names(self.securities)

    @cached_property
    def parameters(self) -> frozendict[CallParameter, None]:
        result: dict['SpecParameter', None] = {}
        for parameter in self.arg_name_to_param.values():
            result |= parameter.get_spec_parameters()
        if self.additional_parameters:
            for parameter in self.additional_parameters:
                result |= parameter.get_spec_parameters()
        return frozendict(result)

    @cached_property
    def body_parameters(self) -> tuple[BodyParameter, ...]:
        body_params: list[BodyParameter] = [param for param in self.parameters if isinstance(param, BodyParameter)]
        assert len(param_types := {type(param) for param in body_params}) <= 1, f"Can not combine {param_types} in body"
        assert len(param_mimes := {param.body_mime_type for param in body_params}) <= 1, \
            f"Can not combine {param_mimes} in body"
        return tuple(sorted(body_params, key=lambda x: x.default is not NoValue))

    @cached_property
    def body_parameter_type(self) -> type[BodyParameter] | None:
        return type(first(self.body_parameters, none_if_empty=True))

    @cached_property
    def non_body_parameters(self) -> tuple[CallParameter, ...]:
        return tuple(param for param in self.parameters if not isinstance(param, BodyParameter))

    @cached_property
    def body_content_schema(self) -> type | None:
        if not (body_params := self.body_parameters):
            return None

        if self.body_parameter_type and issubclass(self.body_parameter_type, RawBodyParameter):
            return bytes

        class_: type = type(f'{snake_case_to_camel_case(self.operation_id)}Body', (EndpointBody,), {
            p.name: field(default=MISSING if p.default is not NoValue else p.default) for p in body_params})
        class_.__annotations__ = {param.name: param.annotated_schema for param in body_params}
        # noinspection PyTypeChecker
        return dataclasses.dataclass(kw_only=True)(class_)

    @cached_property
    def request_body(self) -> RequestBody | None:
        if (body_parameter_type := self.body_parameter_type) and (body_content_schema := self.body_content_schema):
            mime_type = body_parameter_type.body_mime_type
            content_types = mime_type if is_iterable(mime_type) else [mime_type]
            key_to_additional_param_info = {}
            if mime_type in (ContentType.MultipartForm, ContentType.FormUrlEncoded):
                for body_param in self.body_parameters:
                    if isinstance(body_param, MultipartParameter) and body_param.content_type or body_param.headers:
                        header_to_schema: dict[str, typing.Annotated | type] = {}
                        for body_part_header_param in body_param.headers:
                            param_name = body_part_header_param.name
                            param_schema = body_part_header_param.schema
                            raise_if((present_schema := header_to_schema.get(param_name, param_schema)) != param_schema,
                                     KeyError(f"Multipart header {param_name} is used with different schema in "
                                              f"multiple header parameters: {param_schema=}, {present_schema=}"))
                            header_to_schema[param_name] = param_schema

                        key_to_additional_param_info[body_param.name] = BodyPartDescription(
                            explode=body_param.explode, style=body_param.style,
                            content_type=body_param.allowed_content_type, headers=header_to_schema)
            return RequestBody({
                content_type: ContentSchema(body_content_schema, schema_key_to_description=key_to_additional_param_info)
                for content_type in content_types
            })

    # @cached_property
    # def code_to_answer(self) -> dict[http_tools.HttpStatusCode, type[http_tools.Answer]]:
    #     result = {self.answer_type.get_class_status_code(): self.answer_type}
    #     result |= {ans_t.get_class_status_code(): ans_t for ans_t in self.exception_type_to_answer_type.values()}
    #     if self.unhandled_exception__answer_type is not None:
    #         unhandled_exception__answer_code = self.unhandled_exception__answer_type.get_class_status_code()
    #         result[unhandled_exception__answer_code] = self.unhandled_exception__answer_type
    #     return result

    @property
    def _code_to_response(self) -> Mapping[int | Literal['default'], Response]:
        answer_types = [self.answer_type] + list(self.exception_type_to_answer_type.values())
        code_to_answer = {a.get_class_status_code(): a for a in answer_types}
        code_to_answer |= {'default': default_ans} if (default_ans := self.unhandled_exception__answer_type) else {}
        return frozendict({code: get_response_from_answer(answer) for code, answer in code_to_answer.items()})

    def as_operation(self) -> Operation:
        return Operation(
            path=self.path,
            method=self.method,
            operation_id=self.operation_id,
            securities=self.securities,
            parameters=[param for param in self.non_body_parameters if isinstance(param, SpecParameter)],
            request_body=self.request_body,
            code_to_response=self._code_to_response,
            summary=self.summary,
            description=self.description,
            tags=self.tags,
            deprecated=self.deprecated,
        )

    def gen_handler(self) -> Callable[[IncomingRequest], Awaitable[Answer | StreamAnswer]]:
        # TODO: check argument fullness
        if not self.securities:
            warnings.warn('Endpoints without securities will be prohibited')

        raise_if(
            inspect.isgeneratorfunction(self.func),
            NotImplementedError("Synchronous generator functions are not supported")
        )
        raise_if(
            inspect.isasyncgenfunction(self.func) and not issubclass(self.answer_type, StreamAnswer),
            TypeError("Endpoints, based on asynchronous generator function, require StreamResponse answer type")
        )

        async def handler(request: IncomingRequest) -> Answer | StreamAnswer:
            security_kwargs = None
            cache_token = resolver_cache.set({})
            security_key_to_exception = {}
            if self.securities:
                for security_requirement in self.securities:
                    try:
                        security_kwargs = await security_requirement.evaluate(request)
                        break
                    except* (Unauthorized, Forbidden) as exc_group:
                        if security_requirement.do_log:
                            security_key_to_exception[security_requirement.string_key] = ';'.join(
                                repr(e) for e in exc_group.exceptions)
            else:
                security_kwargs = {}
            resolver_cache.reset(cache_token)

            if security_kwargs is None:
                is_forbidden = any(isinstance(exc, Forbidden) for exc in security_key_to_exception.values())
                status_code = HttpStatusCode.Forbidden if is_forbidden else HttpStatusCode.Unauthorized
                # noinspection PyTypeChecker
                return JsonableAnswer(ErrorDescription(security_key_to_exception, "ErrorGroup"), status_code)

            try:
                param_kwargs = {
                    name: value for name, param in self.arg_name_to_param.items()
                    if (value := await param.get(request, security_kwargs)) is not NoValue
                }
                if inspect.isasyncgenfunction(self.func):
                    # noinspection PyTypeChecker
                    async_gen: collections.abc.AsyncGenerator = self.func(**(security_kwargs | param_kwargs))
                    first_result = await anext(async_gen)  # most possible exceptions will be raised here
                    stream_answer: StreamAnswer = self.answer_factory(first_result)
                    await stream_answer.prepare(request.request)
                    try:
                        async for value in async_gen:
                            await stream_answer.write(value)
                    except Exception as e:
                        logger.exception(e)
                    return stream_answer
                else:
                    call_result = await acall(self.func(**(security_kwargs | param_kwargs)))
                    return self.answer_factory(call_result)
            except Exception as e:
                for exception_type, answer_type in self.exception_type_to_answer_type.items():
                    if isinstance(e, exception_type):
                        return answer_type(e)

                if self.unhandled_exception__answer_type is not None:
                    traceback.print_exc()
                    logger.exception(e)
                    return self.unhandled_exception__answer_type(e)
                raise

        return handler

    # @functools.cache
    def _get_repr_parts(self) -> tuple[str, ...]:
        parts = [
            f'func={self.func!r}', f'method={self.method!r}', f'path={self.path!r}',
            f'operation_id={self.operation_id!r}'
        ]
        parts += [f'securities={self.securities!r}'] if self.securities else []
        parts += [f'arg_name_to_param={dict(self.arg_name_to_param)!r}'] if self.arg_name_to_param else []

        parts += [f'request_body={self.request_body!r}'] if self.request_body is not None else []
        parts += [f'answer_type={repr_type(self.answer_type)}'] if self.answer_type else []
        if self.exception_type_to_answer_type:
            subparts = [f"{repr_type(e)}: {repr_type(a)}" for e, a in self.exception_type_to_answer_type.items()]
            parts += [f'exception_type_to_answer_type={{{", ".join(subparts)}}}']
        if self.exception_type_to_answer_type:
            parts += [f'unhandled_exception__answer_type={repr_type(self.unhandled_exception__answer_type)}']

        parts += [f'answer_factory={repr_type(self.answer_factory)}'] if self.answer_factory != self.answer_type else []
        parts += [f'summary={self.summary!r}'] if self.summary else []
        parts += [f'description={self.description!r}'] if self.description else []
        parts += [f'tags={list(self.tags)}'] if self.tags else []
        parts += [f'deprecated={self.deprecated}'] if self.deprecated else []
        parts += [f'additional_parameters={list(self.additional_parameters)}'] if self.additional_parameters else []
        return tuple(parts)

    def __repr__(self):
        return f'{self.__class__.__name__}({", ".join(self._get_repr_parts())})'

    __str__ = __repr__
