#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>
import re

import yarl
from aiohttp import ClientSession, ClientResponse
from aiohttp.typedefs import StrOrURL

from ..metrics_registry import MetricsRegistry
from ..request_id_exemplar import get_request_id_exemplar


class MonitoredClientSession(ClientSession):
    def __init__(
            self, metrics: MetricsRegistry, url_pattern_to_replacer: dict[re.Pattern | str, str] | None = None,
            *args, **kwargs
    ) -> None:
        self._url_pattern_to_replacer = {
            pattern if isinstance(pattern, re.Pattern) else re.compile(pattern): replacer
            for pattern, replacer in (url_pattern_to_replacer or {}).items()
        }
        self._metrics = metrics
        super().__init__(*args, **kwargs)

    def _prepare_url(self, str_or_url: StrOrURL) -> str:
        url = str(yarl.URL(str_or_url).with_query({}))
        for pattern, replacer in self._url_pattern_to_replacer.items():
            if match := re.fullmatch(pattern, url):
                return replacer.format(**match.groupdict())
        return url

    async def _request(
            self, method: str, str_or_url: StrOrURL, extended_metrics_labels: dict[str, str] | None = None, **kwargs
    ) -> ClientResponse:
        exemplar = get_request_id_exemplar()
        labels = {'method': method.upper(), 'url': self._prepare_url(str_or_url)}
        labels |= extended_metrics_labels or {}
        self._metrics.counter('http_client__submitted_requests', labels).inc(exemplar=exemplar)
        with self._metrics.track('http_client__requests', labels, exemplar=exemplar) as tracker:
            try:
                response = await super()._request(method, str_or_url, **kwargs)
                tracker.labels['status_code'] = response.status
            except Exception as er:
                tracker.labels['error'] = type(er).__name__
                raise
        self._metrics.counter('http_client__input_traffic', labels).inc(response.content_length or 0, exemplar=exemplar)
        return response
