#  Copyright (C) 2021
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
#
import asyncio
import dataclasses
import logging
from typing import Awaitable, Callable, Iterable, Optional

from aiohttp import web
from aiohttp.web_urldispatcher import UrlDispatcher

from . import HttpStatusCode
from .answer import JsonableAnswer
from .charset import Charset

logger = logging.getLogger(__name__)

Handler = Callable[[web.Request], Awaitable[web.Response]]

REDIRECT_PERIOD = 1
buffer = {"id": 1, 'tasks': {}}


def get_task_id() -> int:
    id_ = buffer['id']
    buffer['id'] += 1
    return id_


def first(iterable: Iterable):
    return next(iter(iterable))


class AskLaterAnswer(JsonableAnswer):
    @dataclasses.dataclass
    class AskLaterPayload:
        task_id: int
        done: bool = False

    def __init__(self, task_id: int, headers: Optional[dict[str, str]] = None, charset: Charset = Charset.UTF8) -> None:
        super().__init__(payload=self.AskLaterPayload(task_id=task_id), headers=headers, charset=charset)

    @classmethod
    def get_class_payload_type(cls) -> type:
        return cls.AskLaterPayload

    @classmethod
    def get_class_status_code(cls) -> HttpStatusCode:
        return HttpStatusCode.OK


def return_result_by_task_id(task_id: int):
    task: asyncio.Task = buffer['tasks'][task_id]
    if task.done():
        print('task is done')
        del buffer['tasks'][task_id]
        return task.result()
    else:
        print('task is not done')
        return AskLaterAnswer(task_id=task_id)


@web.middleware
async def error_middleware(request: web.Request, handler: Handler, _: UrlDispatcher) -> web.Response:
    task_id = request.headers.get('X-TASK-ID')
    if task_id:
        print('delayed task detected')
        return return_result_by_task_id(int(task_id))

    done, pending = await asyncio.wait([handler(request)], timeout=REDIRECT_PERIOD)
    if done:
        print(f"done in first cycle")
        task: asyncio.Task = first(done)
        return task.result()
    else:
        task_id = get_task_id()
        buffer['tasks'][task_id] = first(pending)
        print(f"saved task to buffer: {buffer['tasks']}")
        return AskLaterAnswer(task_id=task_id)
