from typing import Iterable, Sequence

from prometheus_client import CollectorRegistry

from .collectors.wrapped_metrics_collectors import check_mismatched_label_names, WrappedCounter, \
    WrappedGauge, WrappedHistogram, check_mismatched_buckets, WrappedSummary


class MetricsRegistry(CollectorRegistry):
    def ensure_counter(self, name: str, documentation: str = '', labelnames: Iterable[str] = ()) -> WrappedCounter:
        if counter := self._get_collector(name, WrappedCounter):
            check_mismatched_label_names(counter, labelnames)
            return counter
        return WrappedCounter(name, documentation, labelnames, registry=self)

    def ensure_gauge(self, name: str, description: str = '', labelnames: Iterable[str] = ()) -> WrappedGauge:
        if gauge := self._get_collector(name, WrappedGauge):
            check_mismatched_label_names(gauge, labelnames)
            return gauge
        return WrappedGauge(name, description, labelnames, registry=self)

    def ensure_histogram(self,
                         name: str,
                         documentation: str = '',
                         labelnames: Iterable[str] = (),
                         buckets: Sequence[float] = WrappedHistogram.DEFAULT_BUCKETS,
                         ) -> WrappedHistogram:
        if histogram := self._get_collector(name, WrappedHistogram):
            check_mismatched_label_names(histogram, labelnames)
            check_mismatched_buckets(histogram, buckets)
            return histogram
        return WrappedHistogram(name, documentation, labelnames, buckets=buckets, registry=self)

    def ensure_summary(self, name: str, documentation: str = '', labelnames: Iterable[str] = ()) -> WrappedSummary:
        if summary := self._get_collector(name, WrappedSummary):
            check_mismatched_label_names(summary, labelnames)
            return summary
        return WrappedSummary(name, documentation, labelnames, registry=self)

    def _get_collector(self,
                       name: str,
                       required_type: type,
                       ) -> WrappedCounter | WrappedGauge | WrappedHistogram | WrappedSummary | None:
        with self._lock:
            collector = self._names_to_collectors.get(name)
            if collector and ((collector_type := type(collector)) and not issubclass(collector_type, required_type)):
                raise ValueError(f"Collector '{name}' with the '{collector_type}' type already exists ")
            return collector


METRICS_REGISTRY = MetricsRegistry(auto_describe=True)
