#  Copyright (C) 2025
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mikhail Mamrov <m.mamrov@abm-jsc.ru>
#
import logging
from dataclasses import dataclass
from typing import ClassVar, TypeAlias
import jwt
from jwt import DecodeError, ExpiredSignatureError

from init_helpers.dict_to_dataclass import dict_to_dataclass

from auth_server_connector.auth_server_connector import AuthServerConnector, JWTPayload, TokenType

logger = logging.getLogger(__name__)


class InvalidTokenError(Exception):
    pass


class ExpiredTokenError(Exception):
    pass


class TokenValidator:
    _SIGNATURE_ALGORITHM: ClassVar[str] = "RS256"

    @dataclass
    class Context:
        auth_server_connector: AuthServerConnector

    def __init__(self, context: Context) -> None:
        self.context = context
        self._public_key = None
        logger.info(f"{type(self).__name__} inited")

    async def validate_access_token(self, token: str, origin: str) -> JWTPayload:
        if token is None:
            raise InvalidTokenError("No access token")

        if self._public_key is None:
            self._public_key = await self.context.auth_server_connector.get_public_key()
        logger.debug(f"Decoding jwt using public key: {self._public_key}")
        try:
            payload = jwt.decode(token, self._public_key, algorithms=[self._SIGNATURE_ALGORITHM])
        except DecodeError as er:
            raise InvalidTokenError(f"Invalid token, parent Exception: {repr(er)}")
        except ExpiredSignatureError:
            raise ExpiredTokenError(f"{TokenType.ACCESS_TOKEN} lifetime ended")

        jwt_payload = dict_to_dataclass(payload, JWTPayload)

        if jwt_payload.type != TokenType.ACCESS_TOKEN:
            raise InvalidTokenError(f"Wrong token type, expected {TokenType.ACCESS_TOKEN}, got {jwt_payload.type}")

        if jwt_payload.portal.origin != origin:
            raise InvalidTokenError(f"Token doesn't belong to origin == {origin}")

        return jwt_payload
