#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import functools
import logging
import sys
from dataclasses import dataclass
from json import JSONDecodeError
from typing import Awaitable, List, Iterable
from typing import Callable, Union

import aiohttp
import aiohttp.web_request
import aiohttp.web_response
from aiohttp import web
from aiohttp.web_runner import AppRunner, TCPSite
from aiohttp.web_urldispatcher import UrlDispatcher
from async_tools import acall, AsyncOnStop, AsyncOnStart
from async_tools.async_deinitable import AsyncDeinitable
from async_tools.async_initable import AsyncInitable

from .answer import Answer, ErrorAnswer
from .error_middleware import error_middleware
from .http_status_codes import HttpStatusCode
from .request import IncomingRequest, request_from_web_request
from .ws_handler import WebSocketWrapper, ws_from_web_request

logger = logging.getLogger(__name__)
Handler = Union[Callable[[IncomingRequest], Awaitable[Answer]], Callable[[IncomingRequest], Answer]]
WsHandler = Union[Callable[[WebSocketWrapper], Awaitable[web.WebSocketResponse]], Callable[[WebSocketWrapper], web.WebSocketResponse]]


if sys.version_info >= (3, 8):
    from typing import Protocol
    class Middleware(Protocol):
        def __call__(self, request: web.Request, handler: Handler, router: UrlDispatcher) -> web.Response: ...
else:
    Middleware = Callable[[web.Request, Handler, UrlDispatcher], web.Response]


class HttpServer(AsyncInitable, AsyncDeinitable):
    @dataclass
    class Config:
        interface: str = '0.0.0.0'
        port: int = 80
        client_max_size: int = 100 * (1024 ** 2)
        max_line_size: int = 16380
        max_field_size: int = 16380

    def __init__(self, config: Config, *, middleware: Union[Middleware, List[Middleware]] = error_middleware) -> None:
        AsyncInitable.__init__(self)
        AsyncDeinitable.__init__(self)
        self.config = config

        @web.middleware
        def wrapped_middleware(middleware_):
            @functools.wraps(middleware_)
            def wrapper(*args, **kwargs):
                return middleware_(*args, **kwargs, router=self.runner.app.router)
            return wrapper

        if not isinstance(middleware, list):
            middleware = [middleware]

        self.runner = AppRunner(
            web.Application(
                middlewares=[wrapped_middleware(m) for m in middleware], client_max_size=config.client_max_size,
                handler_args={'max_field_size': config.max_field_size, 'max_line_size': config.max_line_size}
            )
        )

        router: UrlDispatcher = self.runner.app.router
        router.add_options('/{tail:.*}', self.options)
        logger.info(f"{type(self).__name__} inited")

    def register_handler(self, path: str, handler: Handler, methods: Iterable[str] = None) -> None:
        logger.info(f"register handler: {path}, {handler}")
        wrapped_handler = _wrap_handler(handler)
        router: UrlDispatcher = self.runner.app.router
        if methods is None:
            methods = {'GET', 'POST'}
        for method in methods:
            router.add_route(method, path, wrapped_handler)

    def register_ws_handler(self, path: str, handler: WsHandler) -> None:
        logger.info(f"register ws handler: {path}, {handler}")
        wrapped_handler = _wrap_ws_handler(handler)
        router: UrlDispatcher = self.runner.app.router
        router.add_get(path, wrapped_handler)
        router.add_post(path, wrapped_handler)

    @staticmethod
    async def options(_: aiohttp.web_request.Request) -> Answer:
        return Answer(None, HttpStatusCode.OK, headers={
            "Allow": "GET,OPTIONS,POST",
            "Access-Control-Allow-Headers": "Content-Type"
        })

    async def _async_init(self):
        logger.debug(f"start http server on: http://{self.config.interface}:{self.config.port}")
        await self.runner.setup()
        await TCPSite(self.runner, self.config.interface, self.config.port).start()
        logger.info(f"http server started: http://{self.config.interface}:{self.config.port}")

    async def _async_deinit(self):
        logger.debug(f"stop http server")
        await self.runner.cleanup()
        logger.debug(f"http server stopped")

    async def __aenter__(self):
        await self.async_init()

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.async_deinit()


def _wrap_handler(raw_handler: Handler
                  ) -> Callable[[aiohttp.web_request.Request], Awaitable[aiohttp.web_response.Response]]:
    async def inner(web_request: aiohttp.web_request.Request):
        try:
            return await acall(raw_handler(await request_from_web_request(web_request)))
        except JSONDecodeError as e:
            return ErrorAnswer(error=str(e), status=HttpStatusCode.BadRequest, error_type=type(e).__name__)
    return inner


def _wrap_ws_handler(raw_handler: WsHandler
                     ) -> Callable[[aiohttp.web_request.Request], Awaitable[aiohttp.web.WebSocketResponse]]:
    async def inner(web_request: aiohttp.web_request.Request):
        try:
            return await acall(raw_handler(await ws_from_web_request(web_request)))
        except JSONDecodeError as e:
            return ErrorAnswer(error=str(e), status=HttpStatusCode.BadRequest, error_type=type(e).__name__)
    return inner
