#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>
from contextlib import contextmanager
from dataclasses import dataclass
from threading import Lock
from typing import Sequence, Any, Iterator

from prometheus_client import CollectorRegistry, Counter, Gauge, Summary, Histogram
from prometheus_client.utils import INF

from .tools.timer import Timer


class MetricsRegistry:
    @dataclass
    class Config:
        default_buckets: tuple[float] = (.001, .003, .01, .03, .1, .3, 1., 3., 10., 30., 100., 300., 1000., INF)

    def __init__(self, config: Config) -> None:
        self.config = config
        self._lock = Lock()
        self._registry = CollectorRegistry(auto_describe=True)

    def counter(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Counter:
        return self._ensure_collector(Counter, name, labels, doc)

    def gauge(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Gauge:
        return self._ensure_collector(Gauge, name, labels, doc)

    def summary(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Summary:
        return self._ensure_collector(Summary, name, labels, doc)

    def histogram(
            self, name: str, buckets: Sequence[float] | None = None, labels: dict[str, Any] | None = None, doc: str = ''
    ) -> Histogram:
        return self._ensure_collector(Histogram, name, labels, doc, buckets=buckets or self.config.default_buckets)

    @contextmanager
    def track_progress(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Iterator[None]:
        with self.gauge(name, labels, doc).track_inprogress():
            yield

    @contextmanager
    def track_time(
            self, name: str, buckets: Sequence[float] | None = None, labels: dict[str, Any] | None = None, doc: str = ''
    ) -> Iterator[None]:
        try:
            with Timer() as timer:
                yield
        finally:
            self.histogram(name, buckets, labels, doc).observe(timer.duration)

    def _ensure_collector(
            self, required_type: type, name: str, labels: dict[str, Any] | None = None, doc: str = '', **kwargs
    ) -> Counter | Gauge | Summary | Histogram:
        if not issubclass(required_type, (Counter, Gauge, Summary, Histogram)):
            raise TypeError(f'Inappropriate type: {required_type}')

        with self._lock:
            labels = labels or {}

            collector = self._registry._names_to_collectors.get(name)  # bad library interface
            if not collector:
                collector = required_type(name, doc, labels.keys(), registry=self._registry, **kwargs)
            elif not isinstance(collector, required_type):
                raise ValueError(f"Collector '{name}' with the '{type(collector)}' type already exists ")

            if label_values := labels.values():
                collector = collector.labels(*label_values)

        return collector
