from typing import Iterable, Sequence

from prometheus_client import CollectorRegistry

from .collectors.wrapped_metrics_collectors import check_mismatched_label_names, CounterMetric, \
    GaugeMetric, HistogramMetric, check_mismatched_buckets, SummaryMetric


class MetricsRegistry(CollectorRegistry):
    def ensure_counter(self, name: str, documentation: str = '', labelnames: Iterable[str] = ()) -> CounterMetric:
        if counter := self._get_collector(name, CounterMetric):
            check_mismatched_label_names(counter, labelnames)
            return counter
        return CounterMetric(name, documentation, labelnames, registry=self)

    def ensure_gauge(self, name: str, description: str = '', labelnames: Iterable[str] = ()) -> GaugeMetric:
        if gauge := self._get_collector(name, GaugeMetric):
            check_mismatched_label_names(gauge, labelnames)
            return gauge
        return GaugeMetric(name, description, labelnames, registry=self)

    def ensure_histogram(self,
                         name: str,
                         documentation: str = '',
                         labelnames: Iterable[str] = (),
                         buckets: Sequence[float] = HistogramMetric.DEFAULT_BUCKETS,
                         ) -> HistogramMetric:
        if histogram := self._get_collector(name, HistogramMetric):
            check_mismatched_label_names(histogram, labelnames)
            check_mismatched_buckets(histogram, buckets)
            return histogram
        return HistogramMetric(name, documentation, labelnames, buckets=buckets, registry=self)

    def ensure_summary(self, name: str, documentation: str = '', labelnames: Iterable[str] = ()) -> SummaryMetric:
        if summary := self._get_collector(name, SummaryMetric):
            check_mismatched_label_names(summary, labelnames)
            return summary
        return SummaryMetric(name, documentation, labelnames, registry=self)

    def _get_collector(self,
                       name: str,
                       required_type: type,
                       ) -> CounterMetric | GaugeMetric | HistogramMetric | SummaryMetric | None:
        with self._lock:
            collector = self._names_to_collectors.get(name)
            if collector and ((collector_type := type(collector)) and not issubclass(collector_type, required_type)):
                raise ValueError(f"Collector '{name}' with the '{collector_type}' type already exists ")
            return collector


METRICS_REGISTRY = MetricsRegistry(auto_describe=True)
