#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import functools
from dataclasses import field, dataclass
from typing import Mapping, Callable, Awaitable, Any, Literal, get_origin, get_args, Iterable

from async_tools import acall
from dict_caster.extras import first
from frozendict import frozendict
from http_tools import IncomingRequest
from init_helpers import raise_if

from .security_requirement import SecurityRequirement
from ..exceptions import AuthInfoException
from ..scheme import SecurityScheme, AuthInfo


@dataclass(frozen=True, slots=True)
class MonoSecurity(SecurityRequirement):
    """Non OpenAPI specification element, used for security requirements evaluation"""
    argument_name_to_getter: frozendict[str, Callable[[AuthInfo], Awaitable | Any]] = field(
        default_factory=frozendict, hash=False)

    def __init__(self, scheme_to_scopes: Mapping[SecurityScheme, Iterable[str]] | None = None,
                 argument_name_to_getter: Mapping[str, Callable[[AuthInfo], Awaitable | Any]] | None = None):
        if len(scheme_to_scopes) != 1:
            raise ValueError(f'MonoSecurity expects one scheme_to_scopes, got {scheme_to_scopes}')
        SecurityRequirement.__init__(self, scheme_to_scopes)
        object.__setattr__(self, 'argument_name_to_getter', frozendict(argument_name_to_getter or {}))

    @property
    @functools.cache
    def scheme(self) -> SecurityScheme:
        return first(self.scheme_to_scopes.keys())

    @property
    @functools.cache
    def scopes(self) -> frozenset[str]:
        return first(self.scheme_to_scopes.values())

    async def evaluate(self, incoming_request: IncomingRequest) -> dict:
        auth_info = await self.scheme.evaluate(incoming_request, self.scopes)
        try:
            return {name: await acall(getter(auth_info)) for name, getter in self.argument_name_to_getter.items()}
        except (KeyError, TypeError, ValueError) as e:
            raise AuthInfoException(*e.args) from e

    def use(self, *args: str, **kwargs: type[Literal] | Callable[[AuthInfo], Any]) -> 'MonoSecurity':
        return self.has(*args, **kwargs)

    def has(self, *args: str, **kwargs: type[Literal] | Callable[[AuthInfo], Any]) -> 'MonoSecurity':
        argument_name_to_getter = dict(self.argument_name_to_getter)
        for key, value in kwargs.items():
            if get_origin(value) is Literal:
                args_args = get_args(value)
                raise_if(len(args_args) != 1, TypeError(f'Only single literal values allowed, got: {args_args!r}'))
                argument_name_to_getter[key] = lambda _: args_args[0]
            elif callable(value):
                argument_name_to_getter[key] = value
            else:
                raise TypeError(f'Got unexpected value: {value}, Allowed: Literal/Callable[[AuthInfo], Any]')
        return MonoSecurity(scheme_to_scopes={self.scheme: self.scopes | set(args)},
                            argument_name_to_getter=argument_name_to_getter)

    def _repr(self) -> list[str]:
        parts = SecurityRequirement._repr(self)
        parts += [f'argument_name_to_getter={dict(self.argument_name_to_getter)}'] if self.argument_name_to_getter else []
        return parts

    def __repr__(self):
        return f'{self.__class__.__name__}({", ".join(self._repr())})'

    __str__ = __repr__
