#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>

from contextlib import asynccontextmanager
from itertools import chain
from typing import Hashable, Iterable

from dict_caster.extras import to_set

from ..tools.lock_wrapper import LockWrapper


class KeyLock:
    def __init__(self) -> None:
        self.key_to_lock: dict[Hashable, LockWrapper] = {}

    @asynccontextmanager
    async def restrict(self, keys: Iterable[Hashable] | Hashable) -> None:
        keys = to_set(keys)
        locked_keys = []
        not_locked_keys = []

        for key in keys:
            (locked_keys if key in self.key_to_lock else not_locked_keys).append(key)

        for key in chain(not_locked_keys, locked_keys):
            await self.key_to_lock.setdefault(key, LockWrapper()).acquire()

        try:
            yield
        finally:
            for key in keys:
                (lock := self.key_to_lock[key]).release()
                if not lock.depth:
                    del self.key_to_lock[key]
