#!/usr/bin/env python3
# Copyright (c) The mldsa-native project authors
# Copyright (c) The mlkem-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

"""Convenience CLI script wrapping various `make` invocations for
building and running tests and benchmarks.

See the command line interface for more information."""

import platform
import argparse
import os
import re
import sys
import time
import logging
import subprocess
import json

from enum import Enum
from functools import reduce

#
# Some utility functions
#


def dict2str(dict):
    s = ""
    for k, v in dict.items():
        s += f"{k}={v} "
    return s


def github_log(msg):
    if os.environ.get("GITHUB_ENV") is None:
        return
    print(msg)


def github_summary(title, test_label, results):
    """Generate summary for GitHub CI"""
    summary_file = os.environ.get("GITHUB_STEP_SUMMARY")

    res = list(results.values())

    if isinstance(results[SCHEME.MLDSA44], str):
        summaries = list(
            map(
                lambda s: f" {s} |",
                reduce(
                    lambda acc, s: [
                        line1 + " | " + line2 for line1, line2 in zip(acc, s)
                    ],
                    [s.splitlines() for s in res],
                ),
            )
        )
        summaries = [f"| {test_label} |" + summaries[0]] + [
            "| |" + x for x in summaries[1:]
        ]
    else:
        summaries = [
            reduce(
                lambda acc, b: f"{acc} " + (":x: |" if b else ":white_check_mark: |"),
                res,
                f"| {test_label} |",
            )
        ]

    def find_last_consecutive_match(l, s):
        for i, v in enumerate(l[s + 1 :]):
            if not v.startswith("|") or not v.endswith("|"):
                return i + 1
        return len(l)

    def add_summaries(fn, title, summaries):
        summary_title = "| Tests |"
        summary_table_format = "| ----- |"
        for s in SCHEME:
            summary_title += f" {s} |"
            summary_table_format += " ----- |"

        with open(fn, "r") as f:
            pre_summaries = [x for x in f.read().splitlines() if x]
            if title in pre_summaries:
                if summary_title not in pre_summaries:
                    summaries = [summary_title, summary_table_format] + summaries
                    pre_summaries = (
                        pre_summaries[: pre_summaries.index(title) + 1]
                        + summaries
                        + pre_summaries[pre_summaries.index(title) + 1 :]
                    )
                else:
                    i = find_last_consecutive_match(
                        pre_summaries, pre_summaries.index(title)
                    )
                    pre_summaries = pre_summaries[:i] + summaries + pre_summaries[i:]
                return ("w", pre_summaries)
            else:
                pre_summaries = [
                    title,
                    summary_title,
                    summary_table_format,
                ] + summaries
                return ("a", pre_summaries)

    if summary_file is not None:
        (access_mode, summaries) = add_summaries(summary_file, title, summaries)
        with open(summary_file, access_mode) as f:
            print("\n".join(summaries), file=f)


logging.basicConfig(
    stream=sys.stdout, format="%(levelname)-5s > %(name)-40s %(message)s"
)


def config_logger(verbose):
    logger = logging.getLogger()

    if verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)


def logger(test_type, scheme, cross_prefix, opt):
    """Emit line indicating the processing of the given test"""

    test_desc = str(test_type)

    compile_mode = "cross" if cross_prefix else "native"
    if opt is None:
        opt_label = ""
    elif opt is True:
        opt_label = " opt"
    else:
        opt_label = " no_opt"

    if isinstance(test_type, TEST_TYPES) and test_type.is_example():
        sz = 40
    else:
        sz = 18

    return logging.getLogger(
        "{0:<{1}} {2:<11} {3:<17}".format(
            test_desc,
            sz,
            str(scheme),
            "({}{}):".format(compile_mode, opt_label),
        )
    )


#
# Core classes providing a wrapper around invocations to `make`
# for building and running tests and benchmarks
#


class SCHEME(Enum):
    MLDSA44 = 1
    MLDSA65 = 2
    MLDSA87 = 3

    def __str__(self):
        if self == SCHEME.MLDSA44:
            return "ML-DSA-44"
        if self == SCHEME.MLDSA65:
            return "ML-DSA-65"
        if self == SCHEME.MLDSA87:
            return "ML-DSA-87"

    def suffix(self):
        if self == SCHEME.MLDSA44:
            return "44"
        if self == SCHEME.MLDSA65:
            return "65"
        if self == SCHEME.MLDSA87:
            return "87"

    def from_mode(mode):
        if isinstance(mode, str):
            mode = int(mode)
        if mode == 44:
            return SCHEME.MLDSA44
        if mode == 65:
            return SCHEME.MLDSA65
        if mode == 87:
            return SCHEME.MLDSA87


