#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import json
import logging
from asyncio import Semaphore
from dataclasses import dataclass, field
from typing import Mapping, Optional, Any

import yarl
from aiohttp import ClientSession, BasicAuth, FormData, hdrs
from async_tools import AsyncOnStop, DummyAsyncContextManager
from init_helpers.custom_json import custom_dumps, Jsonable
from multidict import MultiMapping, CIMultiDict

from .http_client import HttpClient

logger = logging.getLogger(__name__)

DEFAULT_PROTOCOL = "http"
PROTOCOL_SEPARATOR = "://"


class HttpServerConnector(AsyncOnStop):
    @dataclass
    class Config:
        location: str = ""  # deprecated, use url instead
        url: str = ""  # use it
        timeout_sec: float = 100
        max_connections: int = 10
        proxy_url: str = None

        def __post_init__(self):
            server_dirty_url = self.url or self.location
            if server_dirty_url is None:
                raise ValueError('required "url" or "location", got None')
            self.url = self.location = self._construct_url(server_dirty_url)
            self.location = self.url  # compatibility
            if self.proxy_url is not None:
                self.proxy_url = self._construct_url(self.proxy_url)

        @staticmethod
        def _construct_url(dirty_url):
            parts = dirty_url.split(PROTOCOL_SEPARATOR, 1)
            protocol = parts[0] if len(parts) > 1 else DEFAULT_PROTOCOL
            non_protocol = parts[-1]
            return f'{protocol}{PROTOCOL_SEPARATOR}{non_protocol}'

    @dataclass
    class Context:
        session: ClientSession
        instance_id: str = field(default="", kw_only=True)

    def __init__(self, config: Config, context: Context):
        logger.info(f"initiating {type(self).__name__}, {config.url}")

        self.config = config
        self.context = context
        self._requester = HttpClient(context.session)
        self._semaphore =\
            Semaphore(self.config.max_connections) if self.config.max_connections else DummyAsyncContextManager()

    async def _on_stop(self):
        logger.info(f"stopping {type(self).__name__}, {self.config.url}")
        await self._requester._on_stop()

    def _get_url(self, path: str) -> str:
        yarl_url = yarl.URL(self.config.url)
        if path:
            yarl_url = yarl_url / path.lstrip('/')
        return str(yarl_url)

    async def post_json(self, path: str, payload: Jsonable, url_query: MultiMapping = None,
                        headers: Optional[Mapping[str, str]] = None, auth: Optional[BasicAuth] = None,
                        timeout_sec: Optional[float] = None) -> Any:

        jsom_payload = custom_dumps(payload)
        payload = json.loads(jsom_payload)
        url = self._get_url(path)
        timeout_sec = timeout_sec if timeout_sec else self.config.timeout_sec
        headers = self._ensure_headers(headers)
        async with self._semaphore:
            result = await self._requester.post(url, params=url_query, json=payload, headers=headers, auth=auth,
                                                proxy_url=self.config.proxy_url, timeout=timeout_sec)
        return result

    async def put_json(self, path: str, payload: Jsonable, url_query: MultiMapping = None,
                       headers: Optional[Mapping[str, str]] = None, auth: Optional[BasicAuth] = None,
                       timeout_sec: Optional[float] = None) -> Any:
        jsom_payload = custom_dumps(payload)
        payload = json.loads(jsom_payload)
        url = self._get_url(path)
        timeout_sec = timeout_sec if timeout_sec else self.config.timeout_sec
        headers = self._ensure_headers(headers)
        async with self._semaphore:
            result = await self._requester.put(url, params=url_query, json=payload, headers=headers, auth=auth,
                                               proxy_url=self.config.proxy_url, timeout=timeout_sec)
        return result

    async def post(self, path: str, payload: bytes | FormData, url_query: Mapping = None,
                   headers: Optional[Mapping[str, str]] = None, auth: Optional[BasicAuth] = None,
                   timeout_sec: Optional[float] = None) -> Any:
        url = self._get_url(path)
        timeout_sec = timeout_sec if timeout_sec else self.config.timeout_sec
        headers = self._ensure_headers(headers)
        async with self._semaphore:
            result = await self._requester.post(url, params=url_query, data=payload, headers=headers, auth=auth,
                                                proxy_url=self.config.proxy_url, timeout=timeout_sec)
        return result

    async def delete(self, path: str, payload: bytes = None, url_query: Mapping = None,
                     headers: Optional[Mapping[str, str]] = None, auth: Optional[BasicAuth] = None,
                     timeout_sec: Optional[float] = None) -> Any:
        url = self._get_url(path)
        timeout_sec = timeout_sec if timeout_sec else self.config.timeout_sec
        headers = self._ensure_headers(headers)
        async with self._semaphore:
            result = await self._requester.delete(url, params=url_query, data=payload, headers=headers, auth=auth,
                                                  proxy_url=self.config.proxy_url, timeout=timeout_sec)
        return result

    async def get(self, path: str, url_query: Mapping = None, headers: Optional[Mapping[str, str]] = None,
                  auth: Optional[BasicAuth] = None,
                  timeout_sec: Optional[float] = None) -> Any:
        url = self._get_url(path)
        timeout_sec = timeout_sec if timeout_sec else self.config.timeout_sec
        headers = self._ensure_headers(headers)
        async with self._semaphore:
            result = await self._requester.get(url, params=url_query, headers=headers, auth=auth,
                                               proxy_url=self.config.proxy_url, timeout=timeout_sec)
        return result

    async def patch_json(self, path: str, payload: Jsonable, url_query: MultiMapping = None,
                         headers: Optional[Mapping[str, str]] = None, auth: Optional[BasicAuth] = None,
                         timeout_sec: Optional[float] = None) -> Any:
        jsom_payload = custom_dumps(payload)
        payload = json.loads(jsom_payload)
        url = self._get_url(path)
        timeout_sec = timeout_sec if timeout_sec else self.config.timeout_sec
        headers = self._ensure_headers(headers)
        async with self._semaphore:
            result = await self._requester.patch(url, params=url_query, json=payload, headers=headers, auth=auth,
                                                 proxy_url=self.config.proxy_url, timeout=timeout_sec)
        return result

    def _ensure_headers(self, required_headers: Mapping[str, str] | None) -> CIMultiDict[str, str]:
        headers = CIMultiDict()
        headers.update(required_headers or {})
        if self.context.instance_id:
            headers[hdrs.USER_AGENT] = self.context.instance_id
        return headers
