import logging
from dataclasses import dataclass

import sqlalchemy
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine

from async_tools import AsyncInitable, AsyncDeinitable


from .abstract_database_connector import AbstractDatabaseConnector
from ..entity_helpers.fields import get_value_tuple_by_columns, get_primary_key_columns
from ..entity_helpers.sqlalchemy_base import sqlalchemy_mapper_registry


logger = logging.getLogger(__name__)


class DatabaseConnector(AbstractDatabaseConnector, AsyncInitable, AsyncDeinitable):
    @dataclass
    class Config:
        address: str
        auto_table_creation_enabled: bool = True
        pool_size: int = 5  # Pool size is the maximum number of permanent connections to keep
        max_overflow: int = 3  # Temporarily exceeds the set pool_size if no connections are available
        pool_recycle: int = -1  # Maximum number of seconds a connection can persist. It defaults to -1 (no timeout).
        pool_timeout: int = 30  # Maximum number of seconds to wait when retrieving a new connection from the pool
        echo: bool = False  # log all executed sql commands
        pool_pre_ping: bool = True

    def __init__(self, config: Config):
        AsyncInitable.__init__(self)
        AsyncDeinitable.__init__(self)
        self._config = config
        logger.info(f"{type(self).__name__} inited")

        logger.info(f'database address: {self._config.address}')
        schema_name_to_drivers = {
            "sqlite": "sqlite+aiosqlite",
            'postgresql': 'postgresql+asyncpg',
        }
        postfix = "://"
        address = self._config.address
        for key, value in schema_name_to_drivers.items():
            address = address.replace(f'{key}{postfix}', f'{value}{postfix}')
        logger.debug(f'prepared database connection url: {self._config.address}')

        self.engine = create_async_engine(
            address, echo=self._config.echo, pool_size=self._config.pool_size, max_overflow=self._config.max_overflow,
            pool_recycle=self._config.pool_recycle, pool_timeout=self._config.pool_timeout,
            pool_pre_ping=self._config.pool_pre_ping)

    async def _async_init(self):
        if self._config.auto_table_creation_enabled:
            await self.prepare_db(sqlalchemy_mapper_registry)

    async def _async_deinit(self):
        await self.engine.dispose()

    def get_session(self) -> AsyncSession:
        return AsyncSession(self.engine, expire_on_commit=False, autoflush=False)

    def get_engine(self) -> AsyncEngine:
        return self.engine

    async def prepare_db(self, mapper_registry):
        base = mapper_registry.generate_base()
        async with self.engine.begin() as conn:
            await conn.run_sync(base.metadata.create_all)

        logger.info('Database created')
        if not hasattr(mapper_registry, 'initial_values'):
            return

        total_initial_values_amount = added_initial_values_amount = 0
        async with self.get_session() as session:
            for type_, instances in mapper_registry.initial_values.items():
                total_initial_values_amount += len(instances)
                primary_key_columns = get_primary_key_columns(type_)
                required_primary_keys = {get_value_tuple_by_columns(inst, primary_key_columns) for inst in instances}
                logger.debug(f'preparing initial values for {type_}, required_primary_keys: {required_primary_keys}')

                if len(primary_key_columns) != 1:
                    raise ValueError('initial_values preparation failed: complex primary keys are not supported - '
                                     'asyncpg does not support IN syntax')
                request = sqlalchemy.select(primary_key_columns[0]).filter(
                    (primary_key_columns[0]).in_([keys[0] for keys in required_primary_keys])
                )
                existent_primary_keys = (await session.execute(request)).all()
                logger.debug(f'preparing initial values for {type_}, existent_primary_keys: {existent_primary_keys}')
                missing_primary_keys = required_primary_keys - set(existent_primary_keys)
                logger.debug(f'preparing initial values for {type_}, missing_primary_keys: {missing_primary_keys}')

                def is_missing(instance) -> bool:
                    return get_value_tuple_by_columns(instance, primary_key_columns) in missing_primary_keys
                session.add_all([inst for inst in instances if is_missing(inst)])
                added_initial_values_amount += len(missing_primary_keys)
            await session.commit()
        if total_initial_values_amount:
            logger.info(f'Initial values checked: {total_initial_values_amount}, '
                        f'was missing: {added_initial_values_amount}')


