#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>
import copy
from contextlib import contextmanager, asynccontextmanager
from dataclasses import dataclass
from threading import Lock
from typing import Sequence, Any, Iterator, Callable, Awaitable, Iterable

from async_tools import acall
from more_itertools import first
from prometheus_client import Counter, Gauge, Summary, Histogram, Metric, generate_latest
from prometheus_client.metrics import MetricWrapperBase
from prometheus_client.utils import INF

from .tools.call import call
from .tools.timer import Timer


class BaseMetricsRegistry:
    @dataclass
    class Config:
        default_buckets: tuple[float] = (.001, .003, .01, .03, .1, .3, 1., 3., 10., 30., 100., 300., 1000., INF)

    def __init__(self, config: Config) -> None:
        self.config = config
        self._lock = Lock()
        self._name_to_collector: dict[tuple[str, ...], MetricWrapperBase] = {}

    def render_metrics(self) -> bytes:
        return generate_latest(type("registry", (), {'collect': lambda _: self._collect()})())

    def counter(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Counter:
        return self._ensure_collector(Counter, name, labels, doc)

    def gauge(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Gauge:
        return self._ensure_collector(Gauge, name, labels, doc)

    def summary(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Summary:
        return self._ensure_collector(Summary, name, labels, doc)

    def histogram(
            self, name: str, buckets: Sequence[float] | None = None, labels: dict[str, Any] | None = None, doc: str = ''
    ) -> Histogram:
        return self._ensure_collector(Histogram, name, labels, doc, buckets=buckets or self.config.default_buckets)

    def _ensure_collector(
            self, required_type: type, name: str, labels: dict[str, Any] | None = None, doc: str = '', **kwargs
    ) -> Counter | Gauge | Summary | Histogram:
        if not issubclass(required_type, (Counter, Gauge, Summary, Histogram)):
            raise TypeError(f'Inappropriate type: {required_type}')

        with self._lock:
            labels = labels or {}
            label_keys = set(labels.keys())

            collector_name = (name, *label_keys)
            collector = self._name_to_collector.get(collector_name)
            if not collector:
                collector = required_type(name, doc, label_keys, registry=None, **kwargs)
                self._name_to_collector[collector_name] = collector
            elif not isinstance(collector, required_type):
                raise ValueError(f"The name '{name}' is already taken by the {type(collector)}")

            if labels:
                collector = collector.labels(**labels)

        return collector

    def _collect(self) -> Iterable[Metric]:
        with self._lock:
            collectors = copy.copy(list(self._name_to_collector.values()))
        for collector in collectors:
            yield from collector.collect()


class MetricsRegistry(BaseMetricsRegistry):
    @dataclass
    class Tracker:
        labels: dict[str, Any]

    @contextmanager
    def track_progress(self, name: str, labels: dict[str, Any] | None = None, doc: str = '') -> Iterator[None]:
        with self.gauge(name, labels, doc).track_inprogress():
            yield

    @contextmanager
    def track_time(
            self, name: str, buckets: Sequence[float] | None = None, labels: dict[str, Any] | None = None, doc: str = ''
    ) -> Iterator[None]:
        try:
            with Timer() as timer:
                yield
        finally:
            self.histogram(name, buckets, labels, doc).observe(timer.duration)

    @contextmanager
    def track(
            self,
            name: str,
            labels: dict[str, Any] | None = None,
            *,
            except_labels: dict[str, Any] | None = None,
            missing_label_value: Callable[[], Any] | Any | None = None
    ) -> Iterator[Tracker]:
        missing_label_values = []
        with self._track(missing_label_values, name, labels, except_labels) as tracker:
            yield tracker
            missing_label_values.append(call(missing_label_value))

    @asynccontextmanager
    async def atrack(
            self,
            name: str,
            labels: dict[str, Any] | None = None,
            *,
            except_labels: dict[str, Any] | None = None,
            missing_label_value: Callable[[], Awaitable[Any] | Any] | Any | None = None
    ) -> Iterator[Tracker]:
        missing_label_values = []
        with self._track(missing_label_values, name, labels, except_labels) as tracker:
            yield tracker
            missing_label_values.append(await acall(missing_label_value))

    @contextmanager
    def _track(
            self,
            missing_label_values: list[Any],
            name: str,
            labels: dict[str, Any] | None = None,
            except_labels: dict[str, Any] | None = None
    ) -> Iterator[Tracker]:
        labels = labels or {}
        except_labels = except_labels or {}
        required_keys = labels.keys() | except_labels.keys()
        with self.track_progress(f'{name}__progress', labels):
            with self.track_time(f'{name}__spent_time', self.config.default_buckets, labels):
                try:
                    yield self.Tracker(labels)
                    labels |= {key: first(missing_label_values, None) for key in required_keys - labels.keys()}
                except Exception:
                    labels |= except_labels
                    raise
                finally:
                    self.counter(name, labels).inc()
