#  Copyright (C) 2022
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasiliev Ivan <i.vasiliev@technokert.ru>
import asyncio
import logging
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Awaitable, Optional, Any
from enum import StrEnum
from functools import wraps

from aiokafka import ConsumerRecord
from async_tools import acall
from async_tools.async_deinitable import AsyncDeinitable
from async_tools.async_initable import AsyncInitable

from .kafka_record import KafkaRecord, PreproceededRecord
from ..abstract_buffer import AbstractBuffer
from ..deserializer import RecordDeserializer
from ..kafka_client.consumer import Consumer

logger = logging.getLogger(__file__)


class PreProcessRecordError(ValueError):
    pass


class DeserializeRecordError(PreProcessRecordError):
    pass


class ExceptionStrategy(StrEnum):
    EXIT = "exit"
    SKIP = "skip"


class AbstractRecordProcessor(ABC):
    @dataclass
    class Config:
        topic: str
        exception_strategy: ExceptionStrategy = ExceptionStrategy.SKIP

    @dataclass
    class Context:
        consumer: Consumer

    def __init__(self, config: Config, context: Context, record_callback: Callable) -> None:
        self.context = context
        self.config = config
        self._record_callback = self._wrap_record_callback(record_callback)
        self.context.consumer.subscribe(self.config.topic, self.process_record)

    def _wrap_record_callback(self, record_callback: Callable) -> Callable[[Any], Any]:
        if asyncio.iscoroutinefunction(record_callback):
            @wraps(record_callback)
            async def wrapper(*args, **kwargs) -> Any:
                try:
                    return await acall(record_callback(*args, **kwargs))
                except Exception as exc:
                    logger.warning(f"record callback caused error: {repr(exc)}")
                    if self.config.exception_strategy == ExceptionStrategy.EXIT:
                        sys.exit()
                    else:
                        logger.warning(f"due error {self.__class__.__name__} skipped record/s")
            return wrapper
        else:
            @wraps(record_callback)
            def wrapper(*args, **kwargs) -> Any:
                try:
                    return record_callback(*args, **kwargs)
                except Exception as exc:
                    logger.warning(f"record callback caused error: {repr(exc)}")
                    if self.config.exception_strategy == ExceptionStrategy.EXIT:
                        sys.exit()
                    else:
                        logger.warning(f"due error {self.__class__.__name__} skipped record/s")
            return wrapper

    @abstractmethod
    async def process_record(self, record: ConsumerRecord) -> None:
        pass

    @abstractmethod
    async def _preprocess_record(self, msg: ConsumerRecord) -> PreproceededRecord:
        pass

    @staticmethod
    def form_record(pre_proceeded_record: PreproceededRecord, consumer_record: ConsumerRecord) -> KafkaRecord:
        return KafkaRecord(key=pre_proceeded_record.key,
                           value=pre_proceeded_record.value,
                           topic=consumer_record.topic,
                           partition=consumer_record.partition,
                           offset=consumer_record.offset,
                           timestamp=consumer_record.timestamp,
                           timestamp_type=consumer_record.timestamp_type,
                           checksum=consumer_record.checksum,
                           serialized_key_size=consumer_record.serialized_key_size,
                           serialized_value_size=consumer_record.serialized_value_size,
                           headers=consumer_record.headers
                           )


class AbstractDeserializedRecordProcessor(AbstractRecordProcessor, ABC):
    @dataclass
    class Context(AbstractRecordProcessor.Context):
        deserializer: RecordDeserializer

    def __init__(self,
                 config: AbstractRecordProcessor.Config,
                 context: Context,
                 record_callback: Callable) -> None:
        super().__init__(config, context, record_callback)
        self.context = context

    async def _preprocess_record(self, record: ConsumerRecord) -> PreproceededRecord:
        try:
            return self.context.deserializer.deserialize_record(record)
        except Exception as e:
            raise DeserializeRecordError(e)


