#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import abc
import contextvars
import itertools
import typing
from dataclasses import dataclass, field
from enum import StrEnum
from functools import cached_property
from types import MappingProxyType
from typing import Callable, Any, Iterable, TypeVar, Awaitable, TypeAlias, Literal, Tuple

from async_tools import acall
from dict_caster.extras import to_list
from http_tools import IncomingRequest

from .extras import raise_if, DataclassProtocol
from .parameter import ParameterLocation, QueryParameter, PathParameter, HeaderParameter, SpecParameter

AuthToken: TypeAlias = str
AuthInfo = TypeVar("AuthInfo", bound=dict | DataclassProtocol)

resolver_cache = contextvars.ContextVar('resolver_cache')


class Unauthorized(Exception):
    pass


class Forbidden(Exception):
    pass


class AuthInfoException(Exception):
    pass


class SecuritySchemeType(StrEnum):
    api_key = "apiKey"
    http = "http"
    oauth2 = "oauth2"
    open_id_connect = "openIdConnect"
    mutual_tls = "mutualTLS"


@dataclass(frozen=True)
class SecurityScheme(abc.ABC, typing.Generic[AuthInfo]):
    # key: str
    type_: SecuritySchemeType
    # resolver: Callable[[str, list[str]], dict[str, Any]]
    resolver: Callable[[AuthToken], Awaitable[Tuple[AuthInfo, Iterable[str]]] | Tuple[AuthInfo, Iterable[str]]]
    do_log: bool = field(default=True, kw_only=True)

    @abc.abstractmethod
    def to_spec(self) -> dict:
        return {'type': self.type_}

    @abc.abstractmethod
    def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        pass

    async def evaluate(self, incoming_request: IncomingRequest, required_scopes: frozenset[str]) -> AuthInfo:
        cache = resolver_cache.get({})
        print(f'{cache=}')
        if self not in cache:
            auth_token = self._extract_token(incoming_request)
            try:
                cache[self] = await acall(self.resolver(auth_token))
            except Exception as e:
                cache[self] = e
        result = cache[self]
        if isinstance(result, Exception):
            raise result
        auth_info, allowed_scopes = cache[self]
        allowed_scopes = allowed_scopes if isinstance(allowed_scopes, set) else set(allowed_scopes)
        raise_if(missing_scopes := required_scopes - allowed_scopes, Forbidden(f'missing: {",".join(missing_scopes)}'))
        return auth_info

    @property
    @abc.abstractmethod
    def key(self) -> str:
        pass

    def has(self, *args: str, **kwargs: type[Literal] | Callable[[AuthInfo], Any]) -> 'Security':
        return Security(self, scopes=args).use(**kwargs)


@dataclass(frozen=True)
class ApiKeySecurityScheme(SecurityScheme, typing.Generic[AuthInfo]):
    type_: SecuritySchemeType = field(init=False, default_factory=lambda: SecuritySchemeType.api_key)
    location: ParameterLocation
    name: str

    def __post_init__(self):
        assert self._parameter

    def to_spec(self) -> dict:
        return SecurityScheme.to_spec(self) | {'in': self.location, 'name': self.name}

    @cached_property
    def key(self) -> str:
        return f'{self.type_.name}_in_{self.location.name}_{self.name}'

    @cached_property
    def _parameter(self) -> SpecParameter:
        if self.location == ParameterLocation.query:
            return QueryParameter(name=self.name, schema=str)
        if self.location == ParameterLocation.path:
            return PathParameter(name=self.name, schema=str)
        if self.location == ParameterLocation.header:
            return HeaderParameter(name=self.name, schema=str)
        # elif self.in_ == ParameterLocation.cookie: TODO: implement
        #     return CookieParameter(name=self.name, schema=str)
        raise NotImplementedError(f"ApiKeySecurityScheme does not support location {self.location!r} ")

    def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        try:
            return self._parameter.get(incoming_request)
        except KeyError:
            raise Unauthorized from None


@dataclass(frozen=True)
class HttpSecurityScheme(SecurityScheme, abc.ABC, typing.Generic[AuthInfo]):
    type_: SecuritySchemeType = field(init=False, default_factory=lambda: SecuritySchemeType.http)
    # location: ParameterLocation.header
    header_scheme: str

    def to_spec(self) -> dict:
        return SecurityScheme.to_spec(self) | {'scheme': self.header_scheme}


    @cached_property
    def key(self) -> str:
        return f'{self.type_.name}_{self.header_scheme}'.strip()

    def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        dirty_token: str = incoming_request.metadata.header_name_to_value.get("Authorization")
        raise_if(isinstance(dirty_token, str), Unauthorized("Failed to get token from Authorization header"))
        raise_if(not dirty_token.startswith(self.header_scheme), Unauthorized("Wrong authorization scheme"))
        return dirty_token.removeprefix(self.header_scheme)


@dataclass(frozen=True)
class HttpBearerSecurityScheme(HttpSecurityScheme, typing.Generic[AuthInfo]):
    header_scheme: str = field(init=False, default='Bearer ')


@dataclass(frozen=True)
class HttpBasicSecurityScheme(HttpSecurityScheme, typing.Generic[AuthInfo]):
    header_scheme: str = field(init=False, default='Basic ')


@dataclass(frozen=True, slots=True)
class Security:
    scheme: SecurityScheme
    scopes: frozenset[str] = field(default_factory=frozenset)
    argument_name_to_getter: MappingProxyType[str, Callable[[AuthInfo], Awaitable | Any]] = field(default_factory=dict, hash=False)

    def __init__(self, scheme: SecurityScheme, scopes: Iterable[str] = tuple(),
                 argument_name_to_getter: dict[str, Callable[[AuthInfo], Awaitable | Any]] | None = None):
        object.__setattr__(self, 'scheme', scheme)
        object.__setattr__(self, 'scopes', frozenset(scopes))
        object.__setattr__(self, 'argument_name_to_getter', MappingProxyType(argument_name_to_getter or {}))

    async def evaluate(self, incoming_request: IncomingRequest) -> dict:
        auth_info = await self.scheme.evaluate(incoming_request, self.scopes)
        try:
            return {name: await acall(getter(auth_info)) for name, getter in self.argument_name_to_getter.items()}
        except (KeyError, TypeError, ValueError) as e:
            raise AuthInfoException(*e.args) from e

    def use(self, **kwargs: type[Literal] | Callable[[AuthInfo], Any]) -> 'Security':
        argument_name_to_getter = self.argument_name_to_getter.copy()
        for key, value in kwargs.items():
            if typing.get_origin(value) is Literal:
                args_args = typing.get_args(value)
                raise_if(len(args_args) != 1, TypeError(f'Only single literal values allowed, got: {args_args!r}'))
                argument_name_to_getter[key] = lambda _: args_args[0]
            elif callable(value):
                argument_name_to_getter[key] = value
            else:
                raise TypeError(f'Got unexpected value: {value}, Allowed: jmespath/Literal/Callable[[AuthInfo], Any]')
        return Security(scheme=self.scheme, scopes=self.scopes, argument_name_to_getter=argument_name_to_getter)


def get_securities_arg_names(securities: list[Security | SecurityScheme]) -> set[str]:
    result = None
    for security in securities:
        names = set(security.argument_name_to_getter.keys()) if isinstance(security, Security) else set()
        if result is None:
            result = names
        else:
            raise_if(names != result, ValueError(f'All {securities=} must provide same arguments: {names} != {result}'))
    return result or set()
