#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import functools
from abc import ABC
from dataclasses import dataclass
from typing import Mapping, Iterable, ClassVar, Any

from dict_caster.extras import first
from frozendict import frozendict
from init_helpers import Jsonable
from init_helpers.dict_to_dataclass import NoValue

from openapi_tools.spec.spec_resource import SpecResource, SpecRef
from .base_schema import BaseSchema


@dataclass(frozen=True, slots=True)
class CompositeSchema(BaseSchema, ABC):
    composition_keyword: ClassVar[str]
    variants: Iterable[BaseSchema]
    discriminator_property_name: str = ''
    discriminator_value_to_schema: Mapping[Jsonable, BaseSchema] | None = None

    def __post_init__(self):
        object.__setattr__(self, 'variants', tuple(self.variants))

    @property
    def has_default(self) -> bool:
        return self.default is not NoValue or any(v for v in self.variants if v.default is not NoValue)

    @functools.cache
    def get_spec_dependencies(self) -> frozenset['SpecResource']:
        return frozenset(BaseSchema.get_spec_dependencies(self) | set(self.variants))

    def get_spec_dict(self, dependency_to_ref: Mapping['SpecResource', SpecRef | dict]) -> frozendict[str, Jsonable]:
        result = {self.composition_keyword: [dependency_to_ref[option] for option in self.variants]}
        if self.discriminator_property_name:
            if set(self.discriminator_value_to_schema.values()) != set(self.variants):
                raise ValueError(f'{self.discriminator_value_to_schema=} must have all values of {self.variants=}')
            result['discriminator'] = {
                'propertyName': self.discriminator_property_name,
                'mapping': {value: dependency_to_ref[schema] for value, schema in self.discriminator_value_to_schema}
            }
        return frozendict(BaseSchema.get_spec_dict(self, dependency_to_ref) | result)

    @functools.cache
    def _get_repr_parts(self) -> tuple[str, ...]:
        parts = list(BaseSchema._get_repr_parts(self))
        parts += [f'{self.variants!r}'] if self.variants else []
        if self.discriminator_property_name:
            parts.append(f'discriminator_property_name={self.discriminator_property_name!r}')
        return tuple(parts)

    def __repr__(self):
        return f'{self.__class__.__name__}({", ".join(self._get_repr_parts())})'

    __str__ = __repr__
