#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import functools
import logging
from dataclasses import dataclass
from json import JSONDecodeError
from typing import Awaitable, Iterable, Callable, Union, Protocol, ClassVar

import aiohttp
import aiohttp.typedefs
import aiohttp.web_request
import aiohttp.web_response
from aiohttp import web
from aiohttp.web_runner import AppRunner, TCPSite
from aiohttp.web_urldispatcher import UrlDispatcher
from async_tools import acall, AsyncInitable, AsyncDeinitable

from .answer import Answer, ErrorAnswer
from .error_middleware import error_middleware
from .http_status_codes import HttpStatusCode
from .request import IncomingRequest, request_from_web_request
from .ws_handler import WebSocketWrapper, ws_from_web_request

logger = logging.getLogger(__name__)
Handler = Union[Callable[[IncomingRequest], Awaitable[Answer]], Callable[[IncomingRequest], Answer]]
WsHandler = Union[
    Callable[[WebSocketWrapper], Awaitable[web.WebSocketResponse]],
    Callable[[WebSocketWrapper], web.WebSocketResponse]
]


class Middleware(Protocol):
    def __call__(
            self, request: web.Request, handler: aiohttp.typedefs.Handler, router: UrlDispatcher
    ) -> web.StreamResponse | Awaitable[web.StreamResponse]:
        ...


class HttpServer(AsyncInitable, AsyncDeinitable):
    default_middlewares: ClassVar[list[Middleware]] = [error_middleware]

    @dataclass
    class Config:
        interface: str = '0.0.0.0'
        port: int = 80
        client_max_size: int = 100 * (1024 ** 2)
        max_line_size: int = 16380
        max_field_size: int = 16380

    @dataclass
    class Context:
        instance_id: str

    async def _add_default_headers(
            self, request: web.Request, handler: aiohttp.typedefs.Handler, *args, **kwargs) -> web.Response:
        response = await handler(request)
        if isinstance(response, Answer):
            for header, value in self.default_header_to_value.items():
                if header not in response.headers:
                    response.headers[header] = value
        return response

    def __init__(self, config: Config, context: Context, *, middleware: Middleware | list[Middleware] | None = None):
        if middleware is None:
            middleware = self.default_middlewares[:]
        if not isinstance(middleware, list):
            middleware = [middleware]

        AsyncInitable.__init__(self)
        AsyncDeinitable.__init__(self)
        self.config = config
        self.default_header_to_value = {"Instance": context.instance_id}
        self.middlewares = [self._add_default_headers] + middleware

        self.runner = AppRunner(
            web.Application(
                middlewares=[self._handle_middlewares], client_max_size=config.client_max_size,
                handler_args={'max_field_size': config.max_field_size, 'max_line_size': config.max_line_size}
            )
        )

        router: UrlDispatcher = self.runner.app.router
        router.add_options('/{tail:.*}', self.options)
        logger.info(f"{type(self).__name__} inited")

    def register_handler(self, path: str, handler: Handler, methods: Iterable[str] = None) -> None:
        logger.info(f"register handler: {path}, {handler}")
        wrapped_handler = self._wrap_handler(handler)
        router: UrlDispatcher = self.runner.app.router
        if methods is None:
            methods = {'GET', 'POST'}
        for method in methods:
            router.add_route(method, path, wrapped_handler)

    def register_ws_handler(self, path: str, handler: WsHandler) -> None:
        logger.info(f"register ws handler: {path}, {handler}")
        wrapped_handler = self._wrap_ws_handler(handler)
        router: UrlDispatcher = self.runner.app.router
        router.add_get(path, wrapped_handler)
        router.add_post(path, wrapped_handler)

    @staticmethod
    async def options(_: aiohttp.web_request.Request) -> Answer:
        return Answer(None, HttpStatusCode.OK, headers={
            "Allow": "GET,OPTIONS,POST",
            "Access-Control-Allow-Headers": "Content-Type"
        })

    async def _async_init(self):
        logger.debug(f"start http server on: http://{self.config.interface}:{self.config.port}")
        await self.runner.setup()
        await TCPSite(self.runner, self.config.interface, self.config.port).start()
        logger.info(f"http server started: http://{self.config.interface}:{self.config.port}")

    async def _async_deinit(self):
        logger.debug(f"stop http server")
        await self.runner.cleanup()
        logger.debug(f"http server stopped")

    async def __aenter__(self):
        await self.async_init()

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.async_deinit()

    def _wrap_handler(self, raw_handler: Handler
                      ) -> Callable[[aiohttp.web_request.Request], Awaitable[aiohttp.web_response.Response]]:
        async def inner(web_request: aiohttp.web_request.Request):
            try:
                return await acall(raw_handler(await request_from_web_request(web_request)))
            except JSONDecodeError as e:
                return ErrorAnswer(error=str(e), status=HttpStatusCode.BadRequest, error_type=type(e).__name__)
        return inner

    def _wrap_ws_handler(self, raw_handler: WsHandler
                         ) -> Callable[[aiohttp.web_request.Request], Awaitable[aiohttp.web.WebSocketResponse]]:
        async def inner(web_request: aiohttp.web_request.Request):
            try:
                return await acall(raw_handler(await ws_from_web_request(web_request)))
            except JSONDecodeError as e:
                return ErrorAnswer(error=str(e), status=HttpStatusCode.BadRequest, error_type=type(e).__name__)
        return inner

    @web.middleware
    async def _handle_middlewares(
            self, request: web.Request, handler: aiohttp.typedefs.Handler
    ) -> web.StreamResponse | Awaitable[web.StreamResponse]:
        for middleware in reversed(self.middlewares):
            handler = functools.update_wrapper(
                functools.partial(middleware, handler=handler, router=self.runner.app.router), handler
            )
        return await handler(request)
