#  Copyright (C) 2024
#  ABM JSC, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Albakov Ruslan <r.albakov@abm-jsc.ru>
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Hashable

from .redis_connector import RedisConnector
from .abstract_semaphore import AbstractSemaphore

logger = logging.getLogger(__name__)

class RedisSemaphore(AbstractSemaphore):
    @dataclass
    class Context:
        redis: RedisConnector

    @dataclass
    class Config:
        key_time_to_live_s: int = 120
        sleep_time_s: int = 3

    def __init__(self, config: Config, context: Context) -> None:
        self.config = config
        self.context = context

    @asynccontextmanager
    async def restrict(self, key: Hashable, concurrency_limit: int = 1) -> None:
        """
        This method limits the number of connections by ``concurrency_limit``.
        If resources are free, the ``key`` is set to a hash in Redis.  \
        If all resources are locked, the program waits until the resource is released.
        :param key: Hashable, set field in the hash stored at key in Redis.
        :param concurrency_limit: int, the number of restrict connections.
        :return: None
        """
        current_time = str(time.time())
        redis_hash = str(key)
        redis_key = f'{redis_hash}_{current_time}'
        await self.acquire(redis_hash, redis_key, concurrency_limit)
        logger.debug(f"{redis_key} connection acquired")
        try:
            yield
        finally:
            logger.debug(f"{redis_key} connection released")
            await self.release(redis_hash, redis_key)

    async def acquire(self, hash_name: str, key: str, concurrency_limit: int) -> bool:
        if not await self.locked(hash_name, key, concurrency_limit):
            return True

    async def release(self, hash_name: str, key: str) -> None:
        await self.context.redis.hdel(hash_name, key)

    async def locked(self, hash_name: str, key: str, concurrency_limit: int) -> bool:
        while await self._set_and_check_count(hash_name, key, concurrency_limit):
            await self.context.redis.hdel(hash_name, key)
            await asyncio.sleep(self.config.sleep_time_s)
        return False

    async def _set_and_check_count(self, hash_name: str, key: str, concurrency_limit: int) -> bool:
        return await self.context.redis.hset_and_get_hlen(
            hash_name, key, seconds=self.config.key_time_to_live_s
        ) > concurrency_limit
