from abc import ABC
from typing import Sequence, Iterable

from prometheus_client import Counter, Gauge, Histogram, Summary
from prometheus_client.metrics import MetricWrapperBase
from prometheus_client.utils import INF


class WrappedMetricsCollector(MetricWrapperBase, ABC):
    @property
    def label_names(self) -> set[str]:
        return set(self._labelnames)


class WrappedCounter(Counter, WrappedMetricsCollector):
    pass


class WrappedGauge(Gauge, WrappedMetricsCollector):
    pass


class WrappedSummary(Summary, WrappedMetricsCollector):
    pass


class WrappedHistogram(Histogram, WrappedMetricsCollector):
    @property
    def buckets(self) -> set[float]:
        return set(self._upper_bounds)


def check_mismatched_label_names(collector: WrappedMetricsCollector, required_label_names: Iterable[str]) -> None:
    if diff := collector.label_names ^ set(required_label_names):
        raise ValueError(f'mismatched label_names: {diff}')


def check_mismatched_buckets(collector: WrappedHistogram, required_buckets: Sequence[float]) -> None:
    required_buckets = set(required_buckets)
    if required_buckets:
        required_buckets.add(INF)
    if diff := collector.buckets ^ required_buckets:
        raise ValueError(f'mismatched buckets: {diff}')
