#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>

from dataclasses import dataclass
from typing import Awaitable

import aiohttp
from aiohttp import web
from http_tools import Answer, HttpStatusCode, HttpServer
from http_tools.mime_types import ContentType

from .monitored_components.monitored_client_session import MonitoredClientSession
from .metrics_registry import MetricsRegistry


class PrometheusController(MetricsRegistry):
    @dataclass
    class Config(MetricsRegistry.Config):
        metrics_path = '/metrics'

    @dataclass
    class Context:
        http_server: HttpServer

    def __init__(self, config: Config, context: Context) -> None:
        super().__init__(config)
        self.config = config
        self.context = context
        self.context.http_server.register_handler(
            self.config.metrics_path,
            lambda _: Answer(self.render_metrics(), status=HttpStatusCode.OK, content_type=ContentType.Text)
        )
        self.context.http_server.middlewares.insert(0, self.incoming_http_requests_middleware)

    def get_monitored_client_session(self) -> MonitoredClientSession:
        return MonitoredClientSession(self)

    @web.middleware
    async def incoming_http_requests_middleware(
            self, request: aiohttp.web.Request, handler: aiohttp.typedefs.Handler, *args, **kwargs
    ) -> web.StreamResponse | Awaitable[web.StreamResponse]:
        labels = {'method': request.method, 'path': request.path, 'status_code': None}
        with self.track('http_server__requests', labels) as tracker:
            response = await handler(request)
            tracker.labels['status_code'] = response.status
        self.counter('http_server__incoming_traffic', labels).inc(request.content_length or 0)
        self.counter('http_server__outgoing_traffic', labels).inc(response.content_length or 0)
        return response
