#  Copyright (C) 2024
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
import asyncio
from dataclasses import dataclass, field
from typing import Literal, Any

import mistune

from run_markdown.json_comparator import JsonComparator
from run_markdown.text_comparator import TextComparator


@dataclass
class RunResult:
    stdout: str = ''
    process: asyncio.subprocess.Process | None = None


async def execute_code_block(code: str, language: str, timeout: float = 1) -> RunResult:
    if language == 'python':
        process_args = ['python', '-c', code]
    elif language in ('shell', 'bash', 'sh'):
        process_args = ['script', '-q', '-c', code]
    else:
        raise ValueError('Unknown language')

    process: asyncio.subprocess.Process = await asyncio.create_subprocess_exec(
        *process_args, stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE)

    result = RunResult()
    try:
        # await asyncio.wait_for(process.communicate(code.encode()), timeout=timeout)
        await asyncio.wait_for(process.wait(), timeout=timeout)
        # print('proc finished')
    except asyncio.TimeoutError:
        # print('proc timout')
        result.process = process
    stdout = b''
    try:
        while True:
            red = await asyncio.wait_for(process.stdout.readline(), timeout=0.1)
            # print(f'readline {red=}')
            if red:
                stdout += red
            else:
                # print(f'break')
                break
    except asyncio.TimeoutError:
        # print('readline timeout')
        pass

    try:
        while True:
            red = await asyncio.wait_for(process.stdout.read(1), timeout=0.1)
            # print(f'read {red=}')
            if red:
                stdout += red
            else:
                # print(f'break')
                break
    except asyncio.TimeoutError:
        # print('read timeout')
        pass

    try:
        body = stdout.decode()
    except UnicodeDecodeError as e:
        body = stdout[:e.start]
    result.stdout = body
    return result


async def terminate_running_processes(running_processes: list[asyncio.subprocess.Process], depth: int = 0) -> None:
    for process in running_processes:
        print(' ' * depth + f'terminate {process}')
        process.terminate()
        try:
            await asyncio.wait_for(process.wait(), timeout=1)
        except asyncio.TimeoutError:
            process.kill()


def preview(value: Any, max_len: int = 64, separator: str = '...') -> str:
    result = repr(value)
    if len(result) > max_len:
        result = result[:max_len - len(separator)] + separator
    return result


@dataclass
class Check:
    code: str
    language: str
    expected_format: Literal['text', 'json'] | None = None
    expected: str | None = None

    async def evaluate(self, depth: int = 0) -> asyncio.subprocess.Process | None:
        print(' ' * depth + f"Run {self.language} code: {preview(self.code)}")
        result: RunResult = await execute_code_block(self.code, language=self.language)
        if result.process is not None:
            print(' ' * depth + f"Timeout {self.language} code, detach from it and kill later")
        expected = self.expected
        actual = result.stdout
        print(' ' * depth + f'Got {preview(actual)}... ', end='')
        if expected:
            if self.expected_format == 'json':
                comparator = JsonComparator(expected)
            else:
                comparator = TextComparator(expected)
            comparison_result = comparator.compare(actual)
            if comparison_result.exception is not None:
                print('ERROR')
                print('ERROR DETAILS:')
                comparison_result.print()
                print('RAW DATA:')
                print(repr(actual))
                raise comparison_result.exception
            else:
                print('OK')
        else:
            print('IGNORE')

        return result.process


@dataclass
class LogicBlock:
    name: str
    content: list['Check | LogicBlock'] = field(default_factory=list)

    def add(self, block: 'Check | LogicBlock') -> None:
        self.content.append(block)

    async def evaluate(self, depth: int = 0) -> None:
        print(' ' * depth + f"Starting logic block {self.name!r}")
        running_processes = []
        try:
            for block in self.content:
                result = await block.evaluate(depth + 1)
                if result is not None:
                    running_processes.append(result)
        finally:
            await terminate_running_processes(running_processes, depth)


def get_logic_block_from_md_file(file_path: str) -> LogicBlock:
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()

    markdown = mistune.create_markdown()
    html, parse_state = markdown.parse(text)
    html: str
    parse_state: mistune.BlockState

    return parse_tokens_as_logic(parse_state.tokens, file_path)


def parse_tokens_as_logic(tokens: list[dict], root_name: str) -> LogicBlock:
    root = current_block = LogicBlock(name=root_name)
    last_check = None
    for token in tokens:
        if token['type'] == 'block_code':
            language: str | None = token.get('attrs', {}).get('info', None)
            if language in ('python', 'shell', 'bash', 'sh'):
                last_check = Check(code=token['raw'], language=language)
                current_block.add(last_check)
            else:
                if last_check is not None:
                    last_check.expected_format = language
                    last_check.expected = token['raw']
                    last_check = None
        if token['type'] == 'heading':
            last_check = None
            assert len(children := token['children']) == 1, f'Expected one child, got {len(children)}'
            assert (child := children[0])['type'] == 'text', f'Expected one child of type text, got {child}'
            block_name = child['raw']
            target_parent = root
            target_depth = token.get('attrs', {}).get('level', 0)
            for _ in range(target_depth - 1):
                if not target_parent.content:
                    target_parent.add(LogicBlock(block_name))
                target_parent = target_parent.content[-1]
            target_parent.add(current_block := LogicBlock(block_name))

    return root

