#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import abc
import contextvars
from dataclasses import dataclass, field
from typing import TypeAlias, TypeVar, Callable, Awaitable, Iterable, Literal, Any, Mapping, TYPE_CHECKING, Generic, \
    ClassVar

from async_tools import acall
from frozendict import frozendict
from http_tools import IncomingRequest

from init_helpers import DataclassProtocol, raise_if, Jsonable

from openapi_tools.spec import SpecResource, SpecRef, ReferencableResource
from ..exceptions import Forbidden
from .security_scheme_type import SecuritySchemeType

if TYPE_CHECKING:
    from ..requirement import SecurityRequirement

AuthToken: TypeAlias = str
AuthInfo = TypeVar("AuthInfo", bound=dict | DataclassProtocol)
resolver_cache = contextvars.ContextVar('resolver_cache')


@dataclass(frozen=True, slots=True, repr=False)
class SecurityScheme(Generic[AuthInfo], ReferencableResource, abc.ABC):
    can_be_inlined: ClassVar[bool] = False
    spec_path: ClassVar[tuple[str, ...]] = ('components', 'security')
    type_: ClassVar[SecuritySchemeType]
    resolver: Callable[[AuthToken], Awaitable[tuple[AuthInfo, Iterable[str]]] | tuple[AuthInfo, Iterable[str]]]
    do_log: bool = field(default=True, kw_only=True)

    @classmethod
    def get_spec_dependencies(cls) -> frozenset['SpecResource']:
        return frozenset()

    @abc.abstractmethod
    def get_spec_dict(self, dependency_to_ref: Mapping[SpecResource, SpecRef]) -> frozendict[str, Jsonable]:
        return frozendict({'type': self.type_})

    def get_spec_ref(self, key: str = '') -> str:
        return key if key else self.get_key()

    @abc.abstractmethod
    async def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        pass

    async def evaluate(self, incoming_request: IncomingRequest, required_scopes: frozenset[str]) -> AuthInfo:
        cache = resolver_cache.get({})
        if self not in cache:
            auth_token = await self._extract_token(incoming_request)
            try:
                cache[self] = await acall(self.resolver(auth_token))
            except Exception as e:
                cache[self] = e
        result = cache[self]
        if isinstance(result, Exception):
            raise result
        auth_info, allowed_scopes = cache[self]
        allowed_scopes = allowed_scopes if isinstance(allowed_scopes, set) else set(allowed_scopes)
        raise_if(missing_scopes := required_scopes - allowed_scopes,
                 Forbidden(f'missing scopes: {", ".join(sorted(missing_scopes))}'))
        return auth_info

    # @property
    # @abc.abstractmethod
    # def key(self) -> str:
    #     pass

    def has(self, *args: str, **kwargs: type[Literal] | Callable[[AuthInfo], Any]) -> 'SecurityRequirement':
        from ..requirement import SecurityRequirement
        return SecurityRequirement({self: args}).use(**kwargs)

    def _repr(self) -> list[str]:
        return [f"resolver={self.resolver!r}"]

    def __repr__(self):
        return f'{self.__class__.__name__}({", ".join(self._repr())})'

    __str__ = __repr__
