#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors:
#  Mike Orlov <m.orlov@techokert.ru>
import collections.abc
import dataclasses
import inspect
import json
import logging
import pathlib
import traceback
import typing
import warnings
from dataclasses import dataclass, field, MISSING
from functools import cached_property
from typing import Callable, Any, Awaitable, TypeVar, Literal

import http_tools
import yarl
from async_tools import acall
from dict_caster.extras import first, is_iterable
from http_tools import Answer, HttpServer, HttpStatusCode
from http_tools.answer import ExceptionAnswer, ErrorDescription, JsonableAnswer, FileAnswer, BaseAnswer, StreamAnswer
from http_tools.mime_types import ContentType
from http_tools.request import IncomingRequest
from init_helpers import custom_dumps, raise_if
from init_helpers.dict_to_dataclass import NoValue

from .open_api_wrapper import OpenApiWrapper, Endpoint, RequestBody, Content, Tag, BodyPartDescription
from .parameter import CallParameter, SpecParameter, RawBodyParameter, BodyParameter, MultipartParameter
from .security import Unauthorized, Forbidden, Security, get_securities_arg_names, resolver_cache, SecurityScheme

logger = logging.getLogger(__name__)
ResultType = TypeVar('ResultType', bound=Any)
AnswerType = TypeVar('AnswerType', bound=BaseAnswer)


def _return_empty_dict(*_: Any, **__: Any):
    return {}