class TEST_TYPES(Enum):
    FUNC = 1
    BENCH = 2
    KAT = 3
    BENCH_COMPONENTS = 4
    ACVP = 5
    BRING_YOUR_OWN_FIPS202 = 6
    BRING_YOUR_OWN_FIPS202_STATIC = 7
    CUSTOM_BACKEND = 8
    BASIC = 9
    MONOLITHIC_BUILD = 10
    MONOLITHIC_BUILD_MULTILEVEL = 11
    MULTILEVEL_BUILD = 12
    MULTILEVEL_BUILD_NATIVE = 13
    MONOLITHIC_BUILD_MULTILEVEL_NATIVE = 14
    MONOLITHIC_BUILD_NATIVE = 15
    STACK = 16
    SIZE = 17
    BASIC_DETERMINISTIC = 18
    UNIT = 19
    ALLOC = 20
    BASIC_LOWRAM = 21
    RNG_FAIL = 22

    def is_benchmark(self):
        return self in [TEST_TYPES.BENCH, TEST_TYPES.BENCH_COMPONENTS]

    def is_example(self):
        return self in TEST_TYPES.examples()

    @staticmethod
    def examples():
        return [
            TEST_TYPES.BRING_YOUR_OWN_FIPS202,
            TEST_TYPES.BRING_YOUR_OWN_FIPS202_STATIC,
            TEST_TYPES.CUSTOM_BACKEND,
            TEST_TYPES.BASIC,
            TEST_TYPES.MONOLITHIC_BUILD,
            TEST_TYPES.MONOLITHIC_BUILD_NATIVE,
            TEST_TYPES.MONOLITHIC_BUILD_MULTILEVEL,
            TEST_TYPES.MONOLITHIC_BUILD_MULTILEVEL_NATIVE,
            TEST_TYPES.MULTILEVEL_BUILD,
            TEST_TYPES.MULTILEVEL_BUILD_NATIVE,
            TEST_TYPES.BASIC_DETERMINISTIC,
            TEST_TYPES.BASIC_LOWRAM,
        ]

    @staticmethod
    def from_string(s):
        for e in TEST_TYPES.examples():
            if str.lower(e.name) == str.lower(s):
                return e
        raise Exception(
            f"Could not find example {s}. Examples: {list(map(lambda e: str.lower(e.name), TEST_TYPES.examples()))}"
        )

    def __str__(self):
        return self.desc()

    def desc(self):
        if self == TEST_TYPES.FUNC:
            return "Functional Test"
        if self == TEST_TYPES.BENCH:
            return "Benchmark"
        if self == TEST_TYPES.BENCH_COMPONENTS:
            return "Benchmark Components"
        if self == TEST_TYPES.KAT:
            return "Kat Test"
        if self == TEST_TYPES.ACVP:
            return "ACVP Test"
        if self == TEST_TYPES.STACK:
            return "Stack Usage Test"
        if self == TEST_TYPES.BRING_YOUR_OWN_FIPS202:
            return "Example (Bring-Your-Own-FIPS202)"
        if self == TEST_TYPES.BRING_YOUR_OWN_FIPS202_STATIC:
            return "Example (Bring-Your-Own-FIPS202, static)"
        if self == TEST_TYPES.CUSTOM_BACKEND:
            return "Example (Custom Backend)"
        if self == TEST_TYPES.BASIC:
            return "Example (mldsa-native as code package)"
        if self == TEST_TYPES.BASIC_DETERMINISTIC:
            return "Example (mldsa-native as code package without randombytes() implementation)"
        if self == TEST_TYPES.BASIC_LOWRAM:
            return "Example (mldsa-native with reduced RAM usage)"
        if self == TEST_TYPES.MONOLITHIC_BUILD:
            return "Example (monobuild)"
        if self == TEST_TYPES.MONOLITHIC_BUILD_NATIVE:
            return "Example (monobuild, native)"
        if self == TEST_TYPES.MONOLITHIC_BUILD_MULTILEVEL:
            return "Example (monobuild, multilevel)"
        if self == TEST_TYPES.MONOLITHIC_BUILD_MULTILEVEL_NATIVE:
            return "Example (monobuild, multilevel, native)"
        if self == TEST_TYPES.MULTILEVEL_BUILD:
            return "Example (multilevel build)"
        if self == TEST_TYPES.MULTILEVEL_BUILD_NATIVE:
            return "Example (multilevel build, native)"
        if self == TEST_TYPES.SIZE:
            return "Measurement Code Size"
        if self == TEST_TYPES.UNIT:
            return "Unit Test"
        if self == TEST_TYPES.ALLOC:
            return "Alloc Test"
        if self == TEST_TYPES.RNG_FAIL:
            return "RNG Failure Test"

    def make_dir(self):
        if self == TEST_TYPES.BRING_YOUR_OWN_FIPS202:
            return "examples/bring_your_own_fips202"
        if self == TEST_TYPES.BRING_YOUR_OWN_FIPS202_STATIC:
            return "examples/bring_your_own_fips202_static"
        if self == TEST_TYPES.CUSTOM_BACKEND:
            return "examples/custom_backend"
        if self == TEST_TYPES.BASIC:
            return "examples/basic"
        if self == TEST_TYPES.BASIC_DETERMINISTIC:
            return "examples/basic_deterministic"
        if self == TEST_TYPES.BASIC_LOWRAM:
            return "examples/basic_lowram"
        if self == TEST_TYPES.MONOLITHIC_BUILD:
            return "examples/monolithic_build"
        if self == TEST_TYPES.MONOLITHIC_BUILD_NATIVE:
            return "examples/monolithic_build_native"
        if self == TEST_TYPES.MONOLITHIC_BUILD_MULTILEVEL:
            return "examples/monolithic_build_multilevel"
        if self == TEST_TYPES.MONOLITHIC_BUILD_MULTILEVEL_NATIVE:
            return "examples/monolithic_build_multilevel_native"
        if self == TEST_TYPES.MULTILEVEL_BUILD:
            return "examples/multilevel_build"
        if self == TEST_TYPES.MULTILEVEL_BUILD_NATIVE:
            return "examples/multilevel_build_native"
        return ""

    def make_target(self):
        if self == TEST_TYPES.FUNC:
            return "func"
        if self == TEST_TYPES.BENCH:
            return "bench"
        if self == TEST_TYPES.BENCH_COMPONENTS:
            return "bench_components"
        if self == TEST_TYPES.KAT:
            return "kat"
        if self == TEST_TYPES.ACVP:
            return "acvp"
        if self == TEST_TYPES.STACK:
            return "stack"
        if self == TEST_TYPES.BRING_YOUR_OWN_FIPS202:
            return ""
        if self == TEST_TYPES.BRING_YOUR_OWN_FIPS202_STATIC:
            return ""
        if self == TEST_TYPES.CUSTOM_BACKEND:
            return ""
        if self == TEST_TYPES.BASIC:
            return ""
        if self == TEST_TYPES.BASIC_DETERMINISTIC:
            return ""
        if self == TEST_TYPES.BASIC_LOWRAM:
            return ""
        if self == TEST_TYPES.MONOLITHIC_BUILD:
            return ""
        if self == TEST_TYPES.MONOLITHIC_BUILD_NATIVE:
            return ""
        if self == TEST_TYPES.MONOLITHIC_BUILD_MULTILEVEL:
            return ""
        if self == TEST_TYPES.MONOLITHIC_BUILD_MULTILEVEL_NATIVE:
            return ""
        if self == TEST_TYPES.MULTILEVEL_BUILD:
            return ""
        if self == TEST_TYPES.MULTILEVEL_BUILD_NATIVE:
            return ""
        if self == TEST_TYPES.SIZE:
            return "size"
        if self == TEST_TYPES.UNIT:
            return "unit"
        if self == TEST_TYPES.ALLOC:
            return "alloc"
        if self == TEST_TYPES.RNG_FAIL:
            return "rng_fail"

    def make_run_target(self, scheme):
        t = self.make_target()
        if t == "":
            run_t = "run"
        else:
            run_t = f"run_{t}"
        if scheme is not None:
            return f"{run_t}_{scheme.suffix()}"
        else:
            return run_t


