#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import asyncio
import datetime
import logging
from dataclasses import dataclass, field
from typing import Callable, Awaitable, Iterable, ClassVar, Generic, Mapping

import aiohttp
from async_tools import Periodic
from authlib.jose import JsonWebToken, JWTClaims, JoseError, JsonWebKey, KeySet
from frozendict import frozendict
from http_tools import IncomingRequest
from init_helpers import raise_if, Jsonable, url_to_snake_case

from openapi_tools.spec import SpecResource, SpecRef
from ..exceptions import Unauthorized
from .security_scheme import SecurityScheme, AuthInfo, AuthToken
from .security_scheme_type import SecuritySchemeType

logger = logging.getLogger(__name__)


@dataclass(frozen=True, kw_only=True)
class OpenIdConnectSecurityScheme(SecurityScheme, Generic[AuthInfo]):
    type_: ClassVar[SecuritySchemeType] = SecuritySchemeType.open_id_connect
    header_scheme: ClassVar[str] = 'Bearer'
    jwks_refresh_period: ClassVar[datetime.timedelta] = datetime.timedelta(seconds=60)
    jwks_request_timeout_sec: ClassVar[float] = 10
    jwt: ClassVar[JsonWebToken] = JsonWebToken(['RS256'])

    url: str
    session: aiohttp.ClientSession | None = None
    resolver: Callable[[AuthToken], Awaitable[tuple[AuthInfo, Iterable[str]]] | tuple[AuthInfo, Iterable[str]]] = (
        field(init=False, repr=False, compare=False))
    expected_issuer: str = field(init=False, repr=False, compare=False)
    periodic_refresh_public_keys: Periodic = field(init=False, repr=False, compare=False)
    jwks: KeySet = field(init=False, repr=False, compare=False)
    jwks_loaded_event: asyncio.Event = field(init=False, default_factory=asyncio.Event, repr=False, compare=False)

    def __post_init__(self):
        well_known_index = self.url.find('/.well-known')
        raise_if(well_known_index == -1, ValueError('OpenIdConnectSecurityScheme URL requires ".well-known" part'))
        expected_issuer = self.url[:well_known_index]
        object.__setattr__(self, 'expected_issuer', expected_issuer)
        object.__setattr__(self, 'periodic_refresh_public_keys', Periodic(
            self._refresh_public_keys, self.jwks_refresh_period, first_at=datetime.datetime.now()))

        if self.session is None:
            logger.warning('OpenIdConnectSecurityScheme cannot await session.close(), '
                           'so do it manually or pass the session controlled from outer scope')
            object.__setattr__(self, 'session', aiohttp.ClientSession())

        async def resolver(jwt_value: str) -> tuple[AuthInfo, frozenset[str]]:
            try:
                claims: JWTClaims = self.jwt.decode(jwt_value, await self.get_jwks())
                self._check_claims(claims)
            except (JoseError, ValueError) as e:
                raise Unauthorized(str(e)) from None

            return claims, self._get_scopes(claims)

        object.__setattr__(self, 'resolver', resolver)

    @staticmethod
    def _get_scopes(claims: JWTClaims) -> frozenset[str]:
        if (scope := claims.get('scope')) is None:
            raise Unauthorized(f'BadJWT: Missing scope')
        if not isinstance(scope, str):
            raise Unauthorized(f'BadJWT: Scope is not string: {type(scope).__name__}')
        scopes = frozenset(scope.split(' '))
        if 'openid' not in scopes:
            raise Unauthorized(f'BadJWT: Scope MUST contain openid, got: {scope!r}')

        return scopes

    def _check_claims(self, claims: JWTClaims) -> None:
        claims.validate()
        if (issuer := claims.get('iss')) != self.expected_issuer:
            raise Unauthorized(f'BadJWT: Invalid {issuer=}, expected: {self.expected_issuer!r}')
        # TODO: add "aud" check
        # if 'your-client-id' not in (audience := claims.get('aud', [])):
        #     raise Unauthorized(f'BadJWT: Invalid {audience=}')

    async def _extract_token(self, incoming_request: IncomingRequest) -> AuthToken:
        dirty_token: str = incoming_request.metadata.header_name_to_value.get("Authorization")
        raise_if(not isinstance(dirty_token, str), Unauthorized("Failed to get token from Authorization header"))
        raise_if(not dirty_token.startswith(self.header_scheme), Unauthorized("Wrong authorization scheme"))
        return dirty_token.removeprefix(self.header_scheme).strip()

    def get_key(self) -> str:
        return f'openid__{url_to_snake_case(self.expected_issuer)}'

    def get_spec_dict(self, dependency_to_ref: Mapping[SpecResource, SpecRef]) -> frozendict[str, Jsonable]:
        return SecurityScheme.get_spec_dict(self, dependency_to_ref=dependency_to_ref) | {'openIdConnectUrl': self.url}

    async def get_public_keys_url(self) -> str:
        if not hasattr(self, 'public_keys_url'):
            logger.debug(f'get_public_keys_url from {self.url}')
            try:
                async with self.session.get(self.url, timeout=self.jwks_request_timeout_sec) as resp:
                    configuration = await resp.json()
                object.__setattr__(self, 'public_keys_url', configuration['jwks_uri'])
            except Exception as e:
                raise RuntimeError(f"Failed to get public keys url: {e!r}") from e

        return getattr(self, 'public_keys_url')

    async def _load_public_keys(self) -> list[dict]:
        logger.debug(f'get_public_keys_url from {(public_keys_url := await self.get_public_keys_url())}')
        async with self.session.get(public_keys_url, timeout=self.jwks_request_timeout_sec) as resp:
            jwks = await resp.json()
        return jwks['keys']

    async def _refresh_public_keys(self) -> None:
        logger.debug('_refresh_public_keys')
        object.__setattr__(self, 'jwks', JsonWebKey.import_key_set({'keys': await self._load_public_keys()}))
        self.jwks_loaded_event.set()

    async def get_jwks(self) -> KeySet:
        await self.jwks_loaded_event.wait()
        return self.jwks