@dataclass(frozen=True)
class RpcEndpoint(typing.Generic[ResultType, AnswerType]):
    func: Callable[..., ResultType]  # request handler
    method: Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'HEAD', 'OPTIONS', 'TRACE',
                    'get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace']   # HTTP метод
    path: str  # HTTP path
    securities: list[Security]  # possible ways to check auth, empty means "no auth"
    # maps handler parameter names to parameter objects, capable of producing value from incoming request
    arg_name_to_param: dict[str, CallParameter]
    answer_type: type[AnswerType]  # HTTP answer type
    # exception mapping to HTTP answer type, used to process <func> exceptions and make spec
    exception_type_to_answer_type: dict[type[Exception], type[ExceptionAnswer]]
    # HTTP answer type for unexpected exceptions
    unhandled_exception__answer_type: type[Answer]
    # main endpoint identification
    operation_id: str
    # Callable to convert <func> call result to an HTTP answer instance
    answer_factory: Callable[[ResultType], AnswerType]
    summary: str  # spec description
    description: str  # spec description
    tags: list[str]  # mark endpoint with tags, spec only
    deprecated: bool  # mark endpoint as deprecated, spec only
    additional_parameters: tuple[SpecParameter, ...]  # add to spec, no other processing

    def __init__(
            self,
            func: Callable[..., ResultType],
            method: Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'HEAD', 'OPTIONS', 'TRACE',
                            'get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace'],
            path: str,
            securities: list[Security | SecurityScheme],
            arg_name_to_param: dict[str, CallParameter | type[RawBodyParameter]],
            answer_type: type[AnswerType],
            exception_type_to_answer_type: dict[type[Exception], type[ExceptionAnswer]] | None = None,
            unhandled_exception__answer_type: type[Answer] | None = ExceptionAnswer[HttpStatusCode.InternalServerError],
            operation_id: str = '',
            answer_factory: Callable[[ResultType], AnswerType] | None = None,
            summary: str | None = None,
            description: str | None = None,
            tags: list[str | Tag] | None = None,
            deprecated: bool = False,
            additional_parameters: tuple[SpecParameter, ...] | None = None,  # add to spec, no other processing
    ) -> None:
        object.__setattr__(self, 'func', func)
        object.__setattr__(self, 'method', method)
        object.__setattr__(self, 'path', path)
        object.__setattr__(self, 'securities', [s if isinstance(s, Security) else s.has() for s in securities])
        object.__setattr__(self, 'arg_name_to_param', {
            a: p() if isinstance(p, type) else p for a, p in arg_name_to_param.items()})
        object.__setattr__(self, 'answer_type', answer_type)
        object.__setattr__(self, 'exception_type_to_answer_type', exception_type_to_answer_type or {})
        object.__setattr__(self, 'unhandled_exception__answer_type', unhandled_exception__answer_type)
        object.__setattr__(self, 'operation_id', operation_id or func.__name__)
        object.__setattr__(self, 'answer_factory', answer_factory or answer_type)
        object.__setattr__(self, 'summary', summary or '')
        object.__setattr__(self, 'description', description or func.__doc__ or '')
        object.__setattr__(self, 'tags', tags or [])
        object.__setattr__(self, 'deprecated', deprecated)
        object.__setattr__(self, 'additional_parameters', additional_parameters or tuple())
        get_securities_arg_names(self.securities)

    @cached_property
    def parameters(self) -> set[CallParameter]:
        result = set()
        for parameter in self.arg_name_to_param.values():
            result = result.union(parameter.get_spec_parameters())
        if self.additional_parameters:
            for parameter in self.additional_parameters:
                result = result.union(parameter.get_spec_parameters())
        return result

    @cached_property
    def body_parameters(self) -> list[BodyParameter]:
        body_params: list[BodyParameter] = [param for param in self.parameters if isinstance(param, BodyParameter)]
        assert len(param_types := {type(param) for param in body_params}) <= 1, f"Can not combine {param_types} in body"
        assert len(param_mimes := {param.body_mime_type for param in body_params}) <= 1, \
            f"Can not combine {param_mimes} in body"
        return sorted(body_params, key=lambda x: x.default is not NoValue)

    @cached_property
    def body_parameter_type(self) -> type[BodyParameter] | None:
        return type(first(self.body_parameters, none_if_empty=True))

    @cached_property
    def non_body_parameters(self) -> list[CallParameter]:
        return [param for param in self.parameters if not isinstance(param, BodyParameter)]

    @cached_property
    def body_content_schema(self) -> type | None:
        if not (body_params := self.body_parameters):
            return None

        if self.body_parameter_type and issubclass(self.body_parameter_type, RawBodyParameter):
            return bytes

        class_: type = type(self.operation_id, tuple(), {
            p.name: field(default=MISSING if p.default is not NoValue else p.default) for p in body_params})
        class_.__annotations__ = {param.name: param.annotated_schema for param in body_params}
        # noinspection PyTypeChecker
        return dataclasses.dataclass(kw_only=True)(class_)

    @cached_property
    def request_body(self) -> RequestBody | None:
        if (body_parameter_type := self.body_parameter_type) and (body_content_schema := self.body_content_schema):
            mime_type = body_parameter_type.body_mime_type
            content_types = mime_type if is_iterable(mime_type) else [mime_type]
            key_to_additional_param_info = {}
            if mime_type in (ContentType.MultipartForm, ContentType.FormUrlEncoded):
                for body_param in self.body_parameters:
                    if isinstance(body_param, MultipartParameter) and body_param.content_type or body_param.headers:
                        header_to_schema: dict[str, typing.Annotated | type] = {}
                        for body_part_header_param in body_param.headers:
                            param_name = body_part_header_param.name
                            param_schema = body_part_header_param.schema
                            raise_if((present_schema := header_to_schema.get(param_name, param_schema)) != param_schema,
                                     KeyError(f"Multipart header {param_name} is used with different schema in "
                                              f"multiple header parameters: {param_schema=}, {present_schema=}"))
                            header_to_schema[param_name] = param_schema

                        key_to_additional_param_info[body_param.name] = BodyPartDescription(
                            content_type=body_param.allowed_content_type, headers=header_to_schema)
            return RequestBody({
                content_type: Content(body_content_schema, encoding=key_to_additional_param_info)
                for content_type in content_types
            })

    @cached_property
    def code_to_answer(self) -> dict[http_tools.HttpStatusCode, type[http_tools.Answer]]:
        result = {self.answer_type.get_class_status_code(): self.answer_type}
        result |= {ans_t.get_class_status_code(): ans_t for ans_t in self.exception_type_to_answer_type.values()}
        if self.unhandled_exception__answer_type is not None:
            unhandled_exception__answer_code = self.unhandled_exception__answer_type.get_class_status_code()
            result[unhandled_exception__answer_code] = self.unhandled_exception__answer_type
        return result

    def as_endpoint(self) -> Endpoint:
        return Endpoint(
            path=self.path,
            method=self.method,
            operation_id=self.operation_id,
            securities=self.securities,
            parameters=[param for param in self.non_body_parameters if isinstance(param, SpecParameter)],
            request_body=self.request_body,
            code_to_answer=self.code_to_answer,
            summary=self.summary,
            description=self.description,
            tags=self.tags,
            deprecated=self.deprecated,
        )

    def gen_handler(self) -> Callable[[IncomingRequest], Awaitable[Answer | StreamAnswer]]:
        # TODO: check argument fullness
        if not self.securities:
            warnings.warn('Endpoints without securities will be prohibited')

        raise_if(
            inspect.isgeneratorfunction(self.func),
            NotImplementedError("Synchronous generator functions are not supported")
        )
        raise_if(
            inspect.isasyncgenfunction(self.func) and not issubclass(self.answer_type, StreamAnswer),
            TypeError("Endpoints, based on asynchronous generator function, require StreamResponse answer type")
        )

        async def handler(request: IncomingRequest) -> Answer | StreamAnswer:
            security_kwargs = None
            cache_token = resolver_cache.set({})
            security_key_to_exception = {}
            if self.securities:
                for security in self.securities:
                    security_key = f'{security.scheme.key}[{",".join(sorted(security.scopes))}]'
                    try:
                        security_kwargs = await security.evaluate(request)
                        break
                    except (Unauthorized, Forbidden) as e:
                        if security.scheme.do_log:
                            security_key_to_exception[security_key] = repr(e)
            else:
                security_kwargs = {}
            resolver_cache.reset(cache_token)

            if security_kwargs is None:
                is_forbidden = any(isinstance(exc, Forbidden) for exc in security_key_to_exception.values())
                status_code = HttpStatusCode.Forbidden if is_forbidden else HttpStatusCode.Unauthorized
                # noinspection PyTypeChecker
                return JsonableAnswer(ErrorDescription(security_key_to_exception, "ErrorGroup"), status_code)

            try:
                param_kwargs = {
                    name: value for name, param in self.arg_name_to_param.items()
                    if (value := await param.get(request, security_kwargs)) is not NoValue
                }
                if inspect.isasyncgenfunction(self.func):
                    # noinspection PyTypeChecker
                    async_gen: collections.abc.AsyncGenerator = self.func(**(security_kwargs | param_kwargs))
                    first_result = await anext(async_gen)  # most possible exceptions will be raised here
                    stream_answer: StreamAnswer = self.answer_factory(first_result)
                    await stream_answer.prepare(request.request)
                    try:
                        async for value in async_gen:
                            await stream_answer.write(value)
                    except Exception as e:
                        logger.exception(e)
                    return stream_answer
                else:
                    call_result = await acall(self.func(**(security_kwargs | param_kwargs)))
                    return self.answer_factory(call_result)
            except Exception as e:
                for exception_type, answer_type in self.exception_type_to_answer_type.items():
                    if isinstance(e, exception_type):
                        return answer_type(e)

                if self.unhandled_exception__answer_type is not None:
                    traceback.print_exc()
                    logger.exception(e)
                    return self.unhandled_exception__answer_type(e)
                raise

        return handler


