#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>

from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_tools.database_connector.abstract_database_connector import AbstractDatabaseConnector


CONTEXT_VAR_SESSION: ContextVar[AsyncSession | None] = ContextVar('context_var_session', default=None)


class DatabaseSessionMaker:
    @dataclass
    class Context:
        database_connector: AbstractDatabaseConnector

    def __init__(self, context: Context) -> None:
        self.context = context

    @asynccontextmanager
    async def ensure_session(self) -> AsyncSession:
        unsafe_session = CONTEXT_VAR_SESSION.get()
        if unsafe_session is not None:
            yield unsafe_session
            return
        async with self.context.database_connector.get_session() as session:
            context_var_session = CONTEXT_VAR_SESSION.set(session)
            try:
                yield session
                await session.commit()
            finally:
                CONTEXT_VAR_SESSION.reset(context_var_session)
