from dataclasses import dataclass, field
from typing import List, Union

from aiohttp import web
from aiohttp.web_response import Response
from http_tools.error_middleware import Handler
from http_tools.http_server import WsHandler
from prometheus_client import CollectorRegistry, REGISTRY, Counter, Gauge, Histogram
from prometheus_client.utils import INF
from ..tools.timer import Timer


class IncomingHttpRequestsMetricsCollector:
    class Label:
        method = "method"
        path = "path"
        status_code = "status_code"

    @dataclass
    class Context:
        registry: CollectorRegistry = REGISTRY

    @dataclass
    class Config:
        exclude_paths: List[str] = field(default_factory=lambda: ["/metrics"])
        latency_buckets: List[float] = field(
            default_factory=lambda: [.05, .1, .25, .5, .75, 1.0, 2.5, 5, 7.5, 10.0, 30.0, 60.0, INF]
        )

    def __init__(self, config: Config, context: Context) -> None:
        self.config = config
        self.context = context

        self.incoming_http_requests__amount__counter = Counter(
            "incoming_http_requests__amount",
            "Total number of incoming HTTP requests",
            labelnames=[self.Label.method, self.Label.path],
            registry=self.context.registry,
        )

        self.incoming_http_requests__current__gauge = Gauge(
            "incoming_http_requests__current",
            "Number of current incoming HTTP requests",
            labelnames=[self.Label.method, self.Label.path],
            registry=self.context.registry
        )

        self.incoming_http_requests__latency__histogram = Histogram(
            "incoming_http_requests__latency",
            "Incoming HTTP request latency",
            labelnames=[self.Label.method, self.Label.path, self.Label.status_code],
            registry=self.context.registry,
            buckets=self.config.latency_buckets
        )

        self.incoming_http_requests__input_network_traffic__counter = Counter(
            "incoming_http_requests__input_network_traffic",
            "Input network traffic of incoming HTTP requests in bytes",
            labelnames=[self.Label.method, self.Label.path, self.Label.status_code],
            registry=self.context.registry,
        )

        self.incoming_http_requests__output_network_traffic__counter = Counter(
            "incoming_http_requests__output_network_traffic",
            "Output network traffic of incoming HTTP requests in bytes",
            labelnames=[self.Label.method, self.Label.path, self.Label.status_code],
            registry=self.context.registry,
        )

    def create_middleware(self) -> web.middleware:
        @web.middleware
        async def middleware(request: web.Request, handler: Union[Handler, WsHandler], *args, **kwargs) -> Response:
            if request.path in self.config.exclude_paths:
                return await handler(request)

            labels = {self.Label.method: request.method, self.Label.path: request.path}
            timer = Timer()

            self.incoming_http_requests__amount__counter.labels(**labels).inc()

            with timer, self.incoming_http_requests__current__gauge.labels(**labels).track_inprogress():
                response: Response = await handler(request)

            labels[self.Label.status_code] = response.status
            self.incoming_http_requests__latency__histogram.labels(**labels).observe(timer.duration)
            if input_network_traffic := request.content_length:
                self.incoming_http_requests__input_network_traffic__counter.labels(**labels).inc(input_network_traffic)
            if output_network_traffic := response.content_length:
                self.incoming_http_requests__output_network_traffic__counter.labels(**labels).inc(
                    output_network_traffic
                )

            return response

        return middleware
