#  Copyright (C) 2023
#  ABM, Moscow
#
#  UNPUBLISHED PROPRIETARY MATERIAL.
#  ALL RIGHTS RESERVED.
#
#  Authors: Mike Orlov <m.orlov@abm-jsc.ru>
from typing import Callable

from ._token import Token, Getter
from .token_factory import OnlyPublicGetterTokenFactory, TokenFactory

opening_bracket_to_closing: dict[str, str] = {
    '[': ']',
    '(': ')',
    '{': '}',
    '<': '>',
}

quotes = ["'", '"']


def find_first(string: str, condition: Callable, start_position: int = 0, escape_char: str = "\\") -> int | None:
    i = start_position
    last_char = None
    while i < len(string):
        char = string[i]
        if last_char != escape_char and not condition(char):
            return i
        last_char = char
        i += 1
    return None


class QpelParser:
    def __init__(self, token_factory: TokenFactory | None = OnlyPublicGetterTokenFactory(), allow_spaces: bool = True):
        self.token_factory = token_factory
        self.allow_spaces = allow_spaces

    def parse(self, qpel: str) -> Token:
        if not qpel:
            raise ValueError("empty qpel")

        result = None
        i = 0
        while i < len(qpel):
            char = qpel[i]

            if char == "'":
                end_quote_index = find_first(qpel, lambda x: x != "'", i + 1)
                if end_quote_index is None:
                    raise ValueError("missing \"'\"")
                result = self.token_factory.make_literal(body=qpel[i+1:end_quote_index])
                i = end_quote_index + 1
            elif char.isdigit() or char == '-':
                first_non_digit_or_dot_index = find_first(qpel, lambda x: x.isdigit() or x in ".-", i + 1)
                if first_non_digit_or_dot_index is None:
                    first_non_digit_or_dot_index = len(qpel)
                number_str = qpel[i:first_non_digit_or_dot_index]
                i = first_non_digit_or_dot_index
                body_type = float if '.' in number_str else int
                result = self.token_factory.make_literal(body=body_type(number_str))
            elif result is None and char.isidentifier() or result is not None and char == ".":
                if char == "." and qpel[i+1] == "[":
                    end = find_paired_bracket(qpel, i + 1)  # TODO: replace with some code aware of all bracket types
                    args, kwargs = self._parse_qpel_as_args(qpel[i + 2:end])
                    i = end + 1
                    assert not kwargs
                    body = []
                    for arg in args:
                        root = arg.get_root()
                        if isinstance(root, Getter):
                            root.body = result
                            body.append(arg)
                        else:
                            raise TypeError(f"'.[' accepts only attributes, got: {arg}")
                    result = self.token_factory.make_literal(body=body)
                else:
                    first_non_digit_or_alpha_index = find_first(qpel, lambda x: x.isdigit() or x.isidentifier(), i + 1)
                    if first_non_digit_or_alpha_index is None:
                        first_non_digit_or_alpha_index = len(qpel)
                    start = i if char != '.' else i + 1
                    key_str = qpel[start:first_non_digit_or_alpha_index]
                    i = first_non_digit_or_alpha_index
                    result = self.token_factory.make_getter(key=key_str, body=result)
            elif char == "[":
                end = find_paired_bracket(qpel, i)  # TODO: replace with some code aware of all bracket types
                args, kwargs = self._parse_qpel_as_args(qpel[i + 1:end])
                i = end + 1
                assert not kwargs
                result = self.token_factory.make_literal(body=args)
            elif char == "(":
                end = find_paired_bracket(qpel, i)  # TODO: replace with some code aware of all bracket types
                args, kwargs = self._parse_qpel_as_args(qpel[i + 1:end])
                i = end + 1
                result = self.token_factory.make_call(body=result, args=args, kwargs=kwargs)
            elif self.allow_spaces and char == " ":
                i += 1
            else:
                raise ValueError(f"Unexpected char: {char!r} in position {i}")

        return result

    @staticmethod
    def _split_by_separator(qpel: str, separator: str) -> list[str]:
        result: list[str] = []
        last_position = 0
        brackets: list[str] = []
        opened_quote = None
        for i, char in enumerate(qpel):
            if opened_quote:
                if char == opened_quote:
                    opened_quote = None
                continue

            if brackets:
                expected_closing_bracket = opening_bracket_to_closing[brackets[-1]]
                if char == expected_closing_bracket:
                    brackets.pop()
            else:
                if char == separator:
                    result.append(qpel[last_position:i])
                    last_position = i + 1
            if char in opening_bracket_to_closing:
                brackets.append(char)
            if char in quotes:
                opened_quote = char
        if brackets:
            raise ValueError(f"Got unpaired brackets in: {qpel!r}")
        if opened_quote:
            raise ValueError(f"Got unclosed quote in: {qpel!r}")
        if last_position < len(qpel):
            result.append(qpel[last_position:])
        return result

    def _parse_qpel_as_args(self, qpel: str) -> tuple[list[Token], dict[str, Token]]:
        args = []
        kwargs = {}
        for part in self._split_by_separator(qpel, separator=','):
            if not part:
                raise ValueError("empty string in args")
            sub_parts = self._split_by_separator(part, separator='=')
            val = self.parse(sub_parts[-1])
            if len(sub_parts) == 1:
                args.append(val)
            elif len(sub_parts) == 2:
                kwargs[sub_parts[0]] = val
            else:
                raise ValueError("Expected only one '=' in keyword arguments")

        return args, kwargs


def find_paired_bracket(text: str, start: int):
    opening_bracket = text[start]
    closing_bracket = opening_bracket_to_closing[opening_bracket]
    depth = 0
    for i in range(start, len(text)):
        char = text[i]
        if char == opening_bracket:
            depth += 1
        elif char == closing_bracket:
            depth -= 1
            if depth == 0:
                return i
    raise ValueError(f"parse error: missing {closing_bracket!r} in {text[start:]!r}")
