#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Alexander Medvedev <a.medvedev@abm-jsc.ru>

import asyncio
import logging
from asyncio import Task, CancelledError
from dataclasses import dataclass
from typing import List

from aio_pika.abc import ExchangeType, AbstractExchange, AbstractRobustChannel
from async_tools import AsyncOnStart, AsyncOnStop

from .rabbitmq_connector import RabbitMQConnector
from .rabbitmq_consumer import RabbitMQConsumer


logger = logging.getLogger(__name__)


class RabbitMQExchanger(AsyncOnStart, AsyncOnStop):
    @dataclass
    class Config:
        name: str
        type: ExchangeType
        durable: bool = False
        auto_delete: bool = False
        internal: bool = False
        passive: bool = False
        timeout: float = None
        robust: bool = True

    @dataclass
    class Context:
        rabbitmq_connector: RabbitMQConnector

    def __init__(self, config: Config, context: Context):
        self._config = config
        self._context = context

        self._queue_message_consumers: List[RabbitMQConsumer] = []
        self._running_consumers: List[Task] = []

    def register(self, consumer: RabbitMQConsumer) -> None:
        self._queue_message_consumers.append(consumer)

    async def _declare_exchange(self, channel: AbstractRobustChannel) -> AbstractExchange:
        exchange = await channel.declare_exchange(
            name=self._config.name, type=self._config.type, durable=self._config.durable,
            auto_delete=self._config.auto_delete, internal=self._config.internal,
            passive=self._config.passive, timeout=self._config.timeout,
            robust=self._config.robust
        )
        logger.info('Exchange successfully declared')
        return exchange

    async def _stop_all_running_consumers(self) -> None:
        for consumer in self._running_consumers:
            logger.info(f'Canceled a consumer: "{consumer}"')
            consumer.cancel()
        self._running_consumers.clear()

    async def _run_queue_message_consumers(self, channel: AbstractRobustChannel, exchange: AbstractExchange) -> None:
        try:
            self._running_consumers += [
                asyncio.create_task(f.run(channel, exchange)) for f in self._queue_message_consumers
            ]
            await asyncio.gather(*self._running_consumers)
        except Exception as err:
            logger.error(f'All running consumers will be stopped due to an error: {repr(err)}')
            await self._stop_all_running_consumers()
            raise CancelledError(f'Error: {repr(err)}')

    async def _on_start(self) -> None:
        channel = await self._context.rabbitmq_connector.get_channel()
        exchange = await self._declare_exchange(channel)

        asyncio.create_task(self._run_queue_message_consumers(channel, exchange))

        logger.info('Start consume messages')

    async def _on_stop(self) -> None:
        if self._running_consumers:
            await self._stop_all_running_consumers()
            logger.info('All running consumers were stopped')
