#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasiliev Ivan <i.vasiliev@technokert.ru>


import asyncio
import logging
from collections import defaultdict
from dataclasses import dataclass

from typing import Callable, Dict, Optional, Awaitable, Union, NoReturn
from asyncio import Task
from aiokafka import AIOKafkaConsumer, ConsumerRecord
from async_tools import acall, retry
from kafka import TopicPartition

from .abstract_kafka_client import AbstractKafkaClient

logger = logging.getLogger(__file__)


RecordCallback = Union[
    Callable[[ConsumerRecord], Awaitable[None]],
    Callable[[ConsumerRecord], None]

]


class Consumer(AbstractKafkaClient):

    @dataclass(frozen=True)
    class Config(AbstractKafkaClient.Config):
        topic_auto_offset_is_latest: bool = True
        group_id: Optional[str] = None
        enable_auto_commit: bool = False

    def __init__(self, config: Config, exit_if_fetch_failed: bool = False) -> None:
        self.config = config
        self._topic_to_before_last_committed_offset_callback: Dict[str, RecordCallback] = {}
        self._topic_to_callback: Dict[str, RecordCallback] = {}
        self._topic_to_call_before_fetch: Dict[str, Callable] = {}
        self._topic_to_do_before_fetch_task: Dict[str, asyncio.Future] = {}
        self._consumer: Optional[AIOKafkaConsumer] = None
        self._red_positions = defaultdict(int)
        self._fetch_data__task: Optional[Task] = None
        self._exit_if_fetch_failed: bool = exit_if_fetch_failed
        super(Consumer, self).__init__(config)

    def subscribe(self, topic: str, callback: RecordCallback) -> None:
        logger.info(f"registered handler on topic: {topic}")
        assert not self.is_connected(), "subscribe should be called before Consumer init"
        assert topic not in self._topic_to_callback, f"handler for topic {topic} already registered"
        self._topic_to_callback[topic] = callback

    async def seek_partition_to_begin(self, topic: str) -> None:
        topic_partitions_ids = self._consumer.partitions_for_topic(topic)
        for id_ in topic_partitions_ids:
            await self._consumer.seek_to_beginning(TopicPartition(topic, id_))
        logger.info(f'set topic {topic} offset to begin')

    def register_call_before_fetch(self, topic_name: str, call_before_fetch: Callable) -> None:
        assert not self.is_connected(), "called after init"
        assert topic_name in self._topic_to_callback, f"no handler to topic: {topic_name}"
        self._topic_to_call_before_fetch[topic_name] = call_before_fetch

    def register_call_before_last_committed_offset(self,  topic: str, callback: RecordCallback) -> None:
        assert not self.is_connected(), "called after init"
        self._topic_to_before_last_committed_offset_callback[topic] = callback

    async def get_topic_partition_to_committed_offset(self, topic_name: str) -> dict[TopicPartition, int | None]:
        topic_partition_to_committed_offset: dict[TopicPartition, int] = {}
        for id_ in self._consumer.partitions_for_topic(topic_name):
            topic_partition = TopicPartition(topic_name, id_)
            topic_partition_to_committed_offset[topic_partition] = await self._consumer.committed(topic_partition)
        return topic_partition_to_committed_offset

    async def get_topic_partition_to_max_offset(self, topic_name: str) -> Dict[TopicPartition, int]:
        topic_partitions_ids = self._consumer.partitions_for_topic(topic_name)
        topics = []
        for id_ in topic_partitions_ids:
            topic = TopicPartition(topic_name, id_)
            topics.append(topic)
        return await self._consumer.end_offsets(topics)

    async def _disconnect(self) -> None:
        await self._consumer.stop()

    async def _start(self) -> None:
        self._consumer = await self._create_consumer()
        await self._consumer.start()
        await self._await_before_fetch()
        if self._topic_to_before_last_committed_offset_callback and self.config.group_id:
            await self._fetch_data_before_last_committed_offset()
            logger.info("consumer finished doing callbacks before last committed offset")
        logger.info(f'consumer start listen for {self._topic_to_callback.keys()}')
        self._fetch_data__task = asyncio.create_task(self.__fetch_data())
        self._fetch_data__task.add_done_callback(self._fetch_finished_callback)

    async def _fetch_data_before_last_committed_offset(self) -> None:
        logger.info("consumer starts doing callbacks before last committed offset")
        tasks = []
        for topic, callback in self._topic_to_before_last_committed_offset_callback.items():
            for partition, committed_offset in (await self.get_topic_partition_to_committed_offset(topic)).items():
                if committed_offset is not None:
                    await self._consumer.seek_to_beginning(partition)
                    tasks.append(
                        asyncio.create_task(
                            self._handle_records(
                                partition=partition, until_offset=committed_offset, callback=callback
                            )
                        )
                    )
        await asyncio.gather(*tasks)

    async def _handle_records(self, partition: TopicPartition, until_offset: int, callback: RecordCallback) -> None:
        consumer_record = await self._consumer.getone(partition)
        while consumer_record.offset < until_offset:
            await acall(callback(consumer_record))
            consumer_record = await self._consumer.getone(partition)
        self._consumer.seek(partition, until_offset)

    @staticmethod
    def _fetch_finished_callback(fetch_task: asyncio.Task):
        logger.warning(f"kafka fetch is finished: task: {fetch_task}")

    async def _create_consumer(self) -> AIOKafkaConsumer:
        auto_offset_reset = "latest" if self.config.topic_auto_offset_is_latest else "earliest"
        consumer = AIOKafkaConsumer(
            *self._topic_to_callback.keys(), bootstrap_servers=self.config.address, group_id=self.config.group_id,
            enable_auto_commit=True if self.config.group_id is None else False,
            auto_offset_reset=auto_offset_reset)
        return consumer

    async def _await_before_fetch(self) -> None:
        for topic, work in self._topic_to_call_before_fetch.items():
            await acall(work)

    @retry
    async def __fetch_data(self) -> NoReturn:
        async for record in self._consumer:
            callback = self._topic_to_callback[record.topic]
            await acall(callback(record))

    async def commit(self, record: ConsumerRecord) -> None:
        if self.config.group_id is None:
            return
        offsets = {TopicPartition(record.topic, record.partition): record.offset + 1}
        await self._consumer.commit(offsets)
        logger.info(f'KafkaConsumer: {offsets} committed')