class Tests:
    def __init__(self, args):
        config_logger(args.verbose)
        self.args = args
        self.failed = []

    def fail(self, info):
        self.failed.append(info)

    def check_fail(self):
        num_failed = len(self.failed)
        if num_failed > 0:
            print(f"{num_failed} tests FAILED")
            for info in self.failed:
                print(f"* {info}")
            exit(1)
        print("All good!")
        exit(0)

    def cmd_prefix(self):
        res = []
        if self.args.run_as_root is True:
            res += ["sudo"]
        if self.args.exec_wrapper is not None and self.args.exec_wrapper != "":
            res += self.args.exec_wrapper.split(" ")
        if self.args.mac_taskpolicy is not None:
            res += ["taskpolicy", "-c", f"{self.args.mac_taskpolicy}"]

        return res

    def make_j(self):
        if self.args.j is None or int(self.args.j) == 1:
            return []
        return [f"-j{self.args.j}"]

    def do_opt_all(self):
        return self.args.opt.lower() == "all"

    def do_opt(self):
        return self.args.opt.lower() in ["all", "opt"]

    def do_no_opt(self):
        return self.args.opt.lower() in ["all", "no_opt"]

    def compile_mode(self):
        return "Cross" if self.args.cross_prefix != "" else "Native"

    def _compile_schemes(self, test_type, opt):
        """compile or cross compile with some extra environment variables and makefile arguments"""

        if opt is None:
            opt_label = ""
        elif opt is True:
            opt_label = " opt"
        else:
            opt_label = " no_opt"

        github_log(
            f"::group::compile {self.compile_mode()}{opt_label} {test_type.desc()}"
        )

        log = logger(test_type, "Compile", self.args.cross_prefix, opt)

        extra_make_args = []
        # Those options are not used in the examples
        if test_type.is_example() is False:
            extra_make_args += [f"OPT={int(opt)}", f"AUTO={int(self.args.auto)}"]
        if test_type.is_benchmark() is True:
            extra_make_args += [f"CYCLES={self.args.cycles}"]
        if test_type.make_dir() != "":
            extra_make_args += ["-C", test_type.make_dir()]
        extra_make_args += self.make_j()

        target = test_type.make_target()
        target = [target] if target != "" else []
        args = ["make"] + target + extra_make_args

        # Force static compilation for cross builds
        cflags = self.args.cflags
        if cflags is None:
            cflags = ""
        ldflags = self.args.ldflags
        if ldflags is None:
            ldflags = ""

        if test_type.is_example() and self.args.cross_prefix != "":
            cflags += " -static"

        # Add FIPS202 backend selection if specified and this is an OPT build
        if self.args.fips202_aarch64_backend != "auto" and opt is True:
            # Make sure we're forcing AArch64 architecture
            if not " -DMLD_FORCE_AARCH64" in cflags:
                cflags += " -DMLD_FORCE_AARCH64"

            # Enable native backend for FIPS202
            cflags += " -DMLD_CONFIG_USE_NATIVE_BACKEND_FIPS202"

            # Specify the backend file
            cflags += f' -DMLD_CONFIG_FIPS202_BACKEND_FILE=\\"fips202/native/aarch64/{self.args.fips202_aarch64_backend}.h\\"'

        env_update = {}
        if cflags != "":
            env_update["CFLAGS"] = cflags
        if ldflags != "":
            env_update["LDFLAGS"] = ldflags
        if self.args.cross_prefix != "":
            env_update["CROSS_PREFIX"] = self.args.cross_prefix

        env = os.environ.copy()
        env.update(env_update)

        log.info(dict2str(env_update) + " ".join(args))

        p = subprocess.run(
            args,
            stdout=subprocess.DEVNULL if not self.args.verbose else None,
            env=env,
        )

        if p.returncode != 0:
            log.error(f"make failed: {p.returncode}")
            self.fail(f"Compilation for ({test_type}{opt_label})")

        github_log("::endgroup::")

    def _run_scheme(
        self,
        test_type,
        opt,
        scheme,
        suppress_output=True,
    ):
        """Run the binary in all different ways

        Arguments:

        - scheme: Scheme to test
        - suppress_output: Indicate whether to suppress or print-and-return the output
        """

        if opt is None:
            opt_label = ""
        elif opt is True:
            opt_label = " opt"
        else:
            opt_label = " no_opt"

        if scheme is None:
            scheme_str = "All"
        else:
            scheme_str = str(scheme)

        log = logger(test_type, scheme_str, self.args.cross_prefix, opt)

        args = ["make", test_type.make_run_target(scheme)]
        if test_type.is_benchmark() is False and test_type.is_example() is False:
            args += self.make_j()
        if test_type.make_dir() != "":
            args += ["-C", test_type.make_dir()]

        env_update = {}
        if len(self.cmd_prefix()) > 0:
            env_update["EXEC_WRAPPER"] = " ".join(self.cmd_prefix())

        # Add stack analysis flags for stack tests
        if test_type == TEST_TYPES.STACK:
            stack_flags = []
            if hasattr(self.args, "peak_only") and self.args.peak_only:
                stack_flags.append("--peak-only")
            if hasattr(self.args, "dump_massif") and self.args.dump_massif:
                stack_flags.append("--dump-massif")
            if stack_flags:
                env_update["STACK_ANALYSIS_FLAGS"] = " ".join(stack_flags)

        # Add ACVP version for ACVP tests
        if test_type == TEST_TYPES.ACVP and hasattr(self.args, "version"):
            env_update["ACVP_VERSION"] = self.args.version

        env = os.environ.copy()
        env.update(env_update)

        cmd_str = dict2str(env_update) + " ".join(args)
        log.info(cmd_str)

        p = subprocess.run(args, capture_output=True, universal_newlines=False, env=env)

        if p.returncode != 0:
            log.error(f"'{cmd_str}' failed with with {p.returncode}")
            log.error(p.stderr.decode())
            self.fail(f"{test_type.desc()} ({scheme_str}{opt_label})")
            return True  # Failure
        elif suppress_output is True:
            if self.args.verbose is True:
                log.info(p.stdout.decode())
            return False  # No failure
        else:
            result = p.stdout.decode()
            log.info(result)
            return result

    def _run_schemes(self, test_type, opt, suppress_output=True):
        """Arguments:

        - opt: Whether native backends should be enabled
        - suppress_output: Indicate whether to suppress or print-and-return the output
        """

        results = {}

        k = "opt" if opt else "no_opt"

        github_log(f"::group::run {self.compile_mode()} {k} {test_type.desc()}")

        results[k] = {}
        for scheme in SCHEME:
            result = self._run_scheme(
                test_type,
                opt,
                scheme,
                suppress_output,
            )

            results[k][scheme] = result

        title = "## " + (self.compile_mode()) + " " + (k.capitalize()) + " Tests"
        github_summary(title, test_type.desc(), results[k])

        github_log("::endgroup::")

        if suppress_output is True:
            # In this case, we only gather success/failure booleans
            return reduce(
                lambda acc, c: acc or c,
                [r for rs in results.values() for r in rs.values()],
                False,
            )
        else:
            return results

    def func(self):
        def _func(opt):
            self._compile_schemes(TEST_TYPES.FUNC, opt)
            if self.args.check_namespace is True:
                p = subprocess.run(
                    ["python3", "check-namespace"],
                    stdout=subprocess.DEVNULL if not self.args.verbose else None,
                    cwd="scripts",
                )
                if p.returncode != 0:
                    self.fail(f"Namespacing failed for opt={opt}")
            if self.args.run:
                self._run_schemes(TEST_TYPES.FUNC, opt)

        if self.do_no_opt():
            _func(False)
        if self.do_opt():
            _func(True)

        self.check_fail()

    def kat(self):
        def _kat(opt):
            self._compile_schemes(TEST_TYPES.KAT, opt)
            if self.args.run:
                self._run_schemes(TEST_TYPES.KAT, opt)

        if self.do_no_opt():
            _kat(False)
        if self.do_opt():
            _kat(True)

        self.check_fail()

    def unit(self):
        def _unit(opt):
            self._compile_schemes(TEST_TYPES.UNIT, opt)
            if self.args.run:
                self._run_schemes(TEST_TYPES.UNIT, opt)

        if self.do_no_opt():
            _unit(False)
        if self.do_opt():
            _unit(True)

        self.check_fail()

    def alloc(self):
        def _alloc(opt):
            self._compile_schemes(TEST_TYPES.ALLOC, opt)
            if self.args.run:
                self._run_schemes(TEST_TYPES.ALLOC, opt)

        if self.do_no_opt():
            _alloc(False)
        if self.do_opt():
            _alloc(True)

        self.check_fail()

    def rng_fail(self):
        def _rng_fail(opt):
            self._compile_schemes(TEST_TYPES.RNG_FAIL, opt)
            if self.args.run:
                self._run_schemes(TEST_TYPES.RNG_FAIL, opt)

        if self.do_no_opt():
            _rng_fail(False)
        if self.do_opt():
            _rng_fail(True)

        self.check_fail()

    def acvp(self):
        def _acvp(opt):
            self._compile_schemes(TEST_TYPES.ACVP, opt)
            if self.args.run:
                self._run_scheme(TEST_TYPES.ACVP, opt, None)

        if self.do_no_opt():
            _acvp(False)
        if self.do_opt():
            _acvp(True)

        self.check_fail()

    def examples(self):
        if self.args.l is None:
            l = TEST_TYPES.examples()
        else:
            l = list(map(TEST_TYPES.from_string, self.args.l))

        # Filter out excluded examples
        if hasattr(self.args, "exclude_example") and self.args.exclude_example:
            excluded = [TEST_TYPES.from_string(ex) for ex in self.args.exclude_example]
            l = [e for e in l if e not in excluded]

        for e in l:
            self._compile_schemes(e, None)
            self._run_scheme(e, None, None)

    def bench(self):
        output = self.args.output
        components = self.args.components

        if components is False:
            test_type = TEST_TYPES.BENCH
        else:
            test_type = TEST_TYPES.BENCH_COMPONENTS
            output = False

        # NOTE: We haven't yet decided how to output both opt/no-opt benchmark results
        resultss = None
        if self.do_opt_all():
            self._compile_schemes(test_type, False)
            if self.args.run:
                self._run_schemes(test_type, False, suppress_output=False)
            self._compile_schemes(test_type, True)
            if self.args.run:
                resultss = self._run_schemes(test_type, True, suppress_output=False)
        else:
            self._compile_schemes(test_type, self.do_opt())
            if self.args.run:
                resultss = self._run_schemes(
                    test_type, self.do_opt(), suppress_output=False
                )

        if resultss is None:
            self.check_fail()

        # NOTE: There will only be one items in resultss, as we haven't yet decided how to write both opt/no-opt benchmark results
        for k, results in resultss.items():
            if not (results is not None and output is not None and components is False):
                continue

            v = []
            for scheme in results:
                schemeStr = str(scheme)
                r = results[scheme]

                # The first 3 lines of the output are expected to be
                # keypair cycles=X
                # sign cycles=X
                # verify cycles=X

                lines = [line for line in r.splitlines() if "=" in line]

                d = {k.strip(): int(v) for k, v in (l.split("=") for l in lines)}
                for primitive in ["keypair", "sign", "verify"]:
                    v.append(
                        {
                            "name": f"{schemeStr} {primitive}",
                            "unit": "cycles",
                            "value": d[f"{primitive} cycles (avg)"],
                        }
                    )

            with open(output, "w") as f:
                f.write(json.dumps(v))

        self.check_fail()

    def stack(self):
        """Stack usage analysis"""

        def _stack(opt):
            self._compile_schemes(TEST_TYPES.STACK, opt)
            if self.args.run:
                self._run_schemes(TEST_TYPES.STACK, opt, suppress_output=False)

        if self.do_no_opt():
            _stack(False)
        if self.do_opt():
            _stack(True)

        self.check_fail()

    def size(self):

        test_type = TEST_TYPES.SIZE

        resultss = None

        if self.do_opt_all():
            self._compile_schemes(test_type, False)
            if self.args.run:
                self._run_schemes(test_type, False, suppress_output=False)
            self._compile_schemes(test_type, True)
            if self.args.run:
                resultss = self._run_schemes(test_type, True, suppress_output=False)
        else:
            self._compile_schemes(test_type, self.do_opt())
            if self.args.run:
                resultss = self._run_schemes(
                    test_type, self.do_opt(), suppress_output=False
                )

        if resultss is None:
            self.check_fail()

    def all(self):
        func = self.args.func
        kat = self.args.kat
        acvp = self.args.acvp
        examples = self.args.examples
        stack = self.args.stack
        unit = self.args.unit
        alloc = self.args.alloc
        rng_fail = self.args.rng_fail

        def _all(opt):
            if func is True:
                self._compile_schemes(TEST_TYPES.FUNC, opt)
            if kat is True:
                self._compile_schemes(TEST_TYPES.KAT, opt)
            if acvp is True:
                self._compile_schemes(TEST_TYPES.ACVP, opt)
            if stack is True:
                self._compile_schemes(TEST_TYPES.STACK, opt)
            if unit is True:
                self._compile_schemes(TEST_TYPES.UNIT, opt)
            if alloc is True:
                self._compile_schemes(TEST_TYPES.ALLOC, opt)
            if rng_fail is True:
                self._compile_schemes(TEST_TYPES.RNG_FAIL, opt)

            if self.args.check_namespace is True:
                p = subprocess.run(
                    ["python3", "check-namespace"],
                    stdout=subprocess.DEVNULL if not self.args.verbose else None,
                    cwd="scripts",
                )
                if p.returncode != 0:
                    self.fail(f"Namespacing failed for opt={opt}")

            if self.args.run is False:
                return

            if func is True:
                self._run_schemes(TEST_TYPES.FUNC, opt)
            if kat is True:
                self._run_schemes(TEST_TYPES.KAT, opt)
            if acvp is True:
                self._run_scheme(TEST_TYPES.ACVP, opt, None)
            if stack is True:
                self._run_schemes(TEST_TYPES.STACK, opt, suppress_output=False)
            if unit is True:
                self._run_schemes(TEST_TYPES.UNIT, opt)
            if alloc is True:
                self._run_schemes(TEST_TYPES.ALLOC, opt)
            if rng_fail is True:
                self._run_schemes(TEST_TYPES.RNG_FAIL, opt)

        if self.do_no_opt():
            _all(False)
        if self.do_opt():
            _all(True)

        if examples is True:
            self.examples()

        self.check_fail()

    def cbmc(self):

        def list_proofs():
            cmd_str = ["./proofs/cbmc/list_proofs.sh"]
            p = subprocess.run(cmd_str, capture_output=True, universal_newlines=False)
            proofs = filter(lambda s: s.strip() != "", p.stdout.decode().split("\n"))
            return list(proofs)

        if self.args.list_functions:
            for p in list_proofs():
                print(p)
            exit(0)

        def run_cbmc_single_step(mldsa_parameter_set, proofs):
            envvars = {"MLD_CONFIG_PARAMETER_SET": mldsa_parameter_set}
            scheme = SCHEME.from_mode(mldsa_parameter_set)
            num_proofs = len(proofs)
            for i, func in enumerate(proofs):
                log = logger(f"CBMC ({i+1}/{num_proofs})", scheme, None, None)
                log.info(f"Starting CBMC proof for {func}")
                start = time.time()
                try:
                    p = subprocess.run(
                        [
                            "python3",
                            "run-cbmc-proofs.py",
                            "--summarize",
                            "--no-coverage",
                            "--per-proof-timeout",
                            str(self.args.per_proof_timeout),
                            "-p",
                            func,
                        ]
                        + self.make_j(),
                        cwd="proofs/cbmc",
                        env=os.environ.copy() | envvars,
                        timeout=self.args.timeout,
                        capture_output=(self.args.verbose is False),
                    )
                except subprocess.TimeoutExpired as e:
                    log.error(f"   TIMEOUT (after {self.args.timeout}s)")
                    log.error(e.stderr.decode())
                    self.fail(f"CBMC proof for {func}")
                    if self.args.fail_upon_error:
                        log.error(
                            "Aborting proofs, as requested by -f/--fail-upon-error"
                        )
                        exit(1)
                    continue

                end = time.time()
                dur = int(end - start)
                if p.returncode != 0:
                    log.error(f"   FAILED (after {dur}s)")
                    if p.stderr is not None:
                        log.error(p.stderr.decode())
                    self.fail(f"CBMC proof for {func}")
                else:
                    log.info(f"   SUCCESS (after {dur}s)")

        def run_cbmc(mldsa_parameter_set):
            all_proofs = list_proofs()
            proofs = all_proofs
            if self.args.start_with is not None:
                try:
                    idx = proofs.index(self.args.start_with)
                    proofs = proofs[idx:]
                except ValueError:
                    log.error(
                        "Could not find function {self.args.start_with}. Running all proofs"
                    )
            if self.args.proof is not None:
                proofs = []
                for pat in self.args.proof:
                    # Replace wildcards by regexp wildcards
                    pat = pat.replace("*", ".*")
                    proofs += list(filter(lambda x: re.match(pat, x), all_proofs))
                proofs = sorted(set(proofs))

            if self.args.single_step:
                run_cbmc_single_step(mldsa_parameter_set, proofs)
                return
            envvars = {"MLD_CONFIG_PARAMETER_SET": mldsa_parameter_set}
            cmd = (
                [
                    "python3",
                    "run-cbmc-proofs.py",
                    "--summarize",
                    "--no-coverage",
                    "--per-proof-timeout",
                    str(self.args.per_proof_timeout),
                    "-p",
                ]
                + proofs
                + self.make_j()
            )
            if self.args.output_result_json:
                cmd.extend(["--output-result-json", self.args.output_result_json])
            p = subprocess.run(
                cmd,
                cwd="proofs/cbmc",
                env=os.environ.copy() | envvars,
            )

            if p.returncode != 0:
                self.fail(f"CBMC proofs for parameter set={mldsa_parameter_set}")

        mldsa_parameter_set = self.args.mldsa_parameter_set
        if mldsa_parameter_set == "ALL":
            run_cbmc("44")
            run_cbmc("65")
            run_cbmc("87")
        else:
            run_cbmc(mldsa_parameter_set)

        self.check_fail()

    def hol_light(self):

        if platform.machine().lower() in ["arm64", "aarch64"]:
            # TODO: Skip HOL-Light proofs on arm64/aarch64 until the first proof is added
            # arch = "aarch64"
            return
        elif platform.machine().lower() in ["x86_64"]:
            arch = "x86_64"
        else:
            return

        def list_proofs(arch):
            cmd_str = ["./proofs/hol_light/" + arch + "/list_proofs.sh"]
            p = subprocess.run(cmd_str, capture_output=True, universal_newlines=False)
            proofs = filter(lambda s: s.strip() != "", p.stdout.decode().split("\n"))
            return list(proofs)

        if self.args.list_functions:
            for p in list_proofs(arch):
                print(p)
            exit(0)

        def run_hol_light_single_step(proofs, arch):
            num_proofs = len(proofs)
            for i, func in enumerate(proofs):
                log = logger(f"HOL_LIGHT ({i+1}/{num_proofs})", None, None, None)
                log.info(f"Starting HOL-Light proof for {func}")
                start = time.time()
                proof_bin = f"mldsa/{func}.native"
                proof_target = f"mldsa/{func}.correct"
                proof_dir = "proofs/hol_light/" + arch
                # Remove intermediate proof files to force-rerun
                try:
                    os.remove(os.path.join(proof_dir, proof_bin))
                    os.remove(os.path.join(proof_dir, proof_target))
                except FileNotFoundError:
                    pass
                p = subprocess.run(
                    [
                        "make",
                        f"mldsa/{func}.correct",
                    ]
                    + self.make_j(),
                    cwd="proofs/hol_light/" + arch,
                    env=os.environ.copy(),
                    capture_output=(self.args.verbose is False),
                )

                end = time.time()
                dur = int(end - start)
                if p.returncode != 0:
                    log.error(f"   FAILED (after {dur}s)")
                    if p.stderr is not None:
                        log.error(p.stderr.decode())
                    self.fail(f"HOL-Light proof for {func}")
                else:
                    log.info(f"   SUCCESS (after {dur}s)")

        proofs = list_proofs(arch)
        if self.args.proof is not None:
            proofs = self.args.proof

        run_hol_light_single_step(proofs, arch)
        self.check_fail()