class OpenApiServer:
    @dataclass
    class Context(OpenApiWrapper.Context):
        http_server: HttpServer

    @dataclass
    class Config(OpenApiWrapper.Config):
        write_spec_to_file: str = ""  # empty means skip
        http_prefix: str = '/open_api_specification/'
        spec_http_path: str = 'get'  # empty means disable
        ui_http_path: str = 'ui'  # empty means disable

        def __post_init__(self):
            if prefix := self.http_prefix:
                raise_if(not prefix[0] == '/', ValueError(f"Http prefix MUST start with '/', got: {prefix}"))
                raise_if(not prefix[-1] == '/', ValueError(f"Http prefix MUST end with '/', got: {prefix}"))
            if spec_path := self.spec_http_path:
                raise_if('/' in spec_path, NotImplementedError(f"spec_http_path with '/' unsupported: {spec_path}"))
            if ui_path := self.ui_http_path:
                raise_if('/' in ui_path, NotImplementedError(f"ui_http_path with '/' unsupported: {ui_path}"))

    @property
    def spec_path(self) -> str | None:
        if self.config.spec_http_path:
            return yarl.URL(self.config.http_prefix).join(yarl.URL(self.config.spec_http_path)).path

    @property
    def ui_path(self) -> str | None:
        if self.config.ui_http_path:
            return yarl.URL(self.config.http_prefix).join(yarl.URL(self.config.ui_http_path)).path

    def __init__(
            self, config: Config, context: Context,
            wrapper_factory: Callable[[Config, Context], OpenApiWrapper] = OpenApiWrapper
    ):
        self.context = context
        self.config = config
        self.wrapper = wrapper_factory(config, context)
        self._spec_path = None
        if self.spec_path:
            self.context.http_server.register_handler(self.spec_path, self.get_specification)
            if self.ui_path:
                self.context.http_server.register_handler(self.ui_path, self.get_ui)
                ui_sub_path = yarl.URL(self.ui_path).join(yarl.URL('{file_name}')).path
                self.context.http_server.register_handler(ui_sub_path, self.get_ui)
        elif self.ui_path:
            logger.warning('Swagger UI disabled due to disabled spec endpoint')

    async def get_ui(self, request: IncomingRequest) -> FileAnswer:
        file_name = request.key_value_params.get('file_name', 'index.html')
        target = pathlib.Path(__file__).parent / 'swagger-ui' / file_name
        if file_name.endswith('.html'):
            content = target.read_text()
            content_type = ContentType.HTML
        elif file_name.endswith('.css'):
            content = target.read_text()
            content_type = ContentType.Css
        elif file_name.endswith('.js'):
            content = target.read_text()
            content_type = ContentType.JavaScript
        elif file_name.endswith('.png'):
            content = target.read_bytes()
            content_type = ContentType.PNG
        else:
            raise TypeError(f"Unsupported file type: {file_name}")
        if file_name == 'swagger-initializer.js':
            content = content.replace('https://petstore.swagger.io/v2/swagger.json', f'./{self.config.spec_http_path}')
        return FileAnswer(content, content_type=content_type)

    async def get_specification(self, _: IncomingRequest) -> FileAnswer:
        content = custom_dumps(self.wrapper.get_spec_dict(), indent=2)
        return FileAnswer(content, file_name='openapi_spec.json', content_type=ContentType.Json)

    def _write_spec_to_file(self):
        spec = self.wrapper.get_spec_dict()
        if self.config.write_spec_to_file:
            json_res = json.dumps(spec, indent=2)
            with open(self.config.write_spec_to_file, "w") as f:
                f.write(json_res)

    def register_endpoint(self, rpc_endpoint: RpcEndpoint) -> None:
        self.register(rpc_endpoint)

    def register(self, *rpc_endpoints: RpcEndpoint) -> None:
        for endpoint in rpc_endpoints:
            self.context.http_server.register_handler(endpoint.path, endpoint.gen_handler(), [endpoint.method])
            self.wrapper.register(endpoint.as_endpoint())
            self._write_spec_to_file()
