#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Revva Konstantin <k.revva@abm-jsc.ru>
import asyncio
from dataclasses import dataclass
from typing import Callable, Awaitable, Any, Type, Coroutine, Optional
import json
import logging

from aiokafka import ConsumerRecord
from async_tools import AsyncInitable, AsyncDeinitable
from kafka_tools.abstract_buffer import AbstractBuffer
from kafka_tools.deserializer import RecordDeserializer, JsonDeserializer, AbstractDeserializer

from kafka_tools.kafka_client.consumer import Consumer
from kafka_tools.kafka_client.producer import Producer
from init_helpers.dict_to_dataclass import dict_to_dataclass
from kafka_tools.record_processor.kafka_record import KafkaRecord
from kafka_tools.record_processor.record_processor import BufferRecordProcessor

logger = logging.getLogger(__name__)


@dataclass
class DetectorAnalyzeResult:
    rtsp_url: str
    snapshot_file_id: str
    snapshot_done_at: float
    reference_image_file_ids: list[str]
    detector_requests: list[dict[str, str | float | bool | int]]
    detector_results: list[dict[str, str | float | bool | int | list]]
    comparison_settings: dict[str, str | float | list | dict]
    comparison_results: list[float]
    result: float
    analyzed_at: float


@dataclass
class DetectorErrorResult:
    rtsp_url: str
    occurred_at: float
    exception: str


AnalyzeResultCallback = Callable[[list[DetectorAnalyzeResult]], Coroutine[Any, list[DetectorAnalyzeResult], Any]]
ErrorResultCallback = Callable[[list[DetectorErrorResult]], Coroutine[Any, list[DetectorErrorResult], Any]]


class DetectorKafkaConnector(AsyncInitable, AsyncDeinitable):
    @dataclass(kw_only=True)
    class Config:
        detector_task_topic: str
        detector_result_topic: str
        detector_errors_topic: Optional[str] = None
        MAX_VALUES_IN_BUFFER: int = 10000
        MAX_DELAY_BETWEEN_PROCESS: int = 5

    @dataclass
    class Context:
        kafka_producer: Producer
        kafka_consumer: Consumer

    class DetectorAnalyzeDeserializer(JsonDeserializer):
        # TODO rework kafka_tools to work with functions
        def deserialize(self, payload: bytes) -> DetectorAnalyzeResult:
            payload = super().deserialize(payload)
            return dict_to_dataclass(payload, DetectorAnalyzeResult)

    class DetectorErrorDeserializer(JsonDeserializer):
        # TODO rework kafka_tools to work with functions
        def deserialize(self, payload: bytes) -> DetectorErrorResult:
            payload = super().deserialize(payload)
            if not payload.get("rtsp_url"):  # backward compatibility
                payload["rtsp_url"] = payload["stream_url"]
            return dict_to_dataclass(payload, DetectorErrorResult)

    class DummyDeserializer(AbstractDeserializer):
        def deserialize(self, s: bytes) -> None:
            return

    def __init__(self,
                 config: Config,
                 context: Context,
                 detector_analyze_result_type: Type[DetectorAnalyzeResult]) -> None:
        self.config = config
        self.context = context
        self.detector_analyze_result_type = detector_analyze_result_type
        self._analyze_result_callbacks: set[AnalyzeResultCallback] = set()
        self._error_analyze_result_callbacks: set[ErrorResultCallback] = set()
        self.buffer_record_processor = BufferRecordProcessor(
            config=BufferRecordProcessor.Config(
                topic=self.config.detector_result_topic,
                MAX_VALUES_IN_BUFFER=self.config.MAX_VALUES_IN_BUFFER,
                MAX_DELAY_BETWEEN_PROCESS=self.config.MAX_DELAY_BETWEEN_PROCESS),
            context=BufferRecordProcessor.Context(
                consumer=self.context.kafka_consumer,
                deserializer=RecordDeserializer(RecordDeserializer.Context(self.DummyDeserializer(),
                                                                           self.DetectorAnalyzeDeserializer()))
            ),
            record_callback=self._process_task_result_kafka_records)
        self.buffer_error_record_processor: BufferRecordProcessor | None = None
        if self.config.detector_errors_topic:
            self.buffer_error_record_processor = BufferRecordProcessor(
                config=BufferRecordProcessor.Config(
                    topic=self.config.detector_errors_topic,
                    MAX_VALUES_IN_BUFFER=self.config.MAX_VALUES_IN_BUFFER,
                    MAX_DELAY_BETWEEN_PROCESS=self.config.MAX_DELAY_BETWEEN_PROCESS),
                context=BufferRecordProcessor.Context(
                    consumer=self.context.kafka_consumer,
                    deserializer=RecordDeserializer(RecordDeserializer.Context(self.DummyDeserializer(),
                                                                               self.DetectorErrorDeserializer()))
                ),
                record_callback=self._process_task_error_kafka_records)
        AsyncInitable.__init__(self)
        AsyncDeinitable.__init__(self)

    async def create_analyze_task(self, stream_url: str, snapshot_id: str) -> None:
        payload = {
            "stream_url": stream_url,
            "snapshot_id": snapshot_id,
        }
        await self.context.kafka_producer.produce(self.config.detector_task_topic, json.dumps(payload).encode())
        logger.debug(f"sent snapshot_id {snapshot_id} for {stream_url} to {self.config.detector_task_topic}")

    def subscribe_on_analyze_result(self, callback: AnalyzeResultCallback) -> None:
        self._analyze_result_callbacks.add(callback)

    def subscribe_on_error_result(self, callback: ErrorResultCallback) -> None:
        assert self.config.detector_errors_topic is not None, "error topic is not defined in config"
        self._error_analyze_result_callbacks.add(callback)

    async def _process_task_result_kafka_records(self, records: list[KafkaRecord]) -> None:
        analyze_results = [record.value for record in records]
        callback__tasks = \
            [asyncio.create_task(callback(analyze_results)) for callback in self._analyze_result_callbacks]

        await asyncio.gather(*callback__tasks)

    async def _process_task_error_kafka_records(self, records: list[KafkaRecord]) -> None:
        analyze_errors = [record.value for record in records]
        callback__tasks = \
            [asyncio.create_task(callback(analyze_errors)) for callback in self._error_analyze_result_callbacks]

        await asyncio.gather(*callback__tasks)

    def _deserialize_analyze_result(self, record: ConsumerRecord) -> DetectorAnalyzeResult:
        payload = json.loads(record.value)
        return dict_to_dataclass(payload, self.detector_analyze_result_type)
