#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Awaitable, Iterator, Any

import aiohttp
from aiohttp import web
from http_tools import Answer, HttpStatusCode, HttpServer
from http_tools.mime_types import ContentType
from prometheus_client import generate_latest

from .monitored_components.monitored_client_session import MonitoredClientSession
from .metrics_registry import MetricsRegistry


class PrometheusController(MetricsRegistry):
    @dataclass
    class Config(MetricsRegistry.Config):
        metrics_path = '/metrics'

    @dataclass
    class Context:
        http_server: HttpServer

    def __init__(self, config: Config, context: Context) -> None:
        super().__init__(config)
        self.config = config
        self.context = context
        self.context.http_server.register_handler(
            self.config.metrics_path,
            lambda _: Answer(generate_latest(self._registry), status=HttpStatusCode.OK, content_type=ContentType.Text),
        )
        # TODO register incoming_http_requests_middleware

    def get_monitored_client_session(self) -> MonitoredClientSession:
        return MonitoredClientSession(self)

    @web.middleware
    async def incoming_http_requests_middleware(
            self, request: aiohttp.web.Request, handler: aiohttp.typedefs.Handler, *args, **kwargs
    ) -> web.StreamResponse | Awaitable[web.StreamResponse]:
        labels = {'method': request.method, 'path': request.path}
        self.counter('http_server__accepted_requests', labels).inc()
        self.counter('http_server__incoming_traffic', labels).inc(request.content_length or 0)

        with self.track_progress('http_server__progress_requests', labels):
            labels['status_code'] = None
            with self.track_time('http_server___requests_latency', self.config.default_buckets, labels):
                response = await handler(request)
                labels['status_code'] = response.status

        self.counter('http_server__outgoing_traffic', labels).inc(response.content_length or 0)

        return response

    @contextmanager
    def track(
            self, task_name: str, labels: dict[str, Any] | None = None, *,
            except_labels: dict[str, Any] | None = None,  missing_label_value: Any = None
    ) -> Iterator[None]:
        labels = labels or {}
        except_labels = except_labels or {}
        required_keys = labels.keys() | except_labels.keys()
        with self.track_progress(f'task__{task_name}__progress', labels):
            with self.track_time(f'task__{task_name}__duration', self.config.default_buckets, labels):
                try:
                    yield
                    labels |= {key: missing_label_value for key in required_keys - labels.keys()}
                except Exception as er:
                    labels |= except_labels
                    raise er
                finally:
                    self.counter(f'task__{task_name}', labels).inc()
