#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Vasya Svintsov <v.svintsov@techokert.ru>

import logging
from dataclasses import dataclass, asdict
from typing import Type, Iterable

from http_tools import AbmServiceConnector
from http_tools.long_requests_middleware import first
from init_helpers import conditional_set
from init_helpers.dict_to_dataclass import dict_to_dataclass

from .entity_dataclass import Entity, prepare_entity_column_names, is_entity_dataclass
from .to_list import to_list
from .entity_elements import entity_condition_as_dict, entity_order_as_dict, EntityOrderProtocol, \
    EntityConditionProtocol


logger = logging.getLogger(__name__)


class EntityServerConnector:
    @dataclass(kw_only=True)
    class Config(AbmServiceConnector.Config):
        server_name: str

    Context = AbmServiceConnector.Context

    def __init__(self, config: Config, context: Context) -> None:
        self.config = config
        self._connector = AbmServiceConnector(config, context)

    async def get_one(self,
                      entity_type: Type[Entity],
                      filter_by: EntityConditionProtocol | list[EntityConditionProtocol] | None = None,
                      search_by: EntityConditionProtocol | list[EntityConditionProtocol] | None = None,
                      order_by: EntityOrderProtocol | list[EntityOrderProtocol] | None = None) -> Entity:
        entities = await self.get_list(entity_type, filter_by, search_by, order_by, limit=1)
        if not entities:
            raise LookupError(f'Undefined {entity_type.__entity_name__}: {filter_by}')
        return first(entities)

    async def get_list(self,
                       entity_type: Type[Entity],
                       filter_by: EntityConditionProtocol | list[EntityConditionProtocol] | None = None,
                       search_by: EntityConditionProtocol | list[EntityConditionProtocol] | None = None,
                       order_by: EntityOrderProtocol | list[EntityOrderProtocol] | None = None,
                       offset: int = 0,
                       limit: int = 10 ** 6) -> list[Entity]:
        if not is_entity_dataclass(entity_type):
            raise TypeError("expected entity_dataclass type")

        payload = dict(columns=prepare_entity_column_names(entity_type), offset=offset, limit=limit)
        conditional_set(
            payload, filter_by=[entity_condition_as_dict(f) for f in to_list(filter_by)] if filter_by else None
        )
        conditional_set(
            payload, search_by=[entity_condition_as_dict(s) for s in to_list(search_by)] if search_by else None
        )
        conditional_set(
            payload, order_by=[entity_order_as_dict(o) for o in to_list(order_by)] if order_by else None
        )
        entities = await self._connector.post(
            f'/entity/{entity_type.__entity_name__}/get', payload=payload, headers=self._construct_headers()
        )
        return [dict_to_dataclass(entity_values, entity_type) for entity_values in entities]

    async def add_one(self, entity: Entity) -> int:
        entity_ids = await self.add_list(entity)
        return first(entity_ids)

    async def add_list(self, entity: Entity | Iterable[Entity], *entities: Entity) -> list[int]:
        entities = to_list(entity) + list(entities)
        if not entities:
            raise ValueError(f'empty entities list')

        entity_type = None
        values_to_add = []
        for entity in entities:
            if not is_entity_dataclass(entity):
                raise TypeError("expected entity_dataclass instance")
            if entity_type is not None and not isinstance(entity, entity_type):
                raise TypeError('entities have different types')
            entity_type = type(entity)
            values_to_add.append(asdict(entity))

        payload = dict(values=values_to_add)
        entity_ids = await self._connector.post(
            f'/entity/{entity_type.__entity_name__}/add', payload=payload, headers=self._construct_headers()
        )
        return entity_ids

    async def delete(self,
                     entity_type: Type[Entity],
                     filter_by: EntityConditionProtocol | list[EntityConditionProtocol]) -> bool:
        if not is_entity_dataclass(entity_type):
            raise TypeError("expected entity_dataclass type")

        payload = dict(filter_by=[entity_condition_as_dict(f) for f in to_list(filter_by)])
        deletion_status = await self._connector.post(
            f'/entity/{entity_type.__entity_name__}/delete', payload=payload, headers=self._construct_headers()
        )
        return deletion_status

    def _construct_headers(self) -> dict[str, str]:
        return {'server_name': self.config.server_name}
