#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#

from typing import Mapping, Optional, Any, Dict, Union

from aiohttp import BasicAuth, FormData
from aiohttp.web_exceptions import HTTPServiceUnavailable
from init_helpers.custom_json import Jsonable
from multidict import CIMultiDict

from http_tools.http_server_connector import HttpServerConnector
from http_tools.request_tracing import CONTEXT_VAR_REQUEST_ID, X_REQUEST_ID


class AbmServiceConnector:
    Config = HttpServerConnector.Config
    Context = HttpServerConnector.Context

    def __init__(self, config: Config, context: Context) -> None:
        self._connector = HttpServerConnector(config, context)

    async def get(self,
                  path: str,
                  url_query: Optional[Mapping[str, Any]] = None,
                  headers: Optional[Mapping[str, str]] = None,
                  auth: Optional[BasicAuth] = None,
                  timeout_sec: Optional[float] = None) -> Any:
        headers = self._prepare_headers(headers)
        answer = await self._connector.get(path, url_query, headers, auth, timeout_sec)
        return self._parse_answer(answer)

    async def post(self,
                   path: str,
                   payload: Union[bytes, Jsonable],
                   url_query: Optional[Mapping[str, Any]] = None,
                   headers: Optional[Mapping[str, str]] = None,
                   auth: Optional[BasicAuth] = None,
                   timeout_sec: Optional[float] = None) -> Any:
        headers = self._prepare_headers(headers)
        method = self._connector.post if isinstance(payload, (bytes, FormData)) else self._connector.post_json
        answer = await method(path, payload, url_query, headers, auth, timeout_sec)
        return self._parse_answer(answer)

    def _parse_answer(self, answer: Union[Dict[str, Any], bytes]) -> Any:
        if not isinstance(answer, dict):
            return answer
        try:
            if not answer['done']:
                raise HTTPServiceUnavailable(reason=answer['error'])
            return answer.get('result')
        except KeyError as e:
            raise HTTPServiceUnavailable(reason=f'Incorrect answer from {self._connector.config.url}: {answer}') from e

    @staticmethod
    def _prepare_headers(required_headers: Optional[Mapping[str, str]]) -> CIMultiDict:
        headers = CIMultiDict()
        if required_headers:
            for key, value in required_headers.items():
                headers[key] = value
        if request_id := CONTEXT_VAR_REQUEST_ID.get():
            headers[X_REQUEST_ID] = request_id
        return headers
