#!/usr/bin/env python3

from dataclasses import dataclass
import glob
from itertools import combinations
import logging
import os
from pathlib import Path
import re
import sys
import yaml


BOI23LANGUAGES = ["da", "de", "en", "et", "is", "lt", "lv", "no", "pl", "sv", "uk"]
INTre = r"\$(\d+)\$"  # matches '$54$' but not '54' or '$ 54$'
RANGEre = r"\$(\d+)\$--\$(\d+)\$"  # matches '$54--67$'


def parseValidatorConstraints(infile):
    constraints = {}
    for line in infile:
        if re.search(r"constraint:", line):
            c = re.search("constraint:([A-Za-z]+)", line).group(1)
            if c in constraints:
                logging.warning("Validator uses %s twice. Overwriting (fix me).", c)
            constraints[c] = line.strip()
    return constraints


class GeneratorGroups:
    def __init__(self, infile):
        self.numgroups = 0
        self.score = {}
        self.limits = {}
        self.graph = []
        group = "sample"
        for line in infile:
            m = re.match(r"group\s+g(?:roup)?(\d+)\s+(\d+)", line)
            if m:
                self.numgroups += 1
                group, score = int(m.group(1)), int(m.group(2))
                assert group == self.numgroups
                self.score[group] = score
            m = re.match(r"limits (.*)", line.strip())
            if m:
                self.limits[group] = m.group(1)
            m = re.match(r"include_group\s+g(?:roup)?(\d+)", line)
            if m:
                pred = int(m.group(1))
                self.graph.append((pred, group))

    def check_unique_scores(self):
        scores = {}
        unique = True
        for r in range(0, len(self) + 1):
            for solution in combinations(list(range(1, len(self) + 1)), r):
                valid = True
                for u, v in self.graph:
                    if v in solution and not u in solution:
                        valid = False
                if valid:
                    score = sum(self.score[g] for g in solution)
                    if score in scores:
                        logging.warning(
                            "Multiple ways of getting score %d: %s %s",
                            score,
                            solution,
                            scores[score],
                        )
                        unique = False
                    scores[score] = solution
        return unique

    def __len__(self):
        return self.numgroups

@dataclass
class StatementConstraints:

    def __init__(self, infile):
        self.score = {}
        self.limits = {}
        self.constraints = {}
        numgroups = 0
        mode = "looking"
        for line in infile:
            if re.search(r"constraint:", line):
                m = re.search("constraint:([A-Za-z_]+)", line)
                self.constraints[m.group(1)] = line.strip()
            if re.search(r"tabular", line):
                mode = "slurping"
                continue
            if re.search(r"end{tabular}", line):
                mode = "looking"
                continue
            if mode == "slurping":
                m = re.match(INTre + r"\s*&\s* " + INTre + r"\s*&(.*)", line.strip())
                m_ = re.match(INTre + r"\s*&\s* " + RANGEre + r"\s*&(.*)", line.strip())
                if m or m_:
                    numgroups += 1
                    if m:
                        group = int(m.group(1))
                        self.score[group] = int(m.group(2))
                        self.limits[group] = m.group(3)
                    else:
                        group = int(m_.group(1))
                        self.score[group] = int(m_.group(3))
                        self.limits[group] = m_.group(4)
                    assert group == numgroups, (self.score[group], numgroups)


def check(problem_folder):
    with open(problem_folder / "problem.yaml", "r") as infile:
        assert yaml.safe_load(infile)["type"] == "scoring"
    ggroups = None
    try:
        with open(problem_folder / "data" / "generate.sh", "r") as infile:
            ggroups = GeneratorGroups(infile)
        validatorconstraints = {}
        for validator in glob.iglob(os.path.join(problem_folder, "*validators", "**")):
            assert validator.split(".")[-1] in ["py", "cpp"], "Use C++ or Python for validator"
            with open(validator, "r") as infile:
                tmp = parseValidatorConstraints(infile)
                assert not set(tmp.keys()) & set(validatorconstraints.keys())
                validatorconstraints.update(tmp)

        statement = {}
        for s in glob.iglob( os.path.join(problem_folder, "problem_statement", "problem.*.tex")):
            with open(s, "r") as infile:
                lang = s[-6:-4]
                assert lang in BOI23LANGUAGES
                statement[lang] = StatementConstraints(infile)
    except IOError as err:
        print(err)
        print(
            "\033[31mExpected generator and statement(s) for scoring problem.",
            "Skipping...\033[0m",
        )
        return

    for lang in statement:
        if not set(validatorconstraints.keys()) == set(
            statement[lang].constraints.keys()
        ):
            logging.warning("Constraints for language %s don't match validator", lang)

    for k in validatorconstraints:
        print(f"\033[1m{k}\033[0m")
        print("\033[36m  Validator:\033[0m", validatorconstraints[k])
        for lang in sorted(statement.keys()):
            print(f"  \033[36m{lang}\033[0m", statement[lang].constraints[k])

    for i in range(1, 1 + len(ggroups)):
        print(f"\033[1mGroup {i}\033[0m")
        scores = []
        scores.append(f"\033[1m{ggroups.score[i]}\033[0m (generator)")
        ok = True
        for lang in statement:
            thisscore = statement[lang].score[i] if i in statement[lang].score else None
            if thisscore != ggroups.score[i]:
                ok = False
            scores.append(f"\033[1m{thisscore}\033[0m ({lang})")
        print("  \033[36mScore\033[0m", ", ".join(scores), end=" ")
        print("\033[32mOK\033[0m" if ok else "\033[031mWarning!\033[0m")

        print(
            "  \033[36mGenerator limits:\033[0m",
            ggroups.limits[i] if i in ggroups.limits else "(none)",
        )
        for lang in statement:
            print(
                f"  \033[36mStatement ({lang}):\033[0m ",
                statement[lang].limits[i]
                if i in statement[lang].limits
                else "(found none)",
            )
    print()
    if ggroups.check_unique_scores():
        print(
            "\033[32mOK\033[0m", "Unique scores for each valid test group combination"
        )
    else:
        print("\033[031mWarning!\033[0m", "Nonunique scores (see above)")


if len(sys.argv) > 1:
    problem_paths = [Path(sys.argv[1])]
else:
    problem_paths = [p.parent for p in sorted(Path(".").glob("*/problem.yaml"))]

ct = 0
for problem_path in problem_paths:
    if ct > 0:
        print()
    ct += 1
    print(f"\033[1;4m{problem_path}\033[0m")
    check(problem_path)