#
# Command line interface
#


def cli():
    common_parser = argparse.ArgumentParser(add_help=False)

    # Common arguments for all sub-commands
    common_parser.add_argument(
        "-v", "--verbose", help="Show verbose output or not", action="store_true"
    )
    common_parser.add_argument(
        "-cp", "--cross-prefix", help="Cross prefix for compilation", default=""
    )
    common_parser.add_argument(
        "--cflags", help="Extra cflags to passed in (e.g. '-mcpu=cortex-a72')"
    )
    common_parser.add_argument(
        "--ldflags", help="Extra ldflags to passed in (e.g. '-static')"
    )
    common_parser.add_argument(
        "-j",
        help="Number of jobs to be used for `make` invocations",
        default=os.cpu_count(),
    )

    # --auto / --no-auto
    auto_group = common_parser.add_mutually_exclusive_group()
    auto_group.add_argument(
        "--auto",
        action="store_true",
        dest="auto",
        help="Allow makefile to auto configure system specific preprocessor",
        default=True,
    )
    auto_group.add_argument(
        "--no-auto",
        action="store_false",
        dest="auto",
        help="Disallow makefile to auto configure system specific preprocessor",
    )

    common_parser.add_argument(
        "--opt",
        help="Determine whether to compile/run the opt/no_opt binary or both",
        choices=["ALL", "OPT", "NO_OPT"],
        type=str.upper,
        default="ALL",
    )

    common_parser.add_argument(
        "--fips202-aarch64-backend",
        help="Select FIPS202 AArch64 backend",
        choices=[
            "auto",
            "x1_scalar",
            "x1_v84a",
            "x2_v84a",
            "x4_v8a_scalar",
            "x4_v8a_v84a_scalar",
        ],
        default="auto",
        type=str,
    )

    # --run / --no-run
    run_group = common_parser.add_mutually_exclusive_group()
    run_group.add_argument(
        "--run", action="store_true", dest="run", help="Run the binaries", default=True
    )
    run_group.add_argument(
        "--no-run", action="store_false", dest="run", help="Do not run the binaries"
    )

    common_parser.add_argument(
        "-w", "--exec-wrapper", help="Run the binary with the user-customized wrapper"
    )
    common_parser.add_argument(
        "-r",
        "--run-as-root",
        default=False,
        action="store_true",
        help="Run the binary as root",
    )

    main_parser = argparse.ArgumentParser()

    cmd_subparsers = main_parser.add_subparsers(title="Commands", dest="cmd")

    # all arguments
    all_parser = cmd_subparsers.add_parser(
        "all", help="Run all tests (except benchmark for now)", parents=[common_parser]
    )

    all_parser.add_argument(
        "--check-namespace",
        help="Check namespacing of binaries",
        action="store_true",
        default=False,
    )

    func_group = all_parser.add_mutually_exclusive_group()
    func_group.add_argument(
        "--func", action="store_true", dest="func", help="Run func test", default=True
    )
    func_group.add_argument(
        "--no-func", action="store_false", dest="func", help="Do not run func test"
    )

    kat_group = all_parser.add_mutually_exclusive_group()
    kat_group.add_argument(
        "--kat", action="store_true", dest="kat", help="Run kat test", default=True
    )
    kat_group.add_argument(
        "--no-kat", action="store_false", dest="kat", help="Do not run kat test"
    )

    acvp_group = all_parser.add_mutually_exclusive_group()
    acvp_group.add_argument(
        "--acvp", action="store_true", dest="acvp", help="Run acvp test", default=True
    )
    acvp_group.add_argument(
        "--no-acvp", action="store_false", dest="acvp", help="Do not run acvp test"
    )

    unit_group = all_parser.add_mutually_exclusive_group()
    unit_group.add_argument(
        "--unit", action="store_true", dest="unit", help="Run unit test", default=True
    )
    unit_group.add_argument(
        "--no-unit", action="store_false", dest="unit", help="Do not run unit test"
    )

    examples_group = all_parser.add_mutually_exclusive_group()
    examples_group.add_argument(
        "--examples",
        action="store_true",
        dest="examples",
        help="Run examples",
        default=True,
    )
    examples_group.add_argument(
        "--no-examples",
        action="store_false",
        dest="examples",
        help="Do not run examples",
    )

    all_parser.add_argument(
        "--exclude-example",
        help="Exclude specific examples from running (can be used multiple times)",
        choices=[
            "bring_your_own_fips202",
            "bring_your_own_fips202_static",
            "custom_backend",
            "basic",
            "basic_deterministic",
            "basic_lowram",
            "monolithic_build",
            "monolithic_build_native",
            "monolithic_build_multilevel",
            "monolithic_build_multilevel_native",
            "multilevel_build",
            "multilevel_build_native",
        ],
        action="append",
        default=[],
    )

    stack_group = all_parser.add_mutually_exclusive_group()
    stack_group.add_argument(
        "--stack",
        action="store_true",
        dest="stack",
        help="Run stack analysis",
        default=False,
    )
    stack_group.add_argument(
        "--no-stack",
        action="store_false",
        dest="stack",
        help="Do not run stack analysis",
    )

    alloc_group = all_parser.add_mutually_exclusive_group()
    alloc_group.add_argument(
        "--alloc",
        action="store_true",
        dest="alloc",
        help="Run alloc tests",
        default=True,
    )
    alloc_group.add_argument(
        "--no-alloc",
        action="store_false",
        dest="alloc",
        help="Do not run alloc tests",
    )

    rng_fail_group = all_parser.add_mutually_exclusive_group()
    rng_fail_group.add_argument(
        "--rng-fail",
        action="store_true",
        dest="rng_fail",
        help="Run RNG failure tests",
        default=True,
    )
    rng_fail_group.add_argument(
        "--no-rng-fail",
        action="store_false",
        dest="rng_fail",
        help="Do not run RNG failure tests",
    )

    # acvp arguments
    acvp_parser = cmd_subparsers.add_parser(
        "acvp", help="Run ACVP client", parents=[common_parser]
    )
    acvp_parser.add_argument(
        "--version",
        default="v1.1.0.41",
        help="ACVP test vector version (default: v1.1.0.41)",
    )

    # examples arguments
    examples_parser = cmd_subparsers.add_parser(
        "examples", help="Run examples", parents=[common_parser]
    )

    examples_parser.add_argument(
        "-l",
        help="Explicitly list the examples to run; can be called multiple times",
        choices=[
            "bring_your_own_fips202",
            "bring_your_own_fips202_static",
            "custom_backend",
            "basic",
            "basic_deterministic",
            "basic_lowram",
            "monolithic_build",
            "monolithic_build_native",
            "monolithic_build_multilevel",
            "monolithic_build_multilevel_native",
            "multilevel_build",
            "multilevel_build_native",
        ],
        action="append",
    )

    # bench arguments
    bench_parser = cmd_subparsers.add_parser(
        "bench",
        help="Run the benchmarks for all parameter sets",
        parents=[common_parser],
    )

    bench_parser.add_argument(
        "-c",
        "--cycles",
        help="Method for counting clock cycles. PMU requires (user-space) access to the Arm Performance Monitor Unit (PMU). PERF requires a kernel with perf support. MAC works on some Apple platforms, at least Apple M1.",
        choices=["NO", "PMU", "PERF", "MAC"],
        type=str.upper,
        required=True,
    )
    bench_parser.add_argument(
        "-o", "--output", help="Path to output file in json format"
    )
    if platform.system() == "Darwin":
        bench_parser.add_argument(
            "-t",
            "--mac-taskpolicy",
            help="Run the program using the specified QoS clamp. Applies to MacOS only. Setting this flag to 'background' guarantees running on E-cores. This is an abbreviation of --exec-wrapper 'taskpolicy -c {mac_taskpolicy}'.",
            choices=["utility", "background", "maintenance"],
            type=str.lower,
        )
    bench_parser.add_argument(
        "--components",
        help="Benchmark low-level components",
        action="store_true",
        default=False,
    )
    size_parser = cmd_subparsers.add_parser(
        "size",
        help="Run the code size measurement for all object file",
        parents=[common_parser],
    )

    # cbmc arguments
    cbmc_parser = cmd_subparsers.add_parser(
        "cbmc",
        help="Run the CBMC proofs for all parameter sets",
        parents=[common_parser],
    )

    cbmc_parser.add_argument(
        "-kl",
        "--mldsa-parameter-set",
        help="MLDSA parameter set (MLD_CONFIG_PARAMETER_SET)",
        choices=["44", "65", "87", "ALL"],
        type=str.upper,
        default="ALL",
    )

    cbmc_parser.add_argument(
        "--single-step",
        help="Run one proof a time. This is useful for debugging",
        action="store_true",
        default=False,
    )

    cbmc_parser.add_argument(
        "--start-with",
        help="When --single-step is set, start with given proof and proceed in alphabetical order",
        default=None,
    )

    cbmc_parser.add_argument(
        "-p",
        "--proof",
        nargs="+",
        help='Space separated list of functions for which to run the CBMC proofs. Wildcard patterns "*" are allowed.',
        default=None,
    )

    cbmc_parser.add_argument(
        "--timeout",
        help="Timeout for individual CBMC proofs, in seconds",
        type=int,
        default=3600,
    )

    cbmc_parser.add_argument(
        "--per-proof-timeout",
        help="Timeout for each individual CBMC proof passed to run-cbmc-proofs.py, in seconds (default: 1800)",
        type=int,
        default=1800,
    )

    cbmc_parser.add_argument(
        "-f",
        "--fail-upon-error",
        help="Stop upon first CBMC proof failure",
        action="store_true",
        default=False,
    )

    cbmc_parser.add_argument(
        "-l",
        "--list-functions",
        help="Don't run any proofs, but list all functions for which CBMC proofs are available",
        action="store_true",
        default=False,
    )

    cbmc_parser.add_argument(
        "--output-result-json",
        help="Path to export result JSON",
        default=None,
    )

    # hol_light arguments
    hol_light_parser = cmd_subparsers.add_parser(
        "hol_light",
        help="Run the HOL_LIGHT proofs for all parameter sets",
        parents=[common_parser],
    )

    hol_light_parser.add_argument(
        "-p",
        "--proof",
        nargs="+",
        help="Space separated list of functions for which to run the HOL_LIGHT proofs.",
        default=None,
    )

    hol_light_parser.add_argument(
        "-l",
        "--list-functions",
        help="Don't run any proofs, but list all functions for which HOL_LIGHT proofs are available",
        action="store_true",
        default=False,
    )

    # func arguments
    func_parser = cmd_subparsers.add_parser(
        "func",
        help="Run the functional tests for all parameter sets",
        parents=[common_parser],
    )
    func_parser.add_argument(
        "--check-namespace",
        help="Check namespacing of binaries",
        action="store_true",
        default=False,
    )

    # kat arguments
    kat_parser = cmd_subparsers.add_parser(
        "kat", help="Run the kat tests for all parameter sets", parents=[common_parser]
    )

    # unit arguments
    unit_parser = cmd_subparsers.add_parser(
        "unit",
        help="Run the unit tests for all parameter sets",
        parents=[common_parser],
    )

    # stack arguments
    stack_parser = cmd_subparsers.add_parser(
        "stack",
        help="Analyze stack usage for all parameter sets",
        parents=[common_parser],
    )
    stack_parser.add_argument(
        "--peak-only",
        action="store_true",
        help="Show only runtime peak stack usage (skip per-function analysis)",
        default=False,
    )
    stack_parser.add_argument(
        "--dump-massif",
        action="store_true",
        help="Dump full massif log for debugging",
        default=False,
    )

    # alloc arguments
    alloc_parser = cmd_subparsers.add_parser(
        "alloc",
        help="Run the alloc tests for all parameter sets",
        parents=[common_parser],
    )

    # rng_fail arguments
    rng_fail_parser = cmd_subparsers.add_parser(
        "rng_fail",
        help="Run the RNG failure tests for all parameter sets",
        parents=[common_parser],
    )

    args = main_parser.parse_args()

    if not hasattr(args, "mac_taskpolicy"):
        args.mac_taskpolicy = None
    if not hasattr(args, "l"):
        args.l = None

    os.chdir(os.path.join(os.path.dirname(__file__), ".."))

    if args.cmd == "all":
        Tests(args).all()
    elif args.cmd == "examples":
        Tests(args).examples()
    elif args.cmd == "acvp":
        Tests(args).acvp()
    elif args.cmd == "bench":
        Tests(args).bench()
    elif args.cmd == "cbmc":
        Tests(args).cbmc()
    elif args.cmd == "hol_light":
        Tests(args).hol_light()
    elif args.cmd == "func":
        Tests(args).func()
    elif args.cmd == "kat":
        Tests(args).kat()
    elif args.cmd == "unit":
        Tests(args).unit()
    elif args.cmd == "stack":
        Tests(args).stack()
    elif args.cmd == "size":
        Tests(args).size()
    elif args.cmd == "alloc":
        Tests(args).alloc()
    elif args.cmd == "rng_fail":
        Tests(args).rng_fail()


if __name__ == "__main__":
    cli()
