#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasiliev Ivan <i.vasiliev@technokert.ru>


import logging
from dataclasses import dataclass
from typing import Optional
from http_tools.http_server_connector import HttpServerConnector


logger = logging.getLogger(__name__)


class SnapshotConnector:

    @dataclass
    class Context(HttpServerConnector.Context):
        project_name: str

    Config = HttpServerConnector.Config

    def __init__(self, config: Config, context: Context):
        self.config = config
        self.context = context
        self._connector = HttpServerConnector(config, context)

    async def get_snapshot(self, stream_url: str, snapshot_timeout_s: Optional[int] = None) -> bytes:
        args = {"stream_url": stream_url}
        if snapshot_timeout_s is not None:
            args["snapshot_timeout_s"] = snapshot_timeout_s

        return await self._connector.get("/snapshot/get", args, headers=self._get_auth_header())

    async def get_avg_snapshot(self, stream_url: str,
                               duration: int | None = None,
                               quality : int | None = None,
                               snapshot_timeout_s: int | None = None,
                               debug: bool | None = None) -> bytes:
        args = {
            "stream_url": stream_url
        }
        if snapshot_timeout_s is not None:
            args["snapshot_timeout_s"] = snapshot_timeout_s
        if duration is not None:
            args["duration"] = duration
        if quality is not None:
            args["quality"] = quality
        if debug is not None:
            args["debug"] = debug

        return await self._connector.get("/avg_snapshot/get", args, headers=self._get_auth_header())

    async def get_netris_snapshot(self, camera_id: int) -> bytes:
        return await self._connector.get(
            "/netris_snapshot/get",
            {"camera_id": camera_id},
            headers=self._get_auth_header()
        )

    def _get_auth_header(self) -> dict[str, str]:
        return {"server_name": self.context.project_name}
