#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import logging
from asyncio import TimeoutError, Semaphore
from enum import unique, StrEnum
from typing import Union, Optional, Any, Mapping

from aiohttp import ClientResponse, ClientSession, BasicAuth, ClientTimeout, hdrs
from aiohttp.web_exceptions import HTTPBadRequest, HTTPUnauthorized, HTTPForbidden, HTTPNotFound, \
    HTTPUnprocessableEntity, HTTPBadGateway, HTTPServiceUnavailable, HTTPGatewayTimeout, HTTPInternalServerError, \
    HTTPError
from async_tools import AsyncOnStop, DummyAsyncContextManager

from http_tools.digest_auth import DigestAuth
from http_tools.decoder import Decoder
from http_tools.decoder_register import DecoderRegister
from http_tools.header_utils import parse_content_type_header
from http_tools.http_status_codes import HttpStatusCode
from http_tools.mime_types import ContentType

logger = logging.getLogger(__name__)


class HttpClient(AsyncOnStop):
    @unique
    class Method(StrEnum):
        GET = 'GET'
        POST = 'POST'
        PUT = 'PUT'
        DELETE = 'DELETE'
        PATCH = 'PATCH'

    def __init__(self, session: Optional[ClientSession], default_timeout_sec: float = 100):
        self._session = session
        self._default_timeout_sec = ClientTimeout(default_timeout_sec)
        self.decoder_register = DecoderRegister()

    async def _on_stop(self) -> None:
        await self._session.close()

    async def get(self,
                  url: str,
                  *,
                  params: Optional[dict] = None,
                  headers: Optional[Mapping[str, str]] = None,
                  auth: Optional[BasicAuth] = None,
                  timeout: Optional[float] = None,
                  force_decoder: Decoder = None,
                  proxy_url: Optional[str] = None
                  ) -> Any:
        return await self.request(self.Method.GET, url, params=params, timeout=timeout, auth=auth,
                                  headers=headers, force_decoder=force_decoder, proxy_url=proxy_url)

    async def post(self,
                   url: str,
                   *,
                   params: Optional[dict] = None,
                   json: Union[dict, list, None] = None,
                   data: Optional[bytes] = None,
                   headers: Optional[Mapping[str, str]] = None,
                   auth: Optional[BasicAuth] = None,
                   timeout: Optional[float] = None,
                   force_decoder: Decoder = None,
                   proxy_url: Optional[str] = None
                   ) -> Any:
        return await self.request(self.Method.POST, url, params=params, json=json, data=data, timeout=timeout,
                                  auth=auth, headers=headers, force_decoder=force_decoder, proxy_url=proxy_url)

    async def put(self,
                  url: str,
                  *,
                  params: Optional[dict] = None,
                  json: Union[dict, list, None] = None,
                  data: Optional[bytes] = None,
                  headers: Optional[Mapping[str, str]] = None,
                  auth: Optional[BasicAuth] = None,
                  timeout: Optional[float] = None,
                  force_decoder: Decoder = None,
                  proxy_url: Optional[str] = None
                  ) -> Any:
        return await self.request(self.Method.PUT, url, params=params, json=json, data=data, timeout=timeout,
                                  auth=auth, headers=headers, force_decoder=force_decoder, proxy_url=proxy_url)

    async def delete(self,
                     url: str,
                     *,
                     params: Optional[dict] = None,
                     json: Union[dict, list, None] = None,
                     data: Optional[bytes] = None,
                     headers: Optional[Mapping[str, str]] = None,
                     auth: Optional[BasicAuth] = None,
                     timeout: Optional[float] = None,
                     force_decoder: Decoder = None,
                     proxy_url: Optional[str] = None
                     ) -> Any:
        return await self.request(self.Method.DELETE, url, params=params, json=json, data=data, timeout=timeout,
                                  auth=auth, headers=headers, force_decoder=force_decoder, proxy_url=proxy_url)

    async def patch(self,
                    url: str,
                    *,
                    params: Optional[dict] = None,
                    json: Union[dict, list, None] = None,
                    data: Optional[bytes] = None,
                    headers: Optional[Mapping[str, str]] = None,
                    auth: Optional[BasicAuth] = None,
                    timeout: Optional[float] = None,
                    force_decoder: Decoder = None,
                    proxy_url: Optional[str] = None
                    ) -> Any:
        return await self.request(self.Method.PATCH, url, params=params, json=json, data=data, timeout=timeout,
                                  auth=auth, headers=headers, force_decoder=force_decoder, proxy_url=proxy_url)

    async def request(self,
                      method: Method,
                      url: str,
                      params: Optional[dict] = None,
                      json: Union[dict, list, None] = None,
                      data: Optional[bytes] = None,
                      headers: Optional[Mapping[str, str]] = None,
                      auth: Optional[BasicAuth] = None,
                      timeout: Optional[float] = None,
                      force_decoder: Decoder = None,
                      proxy_url: Optional[str] = None
                      ) -> Any:
        client_timeout = ClientTimeout(timeout) if timeout else self._default_timeout_sec
        try:
            async with self.session.request(method, url, params=params, json=json, data=data, timeout=client_timeout,
                                            auth=auth, headers=headers, proxy=proxy_url) as response:
                return await self._process_response(response, force_decoder)
        except TimeoutError as er:
            logger.error(f"{method} {url} request exceeded the deadline = {client_timeout.total}. URL: {repr(er)}")
            raise er
        except HTTPUnauthorized as er:
            if not isinstance(auth, DigestAuth) or not auth.is_first_request:
                raise er
            auth.set_response_context(response)
            return await self.request(method, url, params, json, data, headers, auth, timeout, force_decoder)
        except HTTPError as er:
            logger.warning(f"{method} {url} request failed. Description: {repr(er)}")
            raise er

    @property
    def session(self) -> ClientSession:
        if self._session is None or self._session.closed:
            self._session = ClientSession()
        return self._session

    async def _process_response_payload(self, response: ClientResponse, force_decoder: Decoder = None) -> Any:
        payload = await response.read()
        if payload:
            content_type_header = response.headers.get(hdrs.CONTENT_TYPE, ContentType.Octet)
            content_type, content_type_info = parse_content_type_header(content_type_header)
            charset = content_type_info.get('charset')
            if charset:
                payload = payload.decode(charset)

            decoder = force_decoder if force_decoder is not None else self.decoder_register.pick_decoder(content_type)
            if decoder is not None:
                payload = decoder(payload)

        return payload

    async def _process_response(self, response: ClientResponse, force_decoder: Decoder = None) -> Any:
        if response.status in {HttpStatusCode.NoContent}:
            return

        payload = await self._process_response_payload(response, force_decoder)

        if response.status in {HttpStatusCode.OK, HttpStatusCode.Created, HttpStatusCode.Accepted}:
            return payload
        elif response.status == HttpStatusCode.BadRequest:
            raise HTTPBadRequest(reason=payload)
        elif response.status == HttpStatusCode.Unauthorized:
            raise HTTPUnauthorized(reason=payload)
        elif response.status == HttpStatusCode.Forbidden:
            raise HTTPForbidden(reason=payload)
        elif response.status == HttpStatusCode.NotFound:
            raise HTTPNotFound(reason=payload)
        elif response.status == HttpStatusCode.UnprocessableEntity:
            raise HTTPUnprocessableEntity(reason=payload)
        elif response.status == HttpStatusCode.BadGateway:
            raise HTTPBadGateway(reason=payload)
        elif response.status == HttpStatusCode.ServiceUnavailable:
            raise HTTPServiceUnavailable(reason=payload)
        elif response.status == HttpStatusCode.GatewayTimeout:
            raise HTTPGatewayTimeout(reason=payload)
        else:
            logger.warning(f"Not processed response status code = {response.status}. body: {payload}")
            raise HTTPInternalServerError(reason=payload)