class AbstractSingleRecordProcessor(AbstractRecordProcessor, ABC):

    def __init__(self, config: AbstractRecordProcessor.Config, context: AbstractRecordProcessor.Context,
                 record_callback: Callable[[KafkaRecord], Awaitable[None]]) -> None:
        super().__init__(config, context, record_callback)
        self._record_callback = record_callback

    async def process_record(self, record: ConsumerRecord) -> None:
        try:
            pre_proceeded_record = await self._preprocess_record(record)
            formed_record = self.form_record(pre_proceeded_record, record)
            await acall(self._record_callback(formed_record))
        except PreProcessRecordError as e:
            # TODO consider add option to skip preprocessing exception and commit anyway
            logger.warning(f"got error in record preprocessing; record {record}, error: {repr(e)}")
            raise e
        await self.context.consumer.commit(record)


class AbstractBufferProcessor(AbstractRecordProcessor, AsyncInitable, AsyncDeinitable, ABC):
    @dataclass
    class Config(AbstractBuffer.Config, AbstractRecordProcessor.Config):
        exception_amount_threshold = 5

    @dataclass
    class Context(AbstractRecordProcessor.Context):
        pass

    def __init__(self,
                 config: Config,
                 context: Context,
                 record_callback: Callable[[list[KafkaRecord]], Awaitable[None]]
                 ) -> None:
        AsyncInitable.__init__(self)
        AsyncDeinitable.__init__(self)
        AbstractRecordProcessor.__init__(self, config, context, record_callback)
        self.config = config
        self._buffer = AbstractBuffer(self._handle_buffer, config)
        self._buffer_task: Optional[asyncio.Task] = None

    async def _async_init(self) -> None:
        self._buffer_task = asyncio.create_task(self._buffer.run())
        self._buffer_task.add_done_callback(self._analyze__task_finished_callback)

    def _analyze__task_finished_callback(self, task: asyncio.Task):
        logger.warning(f"{self.config.topic=} buffer loop finished")

        if task.exception():
            logger.warning(f"{self.Config.topic=} buffer loop finished with error: {task.exception()}")

    async def _async_deinit(self) -> None:
        self._buffer_task.cancel()

    async def process_record(self, record: ConsumerRecord) -> None:
        await self._buffer.add_to_buffer(record)

    async def _handle_buffer(self, buffer: list[ConsumerRecord]) -> None:
        logger.debug("handle buffer")

        last_record = buffer[-1]
        formed_records = []
        exception_amount = 0
        for record in buffer:
            try:
                pre_proceeded_record = await self._preprocess_record(record)
                formed_records.append(self.form_record(pre_proceeded_record, record))
            except PreProcessRecordError as e:
                # TODO consider add option to skip preprocessing exception and commit anyway
                logger.warning(f"got preprocess error on record {record}, error: {repr(e)}")
                exception_amount += 1
                if exception_amount >= self.config.exception_amount_threshold:
                    # TODO fix; not working properly
                    raise e
        await acall(self._record_callback(formed_records))
        await self.context.consumer.commit(last_record)


class RecordProcessor(AbstractDeserializedRecordProcessor, AbstractSingleRecordProcessor):
    @dataclass
    class Config(AbstractSingleRecordProcessor.Config):
        pass

    @dataclass
    class Context(AbstractDeserializedRecordProcessor.Context):
        pass


class BufferRecordProcessor(AbstractDeserializedRecordProcessor, AbstractBufferProcessor):
    @dataclass
    class Config(AbstractBufferProcessor.Config, AbstractRecordProcessor.Config):
        pass

    @dataclass
    class Context(AbstractDeserializedRecordProcessor.Context):
        pass

    def __init__(self,
                 config: Config,
                 context: Context,
                 record_callback: Callable[[list[KafkaRecord]], Awaitable[None]]
                 ) -> None:
        super().__init__(config, context, record_callback)


class BufferRecordMultiProcessor(AbstractDeserializedRecordProcessor, AbstractBufferProcessor):
    @dataclass
    class Config(AbstractBufferProcessor.Config, AbstractRecordProcessor.Config):
        pass

    @dataclass
    class Context(AbstractDeserializedRecordProcessor.Context):
        pass

    def __init__(self, config: Config, context: Context) -> None:
        super().__init__(config, context, self._call_callbacks)
        self._callbacks = []

    async def _call_callbacks(self, records: list[KafkaRecord]) -> None:
        assert self._callbacks, "BufferRecordMultiProcessor started without registered callbacks"
        for callback in self._callbacks:
            await callback(records)

    def register_callback(self, callback: Callable[[list[KafkaRecord]], Awaitable[None]]) -> None:
        self._callbacks.append(callback)
