#!/usr/bin/env python3
# Copyright (c) The mlkem-native project authors
# Copyright (c) The mldsa-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

import subprocess
import tempfile
import platform
import argparse
import shutil
import pathlib
import re
import sys
import threading
import pyparsing as pp
import os
import yaml
import time

from concurrent.futures import ThreadPoolExecutor
from functools import partial
from rich.console import Console
from rich.progress import (
    Progress,
    BarColumn,
    TextColumn,
    TaskProgressColumn,
    TimeElapsedColumn,
)

console = Console()

# Global progress bar - initialized in _main()
_progress = None
_main_task = None
_current_task = ""

modulus = 8380417
root_of_unity = 1753
montgomery_factor = pow(2, 32, modulus)

# Compiled regex patterns
_RE_DEFINED = re.compile(r"defined\(([^)]+)\)")
_RE_MARKDOWN_CITE = re.compile(r"\[\^(?P<id>\w+)\]")
_RE_C_CITE = re.compile(r"@\[(?P<id>\w+)")
_RE_BYTECODE_START = re.compile(
    r"=== bytecode start: (?:aarch64|x86_64)/mldsa/([^/\s]+?)\.o"
)
_RE_FUNC_SYMBOL = re.compile(r"MLD_ASM_FN_SYMBOL\((.*)\)")
_RE_MACRO_CHECK = re.compile(r"[^_]((?:MLD_|MLDSA_)\w+)(.*)$", re.M)
_RE_MLKEM_MACRO_CHECK = re.compile(r"[^_]((?:MLK_|MLKEM_)\w+)(.*)$", re.M)
_RE_DEFINE = re.compile(r"^\s*#define\s+(\w+)")
_RE_ARGS_COMMENT = re.compile(r"(.*?)(\s*//.*)?$")
_RE_MACRO_DEF = re.compile(r"^\s*\.macro\s+(\w+)")
_RE_MACRO_DEF_ARGS = re.compile(r"^(\s*\.macro\s+\w+)(\s+.*)$")
_RE_LEADING_SPACE = re.compile(r"^(\s*)")

# File cache: {filename: {"content": str, "original": str, "force_format": bool}}
# Caches content of files in preparation/modification to avoid repeated
# read/writes to the file system.
_file_cache = {}
_file_cache_lock = threading.Lock()

_errors = []
_errors_lock = threading.Lock()

_progress_lock = threading.Lock()


def read_file(filename, original=False):
    """Read file content, using cache if available"""
    with _file_cache_lock:
        if filename in _file_cache:
            key = "content" if original is False else "original"
            return _file_cache[filename][key]

        with open(filename, "r") as f:
            content = f.read()
        _file_cache[filename] = {
            "content": content,
            "original": content,
            "force_format": False,
        }
        return content


def update_file(filename, content, force_format=False):
    """Write file content to cache"""
    with _file_cache_lock:
        if filename not in _file_cache:
            try:
                with open(filename, "r") as f:
                    original = f.read()
                _file_cache[filename] = {"original": original}
            except FileNotFoundError:
                _file_cache[filename] = {"original": None}

        e = _file_cache[filename]
        e["content"] = content
        e["force_format"] = e.get("force_format", False) or force_format


def finalize_format_batch(batch):
    """Format a batch of files by passing to clang-format with -i flag"""
    if not batch:
        return

    # Create temp files for each filename in batch
    temp_files = []
    try:
        for filename in batch:
            content = read_file(filename)
            # Skip files scheduled for deletion
            if content is None:
                continue
            with tempfile.NamedTemporaryFile(mode="w", suffix=".c", delete=False) as f:
                f.write(content)
                temp_files.append((f.name, filename))

        # Call clang-format with -i to update files in-place
        clang_format_file = os.path.join(
            os.path.dirname(__file__), "..", ".clang-format"
        )
        p = subprocess.run(
            ["clang-format", "-i", f"-style=file:{clang_format_file}"]
            + [t[0] for t in temp_files],
            capture_output=True,
            text=True,
        )
        if p.returncode != 0:
            print(p.stderr)
            print(
                f"Failed to auto-format autogenerated code (clang-format return code {p.returncode}). Are you running in a nix shell? See CONTRIBUTING.md."
            )
            exit(1)

        # Read formatted files back and update cache
        for temp_path, filename in temp_files:
            with open(temp_path, "r") as f:
                update_file(filename, f.read())
    finally:
        for temp_path, _ in temp_files:
            os.unlink(temp_path)


def finalize_file(item, dry_run):
    """Write a single file or delete it if content is None"""
    filename, data = item

    content_old = data["original"]
    content_new = data["content"]

    if content_old == content_new:
        return

    # Handle deletion (content_new is None)
    if content_new is None:
        if dry_run is False:
            file_updated(filename, removed=True)
            os.remove(filename)
        else:
            error(filename, None)
        return

    if dry_run is False:
        file_updated(filename)
        with open(filename, "w") as f:
            f.write(content_new)
    else:
        filename_new = f"{filename}.new"
        with open(filename_new, "w") as f:
            f.write(content_new)
        error(filename, filename_new)


def format_files(dry_run):
    """Apply formatting to files"""
    to_format = [
        filename
        for filename, data in _file_cache.items()
        if data["force_format"] or filename.endswith((".c", ".h", ".i"))
    ]

    # Group into batches of max 20
    batch_size = 20
    batches = [
        to_format[i : i + batch_size] for i in range(0, len(to_format), batch_size)
    ]

    run_parallel(batches, finalize_format_batch)


def finalize(dry_run):
    """Write dirty files to filesystem"""
    run_parallel(_file_cache.items(), partial(finalize_file, dry_run=dry_run))


# This file re-generated auto-generated source files in mldsa-native.
#
# It currently covers:
# - zeta values for the reference NTT and invNTT
# - lookup tables used for fast rejection sampling
# - source files for monolithic single-CU build
# - simplified assembly sources
# - header guards
# - #undef's for CU-local macros

_step_start_time = time.time()


def high_level_task(msg):
    """Set the current high-level task description"""
    global _current_task
    _current_task = msg
    if _progress:
        _progress.update(_main_task, description=f"[cyan]{msg}[/]")


def high_level_status(msg, skipped=False):
    """Complete a high-level step and print status"""
    global _step_start_time
    elapsed = time.time() - _step_start_time
    if skipped:
        symbol = "[dim]–[/dim]"
    else:
        symbol = "[green]✓[/green]"
    if _progress:
        _progress.print(f"{symbol} {msg} ({elapsed:.1f}s)", highlight=False)
        _progress.advance(_main_task)
    else:
        console.print(f"{symbol} {msg} ({elapsed:.1f}s)", highlight=False)
    _step_start_time = time.time()


def run_parallel(files, func):
    """Run func over files in parallel with progress tracking"""
    if not files:
        return []

    files = list(files)
    total = len(files)
    state = {"completed": 0, "last_file": ""}

    def update_progress():
        if _progress and total > 0:
            suffix = (
                f" {os.path.basename(state['last_file'])}" if state["last_file"] else ""
            )
            _progress.update(
                _main_task,
                description=f"[cyan]{_current_task}[/] [dim][{state['completed']}/{total}]{suffix}[/]",
            )

    def wrapped(f):
        result = func(f)
        with _progress_lock:
            state["completed"] += 1
            state["last_file"] = str(f[0]) if isinstance(f, tuple) else str(f)
            update_progress()
        return result

    with ThreadPoolExecutor() as executor:
        return list(executor.map(wrapped, files))


def error(filename, filename_new):
    with _errors_lock:
        _errors.append((filename, filename_new))


def print_check_errors():
    for filename, filename_new in _errors:
        console.print(f"[red]error[/] {filename}")
        if filename_new is not None:
            console.print(
                f"Autogenerated file {filename} needs updating. Have you called scripts/autogen? Wrote new version to {filename_new}."
            )
            if os.path.exists(filename):
                subprocess.run(["diff", filename, filename_new])
        else:
            console.print(
                f"Autogenerated file {filename} needs removing. Have you called scripts/autogen?"
            )

    return len(_errors) == 0


def file_updated(filename, removed=False):
    if removed is False:
        console.print(f"[bold]updated {filename}[/]")
    else:
        console.print(f"[bold]removed {filename}[/]")


def gen_autogen_warning():
    yield ""
    yield "/*"
    yield " * WARNING: This file is auto-generated from scripts/autogen"
    yield " *          in the mldsa-native repository."
    yield " *          Do not modify it directly."
    yield " */"


def gen_header():
    yield "/*"
    yield " * Copyright (c) The mldsa-native project authors"
    yield " * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT"
    yield " */"
    yield from gen_autogen_warning()
    yield ""


def gen_hol_light_header():
    yield "(*"
    yield " * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved."
    yield " * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0"
    yield " *)"
    yield ""
    yield "(*"
    yield " * WARNING: This file is auto-generated from scripts/autogen"
    yield " *          in the mldsa-native repository."
    yield " *          Do not modify it directly."
    yield " *)"
    yield ""


def gen_yaml_header():
    yield "# Copyright (c) The mldsa-native project authors"
    yield "# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT"
    yield ""


def format_content(content):
    clang_format_file = os.path.join(os.path.dirname(__file__), "..", ".clang-format")
    p = subprocess.run(
        ["clang-format", f"-style=file:{clang_format_file}"],
        capture_output=True,
        input=content,
        text=True,
    )
    if p.returncode != 0:
        print(p.stderr)
        print(
            f"Failed to auto-format autogenerated code (clang-format return code {p.returncode}). Are you running in a nix shell? See CONTRIBUTING.md."
        )
        exit(1)
    return p.stdout


class CondParser:
    """Rudimentary parser for expressions if `#if .. #else ..` directives"""

    def __init__(self):
        c_identifier = pp.common.identifier()
        c_integer_suffix = pp.one_of("U L LU UL LL ULL LLU", caseless=True)
        c_dec_integer = pp.Combine(
            pp.Optional(pp.one_of("+ -"))
            + pp.Word(pp.nums)
            + pp.Optional(c_integer_suffix)
        )
        c_hex_integer = pp.Combine(
            pp.Literal("0x") + pp.Word(pp.hexnums) + pp.Optional(c_integer_suffix)
        )

        self.parser = pp.infix_notation(
            c_identifier | c_hex_integer | c_dec_integer,
            [
                (pp.one_of("!"), 1, pp.opAssoc.RIGHT),
                (pp.one_of("!= == <= >= > <"), 2, pp.opAssoc.LEFT),
                (pp.one_of("&&"), 2, pp.opAssoc.LEFT),
                (pp.one_of("||"), 2, pp.opAssoc.LEFT),
            ],
        )

    @staticmethod
    def connective(res):
        """Extract the top-level connective for the expression"""
        if not isinstance(res, list):
            return None
        elif len(res) == 2:
            # Unary operator (will be "!" in our case)
            return res[0]
        else:
            # Binary operator
            return res[1]

    @staticmethod
    def map_top(f, res):
        """Apply function to arguments of top-level connective"""
        if not isinstance(res, list):
            return res
        else:
            # We expect `f` to do nothing on strings, so it is safe
            # to apply it everywhere, including the connectives.
            return list(map(f, res))

    @staticmethod
    def args(res):
        """Assuming the argument is a binary operation, return all arguments"""
        return res[::2]

    @staticmethod
    def simplify_double_negation(res):
        """Cancel double negations"""
        if CondParser.connective(res) == "!" and CondParser.connective(res[1]) == "!":
            res = res[1][1]
        res = CondParser.map_top(CondParser.simplify_double_negation, res)
        return res

    @staticmethod
    def simplify_not_eq(res):
        """Replace !(x == y) by x != y, and !(x != y) by x == y"""
        if CondParser.connective(res) == "!" and CondParser.connective(res[1]) == "==":
            res = res[1]
            res[1] = "!="
        if CondParser.connective(res) == "!" and CondParser.connective(res[1]) == "!=":
            res = res[1]
            res[1] = "=="
        res = CondParser.map_top(CondParser.simplify_not_eq, res)
        return res

    @staticmethod
    def simplify_neq_chain(res):
        """Check for &&-chains of inequalities followed by an equality
        which implies the inequality. This catches patterns like
        ```
         #if MLKEM_K == 2
         ...
         #elif MLKEM_K == 3
         ...
         #elif MLKEM_K == 4
         ...
         #endif
        ```
        """
        if (
            CondParser.connective(res) == "&&"
            and CondParser.connective(res[-1]) == "=="
        ):
            lhs = res[-1][0]
            rhs = res[-1][2]
            args = []
            for a in CondParser.args(res[:-1]):
                if CondParser.connective(a) == "!=" and a[0] == lhs:
                    args.append(a[2])
                else:
                    args = None
                    break
            if args is None:
                return res
            # Check if all args are numerical and different
            if rhs.isdigit() and all(
                map(lambda a: a.isdigit() and int(a) != int(rhs), args)
            ):
                # Success -- just drop all but the final condition
                return res[-1]
        res = CondParser.map_top(CondParser.simplify_neq_chain, res)
        return res

    @staticmethod
    def print_exp(exp, inner=False):
        conn = CondParser.connective(exp)
        if conn is None:
            return exp
        elif conn == "!":
            res = f"!{CondParser.print_exp(exp[1], inner=True)}"
        else:
            padded_conn = f" {conn} "
            res = padded_conn.join(
                map(lambda e: CondParser.print_exp(e, inner=True), CondParser.args(exp))
            )
        if inner is True and conn in ["&&", "||"]:
            res = f"({res})"
        return res

    def simplify_assoc(exp):
        """Check for unnecesary bracketing and remove it"""
        conn = CondParser.connective(exp)
        if conn in ["&&", "||"]:
            args = CondParser.args(exp)
            new_args = []
            for a in args:
                if CondParser.connective(a) == conn:
                    new_args += CondParser.args(a)
                else:
                    new_args.append(a)
            exp = [x for y in map(lambda x: [x, conn], new_args) for x in y][:-1]
        exp = CondParser.map_top(CondParser.simplify_assoc, exp)
        return exp

    def simplify_all(exp):
        exp = CondParser.simplify_double_negation(exp)
        exp = CondParser.simplify_not_eq(exp)
        exp = CondParser.simplify_neq_chain(exp)
        exp = CondParser.simplify_assoc(exp)
        return exp

    def parse_condition(self, exp, simplify=True):
        try:
            exp = self.parser.parseString(exp, parseAll=True).as_list()[0]
        except pp.ParseException:
            print(f"WARNING: Ignoring condition '{exp}' I cannot parse")
            return exp
        if simplify is True:
            exp = CondParser.simplify_all(exp)
        return exp

    def normalize_condition(self, exp):
        return CondParser.print_exp(self.parse_condition(exp))


def adjust_preprocessor_comments_for_filename(
    content, source_file, parser, show_status=False
):
    """Automatically add comments to large `#if ... #else ... #endif`
    blocks indicating the guarding conditions.

    For example, a block

    ```c
      #if FOO
      ...
      #else
      ...
      #endif
    ```

    will be transformed into


    ```c
      #if FOO
      ...
      #else /* FOO */
      ...
      #endif /* !FOO */
    ```

    except when the distance between the preprocessor directives is
    very short, and the annotations would be more harmful than useful.

    ```
    """

    content = content.split("\n")
    new_content = []

    # Stack of `#if` statements. Every entry is a tuple
    # `(conds, line_no, if_or_else, has_children)`, where
    # - `conds` is the list of conditions being tested.
    #   In a normal `#if ... #else ...` braach, this is a singleton list
    #   containing the condition being tested. In a chain of
    #   `#if .. #elif ..` it contains all conditions encountered to this point.
    # - `line_no` is the line where it started
    # - `if_or_else` indicates whether we are in the `#if`
    #   or the `#else` branch (if present)
    # - `force_print` indicates if a comment should be omitted
    if_stack = []

    def merge_escaped_lines(l, i):
        while l.endswith("\\"):
            l = l.removesuffix("\\").rstrip() + content[i + 1].lstrip()
            i = i + 1
        return (l, i)

    def merge_commented_lines(l, i):
        # Not very robust, but good enough
        if not "/*" in l or "*/" in l:
            return (l, i)
        i += 1
        while "*/" not in content[i]:
            l += content[i]
            i += 1

        l += content[i]
        return (l, i)

    def should_print(cur_line_no, conds, line_no, force_print):
        line_threshold = 5
        if force_print is True:
            return True

        if cur_line_no - line_no >= line_threshold:
            return True
        return False

    def format_condition(cond):
        cond = _RE_DEFINED.sub(r"\1", cond)
        return parser.normalize_condition(cond)

    def format_conditions(conds, branch):
        prev_conds = list(map(lambda s: f"!({s})", conds[:-1]))
        final_cond = conds[-1]
        if branch is False:
            final_cond = f"!({final_cond})"
        full_cond = "&&".join(prev_conds + [final_cond])
        return format_condition(full_cond)

    def wrap_long_directive(directive, condition, max_len=80):
        """Manually wrap long preprocessor comment lines without subprocess overhead"""
        single_line = directive + " " + condition
        if len(single_line) <= max_len:
            return single_line

        # Wrap condition across multiple lines with backslash continuation
        words = condition.split()
        lines = []
        current = f"{directive} "
        indent = (len(directive) + 4) * " "
        indent_final = (len(directive) + 2) * " "

        for word in words:
            if len(current) + len(word) + 1 <= max_len:
                current += word + " "
            else:
                lines.append(current.rstrip() + " \\")
                if word == "*/":
                    current = indent_final + word
                else:
                    current = indent + word + " "

        lines.append(current.rstrip())

        return "\n".join(lines)

    def adhoc_format(directive, content):
        # .c and .h files are formatted as a whole
        if not source_file.endswith(".S"):
            return directive + " /* " + content + " */"
        # For .S files, manually wrap long lines
        return wrap_long_directive(directive, "/* " + content + " */")

    i = 0
    while i < len(content):
        l = content[i].strip()
        # Replace #ifdef by #if defined(...)
        if l.startswith("#ifdef "):
            l = "#if defined(" + l.removeprefix("#ifdef").strip() + ")"
        if l.startswith("#ifndef "):
            l = "#if !defined(" + l.removeprefix("#ifndef").strip() + ")"
        if l.startswith("#if"):
            l, _ = merge_escaped_lines(l, i)
            cond = l.removeprefix("#if")
            if_stack.append(([cond], i, True, False))
            new_content.append(content[i])
        elif l.startswith("#elif"):
            conds, _, _, force_print = if_stack.pop()
            l, _ = merge_escaped_lines(l, i)
            conds.append(l.removeprefix("#elif"))
            if_stack.append((conds, i, True, force_print))
            new_content.append(content[i])
        elif l.startswith("#else"):
            l, i = merge_escaped_lines(l, i)
            _, i = merge_commented_lines(l, i)
            conds, j, branch, force_print = if_stack.pop()
            assert branch is True
            print_else = should_print(i, cond, j, force_print)
            if_stack.append((conds, i, False, print_else))
            if print_else is True:
                cond = format_conditions(conds, True)
                new_content.append(adhoc_format("#else", cond))
            else:
                new_content.append("#else")
        elif l.startswith("#endif"):
            l, i = merge_escaped_lines(l, i)
            _, i = merge_commented_lines(l, i)
            conds, j, branch, force_print = if_stack.pop()
            print_endif = should_print(i, conds, j, force_print)
            if print_endif is False:
                new_content.append("#endif")
            else:
                cond = format_conditions(conds, branch)
                new_content.append(adhoc_format("#endif", cond))
        else:
            # Skip over multiline comments -- we don't want to
            # handle `#if ...` inside documentation as this would
            # lead to nested `/* ... */`.
            i_old = i
            _, i = merge_commented_lines(l, i_old)
            new_content += content[i_old : i + 1]
        i += 1

    return "\n".join(new_content)


def gen_preprocessor_comments_for(parser, source_file):
    content = read_file(source_file)
    new_content = adjust_preprocessor_comments_for_filename(
        content, source_file, parser, show_status=True
    )
    update_file(source_file, new_content)


def gen_preprocessor_comments():
    files = get_c_source_files() + get_asm_source_files() + get_header_files()
    parser = CondParser()
    run_parallel(files, partial(gen_preprocessor_comments_for, parser))


def bitreverse(i, n):
    r = 0
    for _ in range(n):
        r = 2 * r + (i & 1)
        i >>= 1
    return r


def signed_reduce(a):
    """Return signed canonical representative of a mod b"""
    c = a % modulus
    if c >= modulus / 2:
        c -= modulus
    return c


def gen_c_zetas():
    """Generate source and header file for zeta values used in
    the reference NTT and invNTT"""

    # The zeta values are the powers of the chosen root of unity (17),
    # converted to Montgomery form.

    zeta = [0]  # First entry is unused and set to 0
    for i in range(1, 256):
        zeta.append(signed_reduce(pow(root_of_unity, i, modulus) * montgomery_factor))

    # The source code stores the zeta table in bit reversed form
    yield from (zeta[bitreverse(i, 8)] for i in range(256))


def gen_c_zeta_file():
    def gen():
        yield from gen_header()
        yield "#include <stdint.h>"
        yield ""
        yield "/*"
        yield " * Table of zeta values used in the reference NTT and inverse NTT."
        yield " * See autogen for details."
        yield " */"
        yield "static const int32_t mld_zetas[MLDSA_N] = {"
        yield from map(lambda t: str(t) + ",", gen_c_zetas())
        yield "};"
        yield ""

    update_file("mldsa/src/zetas.inc", "\n".join(gen()), force_format=True)


def prepare_root_for_barrett(root):
    """Takes a constant that the code needs to Barrett-multiply with,
    and returns the pair of (a) its signed canonical form, (b) the
    twisted constant used in the high-mul part of the Barrett multiplication."""

    # Signed canonical reduction
    root = signed_reduce(root)

    def round_to_even(t):
        rt = round(t)
        if rt % 2 == 0:
            return rt
        # Make sure to pick a rounding target
        # that's <= 1 away from x in absolute value.
        if rt <= t:
            return rt + 1
        return rt - 1

    root_twisted = round_to_even((root * 2**32) / modulus) // 2
    return root, root_twisted


def gen_aarch64_root_of_unity_for_block(layer, block, inv=False, scale=False):
    # We are computing a negacyclic NTT; the twiddles needed here is
    # the second half of the twiddles for a cyclic NTT of twice the size.
    # For ease of calculating the roots, layers are numbers 0 through 7
    # in this function.
    log = bitreverse(pow(2, layer) + block, 8)
    if inv is True:
        log = -log
    root = pow(root_of_unity, log, modulus)

    if scale is True:
        # Integrate scaling by 2**(-8) and Montgomery factor 2**32 into twiddle
        root = root * pow(2, 32 - 8, modulus)

    root, root_twisted = prepare_root_for_barrett(root)
    return root, root_twisted


def gen_aarch64_fwd_ntt_zetas_layer123456():
    # Layers 1,2,3 are merged
    yield from gen_aarch64_root_of_unity_for_block(0, 0)
    yield from gen_aarch64_root_of_unity_for_block(1, 0)
    yield from gen_aarch64_root_of_unity_for_block(1, 1)
    yield from gen_aarch64_root_of_unity_for_block(2, 0)
    yield from gen_aarch64_root_of_unity_for_block(2, 1)
    yield from gen_aarch64_root_of_unity_for_block(2, 2)
    yield from gen_aarch64_root_of_unity_for_block(2, 3)
    yield from (0, 0)  # Padding

    # Layers 4,5,6 are merged
    for block in range(8):  # There are 8 blocks in Layer 4
        yield from gen_aarch64_root_of_unity_for_block(3, block)
        yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 0)
        yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 1)
        yield from gen_aarch64_root_of_unity_for_block(5, 4 * block + 0)
        yield from gen_aarch64_root_of_unity_for_block(5, 4 * block + 1)
        yield from gen_aarch64_root_of_unity_for_block(5, 4 * block + 2)
        yield from gen_aarch64_root_of_unity_for_block(5, 4 * block + 3)
        yield from (0, 0)  # Padding


def gen_aarch64_fwd_ntt_zetas_layer78():
    # Layers 4,5,6,7,8 are merged, but we emit roots for 4,5,6
    # in separate arrays than those for 7,8
    for block in range(8):

        # Ordering of blocks is adjusted to suit the transposed internal
        # presentation of the data

        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(6, 8 * block + 0)[i]
            yield gen_aarch64_root_of_unity_for_block(6, 8 * block + 1)[i]
            yield gen_aarch64_root_of_unity_for_block(6, 8 * block + 2)[i]
            yield gen_aarch64_root_of_unity_for_block(6, 8 * block + 3)[i]

        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 0)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 2)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 4)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 6)[i]

        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 1)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 3)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 5)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 7)[i]

        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(6, 8 * block + 0 + 4)[i]
            yield gen_aarch64_root_of_unity_for_block(6, 8 * block + 1 + 4)[i]
            yield gen_aarch64_root_of_unity_for_block(6, 8 * block + 2 + 4)[i]
            yield gen_aarch64_root_of_unity_for_block(6, 8 * block + 3 + 4)[i]

        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 0 + 8)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 2 + 8)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 4 + 8)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 6 + 8)[i]

        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 1 + 8)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 3 + 8)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 5 + 8)[i]
            yield gen_aarch64_root_of_unity_for_block(7, 16 * block + 7 + 8)[i]


def gen_aarch64_intt_zetas_layer78():
    for block in range(16):
        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(6, block * 4 + 0, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(6, block * 4 + 1, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(6, block * 4 + 2, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(6, block * 4 + 3, inv=True)[i]

        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(7, block * 8 + 0, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(7, block * 8 + 2, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(7, block * 8 + 4, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(7, block * 8 + 6, inv=True)[i]

        for i in range(2):
            yield gen_aarch64_root_of_unity_for_block(7, block * 8 + 1, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(7, block * 8 + 3, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(7, block * 8 + 5, inv=True)[i]
            yield gen_aarch64_root_of_unity_for_block(7, block * 8 + 7, inv=True)[i]


def gen_aarch64_intt_zetas_layer123456():
    for i in range(16):
        yield from gen_aarch64_root_of_unity_for_block(4, i, inv=True)
        yield from (0, 0)  # Padding
        yield from gen_aarch64_root_of_unity_for_block(5, i * 2, inv=True)
        yield from gen_aarch64_root_of_unity_for_block(5, i * 2 + 1, inv=True)

    # The last layer has the scaling by 1/256 integrated in the twiddle
    yield from gen_aarch64_root_of_unity_for_block(0, 0, inv=True, scale=True)

    yield from gen_aarch64_root_of_unity_for_block(1, 0, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(1, 1, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(2, 0, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(2, 1, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(2, 2, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(2, 3, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(3, 0, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(3, 1, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(3, 2, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(3, 3, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(3, 4, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(3, 5, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(3, 6, inv=True)
    yield from gen_aarch64_root_of_unity_for_block(3, 7, inv=True)
    yield from (0, 0)  # Padding


def print_hol_light_array(g, as_int=True, entries_per_line=8, pad=0):

    # Format of integer list entries, including `;` separator:
    # - Positive numbers: &42;
    # - Negative numbers: -- &42;
    # If as_int is false, we omit `&` and emit constant as numerals.
    def format_hol_light_int(n):
        prefix = ""
        if n < 0:
            prefix = "-- "
            n = -n
        c = "&" if as_int is True else ""
        return f"{prefix}{c}{n:>{pad}};"

    l = list(map(format_hol_light_int, g))
    # Remove `;` from end of last entry
    l[-1] = l[-1][:-1]

    for i in range(0, len(l), entries_per_line):
        yield "  " + " ".join(l[i : i + entries_per_line])


def gen_aarch64_zeta_file():
    def gen():
        yield from gen_header()
        yield '#include "../../../common.h"'
        yield ""
        yield "#if defined(MLD_ARITH_BACKEND_AARCH64) && \\"
        yield "    !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)"
        yield ""
        yield "#include <stdint.h>"
        yield '#include "arith_native_aarch64.h"'
        yield ""
        yield "/*"
        yield " * Table of zeta values used in the AArch64 forward NTT"
        yield " * See autogen for details."
        yield " */"
        yield "MLD_ALIGN const int32_t mld_aarch64_ntt_zetas_layer123456[] = {"
        yield from map(lambda t: str(t) + ",", gen_aarch64_fwd_ntt_zetas_layer123456())
        yield "};"
        yield ""
        yield "MLD_ALIGN const int32_t mld_aarch64_ntt_zetas_layer78[] = {"
        yield from map(lambda t: str(t) + ",", gen_aarch64_fwd_ntt_zetas_layer78())
        yield "};"
        yield ""
        yield "MLD_ALIGN const int32_t mld_aarch64_intt_zetas_layer78[] = {"
        yield from map(lambda t: str(t) + ",", gen_aarch64_intt_zetas_layer78())
        yield "};"
        yield ""
        yield "MLD_ALIGN const int32_t mld_aarch64_intt_zetas_layer123456[] = {"
        yield from map(lambda t: str(t) + ",", gen_aarch64_intt_zetas_layer123456())
        yield "};"
        yield ""
        yield "#else"
        yield ""
        yield "MLD_EMPTY_CU(aarch64_zetas)"
        yield ""
        yield "#endif"
        yield ""

    update_file("dev/aarch64_opt/src/aarch64_zetas.c", "\n".join(gen()))

    update_file("dev/aarch64_clean/src/aarch64_zetas.c", "\n".join(gen()))


def gen_aarch64_rej_uniform_eta_table_rows():
    # The index into the lookup table is an 8-bit bitmap, i.e. a number 0..255.
    # Conceptually, the table entry at index i is a vector of 8 16-bit values, of
    # which only the first popcount(i) are set; those are the indices of the set-bits
    # in i. Concretely, we store each 16-bit index as consecutive 8-bit indices.
    def get_set_bits_idxs(i):
        bits = list(map(int, format(i, "08b")))
        bits.reverse()
        return [bit_idx for bit_idx in range(8) if bits[bit_idx] == 1]

    for i in range(256):
        idxs = get_set_bits_idxs(i)
        # Replace each index by two consecutive indices
        idxs = [j for i in idxs for j in [2 * i, 2 * i + 1]]
        # Pad by 255 (invalid index)
        idxs = idxs + [255] * (16 - len(idxs))
        yield idxs


def gen_aarch64_rej_uniform_eta_table():
    def gen():
        yield from gen_header()
        yield '#include "../../../common.h"'
        yield ""
        yield "#if defined(MLD_ARITH_BACKEND_AARCH64) && \\"
        yield "    !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)"
        yield ""
        yield "#include <stdint.h>"
        yield '#include "arith_native_aarch64.h"'
        yield ""
        yield "/*"
        yield " * Lookup table used by 16-bit rejection sampling (rej_eta)."
        yield " * Adapted from ML-KEM for ML-DSA eta rejection sampling."
        yield " * See autogen for details."
        yield " */"
        yield "MLD_ALIGN const uint8_t mld_rej_uniform_eta_table[] = {"
        for i, idxs in enumerate(gen_aarch64_rej_uniform_eta_table_rows()):
            yield ",".join(map(str, idxs)) + f" /* {i} */,"
        yield "};"
        yield ""
        yield "#else"
        yield ""
        yield "MLD_EMPTY_CU(aarch64_rej_uniform_eta_table)"
        yield ""
        yield "#endif"
        yield ""

    update_file("dev/aarch64_opt/src/rej_uniform_eta_table.c", "\n".join(gen()))

    update_file("dev/aarch64_clean/src/rej_uniform_eta_table.c", "\n".join(gen()))


def gen_aarch64_rej_uniform_table_rows():
    # The index into the lookup table is an 4-bit bitmap, i.e. a number 0..15.
    # Conceptually, the table entry at index i is a vector of 4-bit values, of
    # which only the first popcount(i) are set; those are the indices of the set-bits
    # in i. Concretely, we store each 32-bit index as consecutive 8-bit indices.
    def get_set_bits_idxs(i):
        bits = list(map(int, format(i, "08b")))
        bits.reverse()
        return [bit_idx for bit_idx in range(8) if bits[bit_idx] == 1]

    for i in range(16):
        idxs = get_set_bits_idxs(i)
        # Replace each index by two consecutive indices
        idxs = [j for i in idxs for j in [4 * i + k for k in range(4)]]
        # Pad by -1
        idxs = idxs + [255] * (16 - len(idxs))
        yield idxs


def gen_aarch64_rej_uniform_table():
    def gen():
        yield from gen_header()
        yield '#include "../../../common.h"'
        yield ""
        yield "#if defined(MLD_ARITH_BACKEND_AARCH64) && \\"
        yield "    !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)"
        yield ""
        yield "#include <stdint.h>"
        yield '#include "arith_native_aarch64.h"'
        yield ""
        yield "/*"
        yield " * Lookup table used by rejection sampling of the public matrix."
        yield " * See autogen for details."
        yield " */"
        yield "MLD_ALIGN const uint8_t mld_rej_uniform_table[] = {"
        for i, idxs in enumerate(gen_aarch64_rej_uniform_table_rows()):
            yield ",".join(map(str, idxs)) + f" /* {i} */,"
        yield "};"
        yield ""
        yield "#else"
        yield ""
        yield "MLD_EMPTY_CU(aarch64_rej_uniform_table)"
        yield ""
        yield "#endif"
        yield ""

    update_file("dev/aarch64_opt/src/rej_uniform_table.c", "\n".join(gen()))

    update_file("dev/aarch64_clean/src/rej_uniform_table.c", "\n".join(gen()))


def gen_avx2_rej_uniform_table_rows():
    # The index into the lookup table is an 8-bit bitmap, i.e. a number 0..255.
    # Conceptually, the table entry at index i is a vector of 8 16-bit values, of
    # which only the first popcount(i) are set; those are the indices of the set-bits
    # in i.
    def get_set_bits_idxs(i):
        bits = list(map(int, format(i, "08b")))
        bits.reverse()
        return [bit_idx for bit_idx in range(8) if bits[bit_idx] == 1]

    for i in range(256):
        idxs = get_set_bits_idxs(i)
        idxs = [i for i in idxs]
        # Pad by 0
        idxs = idxs + [0] * (8 - len(idxs))
        yield "{" + ",".join(map(str, idxs)) + "}"


def gen_avx2_rej_uniform_table():
    def gen():
        yield from gen_header()
        yield '#include "../../../common.h"'
        yield ""
        yield "#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \\"
        yield "    !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)"
        yield ""
        yield "#include <stdint.h>"
        yield '#include "arith_native_x86_64.h"'
        yield ""
        yield "/*"
        yield " * Lookup table used by rejection sampling."
        yield " * See autogen for details."
        yield " */"
        yield "MLD_ALIGN const uint8_t mld_rej_uniform_table[256][8] = {"
        yield from map(lambda t: str(t) + ",", gen_avx2_rej_uniform_table_rows())
        yield "};"
        yield ""
        yield "#else"
        yield ""
        yield "MLD_EMPTY_CU(avx2_rej_uniform_table)"
        yield ""
        yield "#endif"
        yield ""

    update_file("dev/x86_64/src/rej_uniform_table.c", "\n".join(gen()))


def signed_reduce(a):
    """Return signed canonical representative of a mod b"""
    c = a % modulus
    if c >= modulus / 2:
        c -= modulus
    return c


def signed_reduce_u32(a):
    """Return signed canonical representative of a mod b"""
    c = a % 2**32
    if c >= 2**31:
        c -= 2**32
    return c


def prepare_root_for_montmul(root, mult):
    """Takes a constant that the code needs to Montgomery-multiply with,
    and returns the pair of (a) the signed canonical representative of its
    Montgomery form, (b) the twisted constant used in the low-mul part of
    the Montgomery multiplication."""

    # Convert to Montgomery form and pick canonical signed representative
    root = signed_reduce(root * montgomery_factor)
    if mult:
        root = signed_reduce_u32(root * pow(modulus, -1, 2**32))
    return root


def gen_avx2_root_of_unity_for_block(layer, block, mult=False):
    # We are computing a negacyclic NTT; the twiddles needed here is
    # the second half of the twiddles for a cyclic NTT of twice the size.
    log = bitreverse(pow(2, layer) + block, 8)
    root = pow(root_of_unity, log, modulus)
    return prepare_root_for_montmul(root, mult)


def gen_avx2_fwd_ntt_zetas(mult=False):

    def gen_twiddles(layer, block, repeat, mult):
        root = gen_avx2_root_of_unity_for_block(layer, block, mult)
        return [root] * repeat

    def gen_twiddles_many(layer, block_base, block_offsets, repeat, mult):
        roots = list(
            map(
                lambda x: gen_twiddles(layer, block_base + x, repeat, mult),
                block_offsets,
            )
        )
        yield from (r for l in roots for r in l)

    # embed the scaling of 1/256 and correction of the Montgomery factor
    # from the basemul into last twiddle of the inverse NTT
    # - root^-128 * 2^64/256
    # In the forward NTT this twiddle is unused
    f = signed_reduce(-pow(root_of_unity, -128, modulus) * 2**56)
    if mult:
        f = signed_reduce_u32(f * pow(modulus, -1, 2**32))

    yield f

    # Layers 1 twiddle
    # In the inverse NTT this twiddle is unused
    yield from gen_twiddles_many(0, 0, range(1), 1, mult)

    # Layer 2-8 twiddles
    yield from gen_twiddles_many(1, 0, range(2), 1, mult)
    yield from gen_twiddles_many(2, 0, range(4), 1, mult)
    yield from gen_twiddles_many(3, 0, range(8), 4, mult)
    yield from gen_twiddles_many(4, 0, range(16), 2, mult)
    yield from gen_twiddles_many(5, 0, range(32), 1, mult)
    for i in range(32):
        yield from gen_twiddles_many(6, i * 2, range(1), 1, mult)
    for i in range(32):
        yield from gen_twiddles_many(6, i * 2 + 1, range(1), 1, mult)

    for k in range(4):
        for i in range(32):
            yield from gen_twiddles_many(7, i * 4 + k, range(1), 1, mult)


def gen_avx2_zetas_qdata():
    def cmod(a, mod):
        """Return signed canonical representative of a mod b"""
        c = a % mod
        if c >= mod / 2:
            c -= mod
        return c

    # MLD_AVX2_Q
    q_dup = [modulus] * 8
    # MLD_AVX2_QINV
    qinv_dup = [pow(modulus, -1, 2**32)] * 8
    # MLD_AVX2_DIV_QINV
    div_qinv_dup = [cmod(pow(modulus, -1, 2**32) * pow(2, 64 - 8, modulus), 2**32)] * 8
    # MLD_AVX2_DIV
    div_dup = [pow(2, 64 - 8, modulus)] * 8

    zetas_qinv = list(gen_avx2_fwd_ntt_zetas(mult=True))
    zetas = list(gen_avx2_fwd_ntt_zetas(mult=False))

    q_idx = 0
    qinv_idx = q_idx + len(q_dup)
    div_qinv_idx = qinv_idx + len(qinv_dup)
    div_idx = div_qinv_idx + len(div_qinv_dup)
    zetas_qinv_idx = div_idx + len(div_dup)
    zetas_idx = zetas_qinv_idx + len(zetas_qinv)

    constants = q_dup + qinv_dup + div_qinv_dup + div_dup + zetas_qinv + zetas
    offsets = {
        "q": q_idx,
        "qinv": qinv_idx,
        "div_qinv": div_qinv_idx,
        "div": div_idx,
        "zetas_qinv": zetas_qinv_idx,
        "zetas": zetas_idx,
    }
    return constants, offsets


def gen_avx2_hol_light_zeta_file():
    def gen():
        yield from gen_hol_light_header()
        yield "(*"
        yield " * Table of zeta values used in the AVX2 NTTs"
        yield " * See autogen for details."
        yield " *)"
        yield ""
        yield "let mldsa_complete_qdata = define `mldsa_complete_qdata:int list = ["
        constants, _ = gen_avx2_zetas_qdata()
        yield from print_hol_light_array(constants)
        yield "]`;;"
        yield ""

    update_file("proofs/hol_light/x86_64/proofs/mldsa_zetas.ml", "\n".join(gen()))


def gen_avx2_zeta_file():
    constants, offsets = gen_avx2_zetas_qdata()

    def gen_c():
        yield from gen_header()
        yield '#include "../../../common.h"'
        yield ""
        yield "#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \\"
        yield "    !defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)"
        yield ""
        yield '#include "consts.h"'
        yield ""
        yield "/*"
        yield " * Table of zeta values used in the AVX2 forward and inverse NTT"
        yield " * See autogen for details."
        yield " */"
        yield f"MLD_ALIGN const int32_t mld_qdata[{len(constants)}] = {{"
        yield from map(lambda t: str(t) + ",", constants)
        yield "};"
        yield ""
        yield "#else"
        yield ""
        yield "MLD_EMPTY_CU(avx2_consts)"
        yield ""
        yield "#endif"
        yield ""

    def gen_h():
        yield from gen_header()
        yield "#ifndef MLD_NATIVE_X86_64_SRC_CONSTS_H"
        yield "#define MLD_NATIVE_X86_64_SRC_CONSTS_H"
        yield '#include "../../../common.h"'
        yield f"#define MLD_AVX2_BACKEND_DATA_OFFSET_8XQ {offsets['q']}"
        yield f"#define MLD_AVX2_BACKEND_DATA_OFFSET_8XQINV {offsets['qinv']}"
        yield f"#define MLD_AVX2_BACKEND_DATA_OFFSET_8XDIV_QINV {offsets['div_qinv']}"
        yield f"#define MLD_AVX2_BACKEND_DATA_OFFSET_8XDIV {offsets['div']}"
        yield f"#define MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS_QINV {offsets['zetas_qinv']}"
        yield f"#define MLD_AVX2_BACKEND_DATA_OFFSET_ZETAS {offsets['zetas']}"
        yield ""
        yield "#ifndef __ASSEMBLER__"
        yield "#define mld_qdata MLD_NAMESPACE(qdata)"
        yield f"extern const int32_t mld_qdata[{len(constants)}];"
        yield "#endif"
        yield ""
        yield "#endif"
        yield ""

    update_file("dev/x86_64/src/consts.c", "\n".join(gen_c()))
    update_file("dev/x86_64/src/consts.h", "\n".join(gen_h()))


def get_c_source_files(main_only=False, core_only=False, strip_mldsa=False):
    if main_only is True:
        return get_files("mldsa/src/**/*.c", strip_mldsa=strip_mldsa)
    elif core_only is True:
        return get_files("mldsa/src/**/*.c", strip_mldsa=strip_mldsa) + get_files(
            "dev/**/*.c"
        )
    else:
        return get_files("**/*.c", strip_mldsa=strip_mldsa)


def get_asm_source_files(main_only=False, core_only=False, strip_mldsa=False):
    if main_only is True:
        return get_files("mldsa/src/**/*.S", strip_mldsa=strip_mldsa)
    elif core_only is True:
        return get_files("mldsa/src/**/*.S", strip_mldsa=strip_mldsa) + get_files(
            "dev/**/*.S", strip_mldsa=strip_mldsa
        )
    else:
        return get_files("**/*.S", strip_mldsa=strip_mldsa)


def get_header_files(main_only=False, core_only=False):
    if main_only is True:
        return get_files("mldsa/*.h") + get_files("mldsa/src/**/*.h")
    elif core_only is True:
        return (
            get_files("mldsa/*.h")
            + get_files("mldsa/src/**/*.h")
            + get_files("dev/**/*.h")
            + get_files("integration/**/*.h")
        )
    else:
        return get_files("**/*.h")


def get_markdown_files(main_only=False):
    return get_files("**/*.md")


def get_files(pattern, strip_mldsa=False):
    def normalize(f):
        return f.removeprefix("mldsa/") if strip_mldsa else f

    fs_files = {f for f in get_all_files() if pathlib.Path(f).is_file()}
    cache_files = set(_file_cache.keys())

    # Convert glob pattern to compiled regex
    pattern = pattern.replace(".", r"\.")
    pattern = pattern.replace("**/", "<<<DOUBLESTAR>>>")
    pattern = pattern.replace("*", "[^/]*")
    pattern = pattern.replace("<<<DOUBLESTAR>>>", "(?:.*/)?")
    regexp = re.compile(f"^{pattern}$")

    # Remove files which are scheduled to be deleted
    res = list(
        map(
            normalize,
            filter(
                lambda f: regexp.match(f) and (read_file(f) is not None),
                fs_files | cache_files,
            ),
        )
    )

    return res


def get_all_files():
    # All git-tracked files, including symlinks
    r = subprocess.run(["git", "ls-files"], capture_output=True, text=True)
    assert r.returncode == 0
    files = r.stdout.split("\n")
    files = filter(lambda s: s != "" and pathlib.Path(s).is_symlink() is False, files)
    return files


def get_defines_from_file(c):
    for l in read_file(c).split("\n"):
        m = _RE_DEFINE.match(l)
        if m:
            yield (c, m.group(1))


def get_defines(all=False):
    if all is False:
        files = get_header_files(main_only=True)
    else:
        files = get_header_files() + get_c_source_files() + get_asm_source_files()

    for results in run_parallel(files, get_defines_from_file):
        yield from results


def get_checked_defines():
    allow_list = [("__contract__", "cbmc.h"), ("__loop__", "cbmc.h")]

    def is_allowed(d, c):
        for d0, c0 in allow_list:
            if c.endswith(c0) is True and d0 == d:
                return True
        return False

    for c, d in get_defines():
        if d.startswith("_") and is_allowed(d, c) is False:
            raise Exception(
                f"{d} from {c}: starts with an underscore, which is not allowed for mldsa-native macros. "
                f"If this is an mldsa-native specific macro, please pick a different name. "
                f"If this is an external macro, it likely needs removing from `gen_monolithic_undef_all_core()` in `scripts/autogen` -- check this!"
            )
        yield (c, d)


def gen_monolithic_undef_all_core(filt=None, desc=""):

    if filt is None:
        filt = lambda c: True

    if desc != "":
        yield "/*"
        yield f" * Undefine macros from {desc}"
        yield " */"

    defines = list(set(get_checked_defines()))
    defines.sort()

    last_filename = None
    for filename, d in defines:
        if filt(filename) is False:
            continue
        if last_filename != filename:
            yield f"/* {filename} */"
            last_filename = filename
        yield f"#undef {d}"


def native(c):
    return "/native/" in c


def fips202(c):
    return "/fips202/" in c


def aarch64(c):
    return "/aarch64/" in c


def x86_64(c):
    return "/x86_64/" in c


def riscv64(c):
    return "/riscv64/" in c


def native_fips202(c):
    return native(c) and fips202(c)


def native_arith(c):
    return native(c) and not fips202(c)


def native_fips202_aarch64(c):
    return native_fips202(c) and aarch64(c)


def native_fips202_x86_64(c):
    return native_fips202(c) and x86_64(c)


def native_fips202_core(c):
    return (
        native_fips202(c)
        and not native_fips202_x86_64(c)
        and not native_fips202_aarch64(c)
    )


def native_arith_aarch64(c):
    return native_arith(c) and aarch64(c)


def native_arith_x86_64(c):
    return native_arith(c) and x86_64(c)


def native_arith_riscv64(c):
    return native_arith(c) and riscv64(c)


def native_arith_core(c):
    return (
        native_arith(c)
        and not native_arith_x86_64(c)
        and not native_arith_aarch64(c)
        and not native_arith_riscv64(c)
    )


# List of level-specific source files
# All other files only need including and building once
# in multi-level build.
def k_specific(c):
    k_specific_sources = [
        "mldsa_native.h",
        "params.h",
        "common.h",
        "packing.c",
        "packing.h",
        "poly_kl.c",
        "poly_kl.h",
        "polyvec.c",
        "polyvec.h",
        "rounding.h",
        "sign.c",
        "sign.h",
    ]
    for f in k_specific_sources:
        if c.endswith(f):
            return True
    return False


def k_generic(c):
    return not k_specific(c) and c != "mldsa/mldsa_native_config.h"


def gen_macro_undefs(extra_notes=None):
    if extra_notes is None:
        extra_notes = []

    yield "/* Macro #undef's"
    yield " *"
    yield " * The following undefines macros from headers"
    yield " * included by the source files imported above."
    yield " *"
    yield " * This is to allow building and linking multiple builds"
    yield " * of mldsa-native for varying parameter sets through concatenation"
    yield " * of this file, as if the files had been compiled separately."
    yield " * If this is not relevant to you, you may remove the following."
    for e in extra_notes:
        yield f" * {e}"
    yield " */"
    yield ""
    yield from gen_monolithic_undef_all_core(
        filt=k_specific, desc="MLD_CONFIG_PARAMETER_SET-specific files"
    )
    yield ""
    yield "#if !defined(MLD_CONFIG_MONOBUILD_KEEP_SHARED_HEADERS)"
    yield from gen_monolithic_undef_all_core(
        filt=lambda c: not native(c)
        and k_generic(c)
        and not fips202(c)
        and "cbmc.h" not in c,
        desc="MLD_CONFIG_PARAMETER_SET-generic files",
    )
    # Handle cbmc.h manually -- most #define's therein are only defined when CBMC is set
    # and need not be #undef'ed. In fact, #undef'ing them is risky since their names may
    # well already be occupied.
    yield "/* mldsa/src/cbmc.h */"
    yield "#undef MLD_CBMC_H"
    yield "#undef __contract__"
    yield "#undef __loop__"
    yield ""
    yield "#if !defined(MLD_CONFIG_FIPS202_CUSTOM_HEADER)"
    yield from gen_monolithic_undef_all_core(
        filt=lambda c: not native(c) and k_generic(c) and fips202(c),
        desc="FIPS-202 files",
    )
    yield "#endif"
    yield ""
    yield "#if defined(MLD_CONFIG_USE_NATIVE_BACKEND_FIPS202)"
    yield from gen_monolithic_undef_all_core(filt=native_fips202_core, desc="")
    yield "#if defined(MLD_SYS_AARCH64)"
    yield from gen_monolithic_undef_all_core(
        filt=native_fips202_aarch64, desc="native code (FIPS202, AArch64)"
    )
    yield "#endif"
    yield "#if defined(MLD_SYS_X86_64)"
    yield from gen_monolithic_undef_all_core(
        filt=native_fips202_x86_64, desc="native code (FIPS202, x86_64)"
    )
    yield "#endif"
    yield "#endif"
    yield "#if defined(MLD_CONFIG_USE_NATIVE_BACKEND_ARITH)"
    yield from gen_monolithic_undef_all_core(filt=native_arith_core, desc="")
    yield "#if defined(MLD_SYS_AARCH64)"
    yield from gen_monolithic_undef_all_core(
        filt=native_arith_aarch64, desc="native code (Arith, AArch64)"
    )
    yield "#endif"
    yield "#if defined(MLD_SYS_X86_64)"
    yield from gen_monolithic_undef_all_core(
        filt=native_arith_x86_64, desc="native code (Arith, X86_64)"
    )
    yield "#endif"
    yield "#endif"
    yield "#endif"
    yield ""


def gen_monolithic_source_file():

    def gen():
        c_sources = get_c_source_files(main_only=True, strip_mldsa=True)
        yield from gen_header()

        yield "/******************************************************************************"
        yield " *"
        yield " * Single compilation unit (SCU) for fixed-level build of mldsa-native"
        yield " *"
        yield " * This compilation unit bundles together all source files for a build"
        yield " * of mldsa-native for a fixed security level (MLDSA-44/65/87)."
        yield " *"
        yield " * # API"
        yield " *"
        yield " * The API exposed by this file is described in mldsa_native.h."
        yield " *"
        yield " * # Multi-level build"
        yield " *"
        yield " * If you want an SCU build of mldsa-native with support for multiple security"
        yield " * levels, you need to include this file multiple times, and set"
        yield " * MLD_CONFIG_MULTILEVEL_WITH_SHARED and MLD_CONFIG_MULTILEVEL_NO_SHARED"
        yield " * appropriately. This is exemplified in examples/monolithic_build_multilevel"
        yield " * and examples/monolithic_build_multilevel_native."
        yield " *"
        yield " * # Configuration"
        yield " *"
        yield " * The following options from the mldsa-native configuration are relevant:"
        yield " *"
        yield " * - MLD_CONFIG_FIPS202_CUSTOM_HEADER"
        yield " *   Set this option if you use a custom FIPS202 implementation."
        yield " *"
        yield " * - MLD_CONFIG_USE_NATIVE_BACKEND_ARITH"
        yield " *   Set this option if you want to include the native arithmetic backends"
        yield " *   in your build."
        yield " *"
        yield " * - MLD_CONFIG_USE_NATIVE_BACKEND_FIPS202"
        yield " *   Set this option if you want to include the native FIPS202 backends"
        yield " *   in your build."
        yield " *"
        yield " * - MLD_CONFIG_MONOBUILD_KEEP_SHARED_HEADERS"
        yield " *   Set this option if you want to keep the directives defined in"
        yield " *   level-independent headers. This is needed for a multi-level build."
        yield " */"
        yield " "
        yield " /* If parts of the mldsa-native source tree are not used,"
        yield " * consider reducing this header via `unifdef`."
        yield " *"
        yield " * Example:"
        yield " * ```bash"
        yield " * unifdef -UMLD_CONFIG_USE_NATIVE_BACKEND_ARITH mldsa_native.c"
        yield " * ```"
        yield " */"
        yield ""
        yield '#include "src/common.h"'
        yield ""
        for c in filter(lambda c: not native(c) and not fips202(c), c_sources):
            yield f'#include "{c}"'
        yield ""
        yield "#if !defined(MLD_CONFIG_FIPS202_CUSTOM_HEADER)"
        for c in filter(lambda c: not native(c) and fips202(c), c_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield ""
        yield "#if defined(MLD_CONFIG_USE_NATIVE_BACKEND_ARITH)"
        yield "#if defined(MLD_SYS_AARCH64)"
        for c in filter(native_arith_aarch64, c_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield "#if defined(MLD_SYS_X86_64)"
        for c in filter(native_arith_x86_64, c_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield "#endif"
        yield ""
        yield "#if defined(MLD_CONFIG_USE_NATIVE_BACKEND_FIPS202)"
        yield "#if defined(MLD_SYS_AARCH64)"
        for c in filter(native_fips202_aarch64, c_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield "#if defined(MLD_SYS_X86_64)"
        for c in filter(native_fips202_x86_64, c_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield "#endif"
        yield ""
        yield from gen_macro_undefs()

    update_file("mldsa/mldsa_native.c", "\n".join(gen()))


def gen_monolithic_asm_file():

    def gen():
        asm_sources = get_asm_source_files(main_only=True, strip_mldsa=True)
        yield from gen_header()
        yield "/******************************************************************************"
        yield " *"
        yield " * Single assembly unit for fixed-level build of mldsa-native"
        yield " *"
        yield " * This assembly unit bundles together all assembly files for a build"
        yield " * of mldsa-native for a fixed security level (MLDSA-44/65/87)."
        yield " *"
        yield " * # Multi-level build"
        yield " *"
        yield " * If you want an SCU build of mldsa-native with support for multiple security"
        yield " * levels, you should include this file once with MLD_CONFIG_MULTILEVEL_WITH_SHARED set."
        yield " *"
        yield " * (You could also follow the same pattern as for mldsa_native_monobuild.c"
        yield " *  and include it for every level, setting MLD_CONFIG_MULTILEVEL_NO_SHARED"
        yield " *  for all but one. For builds with MLD_CONFIG_MULTILEVEL_NO_SHARED, this"
        yield " *  file will then be ignored.)"
        yield " *"
        yield " * # Configuration"
        yield " *"
        yield " * The following options from the mldsa-native configuration are relevant:"
        yield " *"
        yield " * - MLD_CONFIG_FIPS202_CUSTOM_HEADER"
        yield " *   Set this option if you use a custom FIPS202 implementation."
        yield " *"
        yield " * - MLD_CONFIG_USE_NATIVE_BACKEND_ARITH"
        yield " *   Set this option if you want to include the native arithmetic backends"
        yield " *   in your build."
        yield " *"
        yield " * - MLD_CONFIG_USE_NATIVE_BACKEND_FIPS202"
        yield " *   Set this option if you want to include the native FIPS202 backends"
        yield " *   in your build."
        yield " *"
        yield " * - MLD_CONFIG_MONOBUILD_KEEP_SHARED_HEADERS"
        yield " *   Set this option if you want to keep the directives defined in"
        yield " *   level-independent headers. This is needed for a multi-level build."
        yield " */"
        yield ""
        yield "/* If parts of the mldsa-native source tree are not used,"
        yield " * consider reducing this header via `unifdef`."
        yield " *"
        yield " * Example:"
        yield " * ```bash"
        yield " * unifdef -UMLD_CONFIG_USE_NATIVE_BACKEND_ARITH mldsa_native.S"
        yield " * ```"
        yield " */"
        yield ""
        yield '#include "src/common.h"'
        yield ""
        yield "#if defined(MLD_CONFIG_USE_NATIVE_BACKEND_ARITH)"
        yield "#if defined(MLD_SYS_AARCH64)"
        for c in filter(native_arith_aarch64, asm_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield "#if defined(MLD_SYS_X86_64)"
        for c in filter(native_arith_x86_64, asm_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield "#endif"
        yield ""
        yield "#if defined(MLD_CONFIG_USE_NATIVE_BACKEND_FIPS202)"
        yield "#if defined(MLD_SYS_AARCH64)"
        for c in filter(native_fips202_aarch64, asm_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield "#if defined(MLD_SYS_X86_64)"
        for c in filter(native_fips202_x86_64, asm_sources):
            yield f'#include "{c}"'
        yield "#endif"
        yield "#endif"
        yield ""
        # We generate #undef's for all headers, even though most are not
        # included by the assembly files. This does not harm, and avoids
        # having to trace which headers are being pulled in from common.h
        # included from the assembly files.
        yield ""
        extra = [
            "",
            "NOTE: This is not needed for the assembly SCU since, at present,",
            "there is no need to include it multiple times.",
            "We keep it for uniformity with mldsa_native.c only.",
            "",
            "NOTE: To avoid having to distinguish between which headers are included",
            "from the assembly files, we #undef the same set of directives",
            "as in mldsa_native.c",
        ]
        yield from gen_macro_undefs(extra_notes=extra)

    update_file("mldsa/mldsa_native.S", "\n".join(gen()), force_format=True)


def get_config_options():
    content = read_file("mldsa/mldsa_native_config.h")
    config_pattern = r"Name:\s*(MLD_CONFIG_\w+)"

    configs = re.findall(config_pattern, content)

    configs += [
        "MLD_FORCE_AARCH64",
        "MLD_FORCE_AARCH64_EB",
        "MLD_FORCE_X86_64",
        "MLD_FORCE_PPC64LE",
        "MLD_FORCE_RISCV64",
        "MLD_FORCE_RISCV32",
        "MLD_SYS_AARCH64_SLOW_BARREL_SHIFTER",
        "MLDSA_DEBUG",  # TODO: Rename?
        "MLD_BREAK_PCT",  # Use in PCT breakage test
        "MLD_CHECK_APIS",
        "MLD_CONFIG_API_XXX",
        "MLD_ERR_XXX",
        "MLD_USE_NATIVE_XXX",
        "MLD_CONFIG_XXX",
        "MLD_CONFIG_API_CONSTANTS_ONLY",
        "MLD_PREHASH_",
    ]

    return configs


def check_macro_typos_in_file(filename, macro_check):
    """Checks for typos in MLD_XXX and MLDSA_XXX identifiers."""
    content = read_file(filename)

    # Separate check for wrongly ported MLK/MLKEM macros
    for m in _RE_MLKEM_MACRO_CHECK.finditer(content):
        txt = m.group(1)
        rest = m.group(2)
        line_no = content[: m.start()].count("\n") + 1
        if filename != "scripts/autogen":
            raise Exception(
                f"Likely typo {txt} in {filename}:{line_no}? wrongly ported MLK_XXX / MLKEM_XXX macros from mlkem-native."
            )

    # Check MLD/MLDSA macros
    for m in _RE_MACRO_CHECK.finditer(content):
        txt = m.group(1)
        rest = m.group(2)
        if macro_check(txt, rest, filename) is False:
            line_no = content[: m.start()].count("\n") + 1
            raise Exception(
                f"Likely typo {txt} in {filename}:{line_no}? Not a defined macro."
            )


def get_syscaps():
    return ["MLD_SYS_CAP_AVX2", "MLD_SYS_CAP_SHA3", "MLD_SYS_CAP_DUMMY"]


def check_macro_typos():
    files = get_all_files()
    syscaps = get_syscaps()

    macros = set(map(lambda t: t[1], get_defines(all=True)))

    # Add configuration options to the list of allows macro names
    macros.update(get_config_options())

    def macro_check(m, rest, filename):
        if m in macros:
            return True

        # Ignore alloc macros only defined in mldsa_native.h
        if m.startswith("MLD_TOTAL_ALLOC") or m.startswith("MLD_MAX_TOTAL_ALLOC"):
            return True

        is_autogen = filename == "scripts/autogen"

        # Exclude system capabilities, which are enum values
        if m in syscaps:
            return True

        #
        # Register some file-specific exceptions
        #

        # 1. Makefiles use MLD_SOURCE_XXX to list source files
        if is_autogen or filename.endswith("/Makefile"):
            if m.startswith("MLD_SOURCE") or m.startswith("MLD_OBJ"):
                return True

        # 2. libOQS specific identifier
        if is_autogen or filename.startswith("integration/liboqs"):
            if m.startswith("MLDSA_NATIVE_MLDSA") or m in ["MLDSA_DIR"]:
                return True

        # 3. Exclude HOL-Light proof scripts
        if is_autogen or filename.startswith("proofs/hol_light"):
            if filename.endswith(".ml"):
                return True

        # 4. Exclude regexp patterns in `autogen`
        if is_autogen:
            if rest.startswith("\\") or m in ["MLD_XXX", "MLD_SOURCE_XXX"]:
                return True

        # 5. AWS-LC importer patch
        if is_autogen or filename == "integration/awslc/awslc.patch":
            return True

        if is_autogen or filename == "mldsa/src/common.h":
            if m == "MLD_CONTEXT_PARAMETERS_n":
                return True

        return False

    run_parallel(
        list(files), partial(check_macro_typos_in_file, macro_check=macro_check)
    )


def check_asm_register_aliases_for_file(filename):
    """Checks that `filename` has no mismatching or dangling register aliases"""

    def get_alias_def(l):
        s = list(filter(lambda s: s != "", l.strip().split(" ")))
        if len(s) < 3 or s[1] != ".req":
            return None
        return s[0]

    def get_alias_undef(l):
        if l.strip().startswith(".unreq") is False:
            return None
        return list(filter(lambda s: s != "", l.strip().split(" ")))[1]

    content = read_file(filename)
    aliases = {}
    for i, l in enumerate(content.split("\n")):
        alias_def = get_alias_def(l)
        alias_undef = get_alias_undef(l)
        if alias_def is not None:
            if alias_def in aliases.keys():
                raise Exception(
                    f"Invalid assembly file {filename}: Duplicate .req directive for {alias_def} at line {i}"
                )
            aliases[alias_def] = i
        elif alias_undef is not None:
            if alias_undef not in aliases.keys():
                raise Exception(
                    f"Invalid assembly file {filename}: .unreq without prior .req for {alias_undef} at line {i}"
                )
            del aliases[alias_undef]

    if len(aliases) > 0:
        fixup_suggestion = [
            "/****************** REGISTER DEALLOCATIONS *******************/"
        ]
        dangling = list(aliases.items())
        # Sort by line number of .req
        dangling.sort(key=lambda s: s[1])

        for a, _ in dangling:
            fixup_suggestion.append(f"    .unreq {a}")
        fixup_suggestion.append("")
        fixup_suggestion = "\n".join(fixup_suggestion)

        raise Exception(
            f"Invalid assembly file {filename}: Dangling .req directives {aliases}.\n\nTry adding this?\n\n{fixup_suggestion}"
        )


def check_asm_register_aliases():
    run_parallel(get_asm_source_files(), check_asm_register_aliases_for_file)


def check_asm_loop_labels_for_file(filename):
    """Checks that all labels in `filename` are prefixed according to the function defined in the file"""

    content = read_file(filename)

    # Find function symbol name
    func_pattern = r"MLD_ASM_FN_SYMBOL\((.*)\)"
    res = _RE_FUNC_SYMBOL.search(content)
    if res is None:
        raise Exception(f"Could not find function symbol in assembly file {filename}")
    funcname = res.group(1)
    lbl_prefix = funcname.replace("_asm", "") + "_"
    lbl_pattern = r"^(\w+):"
    for m in re.finditer(lbl_pattern, content, flags=re.M):
        lbl = m.group(1)
        if not lbl.startswith(lbl_prefix):
            raise Exception(
                f"Please change label {lbl} in {filename} to be prefixed by {lbl_prefix}"
            )


def check_asm_loop_labels():
    # Operate on assembly files in dev/ only. The ones in mlkem/ are autogenerated
    # from that and don't have the original MLK_ASM_FN_SYMBOL marker anymore.
    files = list(filter(lambda s: s.startswith("dev/"), get_asm_source_files()))
    run_parallel(files, check_asm_loop_labels_for_file)


def normalize_comma_separated_args(args_str):
    """Convert whitespace-separated args to comma-separated, add spaces after commas"""
    # Extract and preserve comment
    match = _RE_ARGS_COMMENT.match(args_str)
    args_only = match.group(1).rstrip()
    comment = match.group(2) or ""

    # If already has commas, just normalize spacing
    if "," in args_only:
        result = re.sub(r",(?! )", ", ", args_only)
    else:
        # Split on whitespace and join with commas
        args = args_only.split()
        if not args:
            return args_str
        result = ", ".join(args)

    return result + comment


def normalize_asm_macro_syntax_for_file(filename):
    """Normalize macro definitions and invocations to use commas with spaces"""

    content = read_file(filename)
    lines = content.split("\n")

    # First pass: collect macro names
    macro_names = set()
    for line in lines:
        macro_match = _RE_MACRO_DEF.match(line)
        if macro_match:
            macro_names.add(macro_match.group(1))

    # Second pass: normalize syntax
    new_lines = []
    for line in lines:
        # Normalize .macro definitions
        macro_def_match = _RE_MACRO_DEF_ARGS.match(line)
        if macro_def_match:
            prefix = macro_def_match.group(1)
            args_with_space = macro_def_match.group(2)
            # Preserve leading whitespace, normalize the rest
            leading_space = _RE_LEADING_SPACE.match(args_with_space).group(1)
            args = args_with_space.lstrip()
            normalized_args = normalize_comma_separated_args(args)
            line = prefix + leading_space + normalized_args
        else:
            # Normalize macro invocations
            for macro_name in macro_names:
                # Match: whitespace + macro_name + whitespace + args
                pattern = r"^(\s*" + re.escape(macro_name) + r")(\s+.*)$"
                invocation_match = re.match(pattern, line)
                if not invocation_match:
                    continue
                prefix = invocation_match.group(1)
                args_with_space = invocation_match.group(2)
                # Preserve leading whitespace, normalize the rest
                leading_space = _RE_LEADING_SPACE.match(args_with_space).group(1)
                args = args_with_space.lstrip()
                normalized_args = normalize_comma_separated_args(args)
                line = prefix + leading_space + normalized_args
                break

        new_lines.append(line)

    update_file(filename, "\n".join(new_lines))


def normalize_asm_macro_syntax():
    """Normalize macro syntax in all assembly files"""
    # Operate on assembly files in dev/ only. The ones in mldsa/ are autogenerated.
    files = list(filter(lambda s: s.startswith("dev/"), get_asm_source_files()))
    files += list(filter(lambda s: s.endswith(".inc"), get_files("dev/**/*.inc")))
    run_parallel(files, normalize_asm_macro_syntax_for_file)


def update_via_simpasm(
    infile_full,
    outdir,
    outfile=None,
    cflags=None,
    preserve_header=True,
    force_cross=False,
    x86_64_syntax="att",
):

    _, infile = os.path.split(infile_full)
    if outfile is None:
        outfile = infile
    outfile_full = os.path.join(outdir, outfile)

    if cflags is None:
        cflags = ""
    cflags += " -Imldsa"

    # Check if we need to use a cross-compiler
    if "aarch64" in infile_full:
        source_arch = "aarch64"
    elif "x86_64" in infile_full:
        source_arch = "x86_64"
    else:
        raise Exception(f"Could not detect architecture of source file {infile_full}.")
    # Check native architecture
    if platform.machine().lower() in ["arm64", "aarch64"]:
        native_arch = "aarch64"
    else:
        native_arch = "x86_64"

    if native_arch != source_arch:
        cross_prefix = f"{source_arch}-unknown-linux-gnu-"
        cross_gcc = cross_prefix + "gcc"
        # Check if cross-compiler is present
        if shutil.which(cross_gcc) is None:
            if force_cross is False:
                return
            raise Exception(f"Could not find cross toolchain {cross_prefix}")
    else:
        cross_prefix = None

    with tempfile.NamedTemporaryFile(suffix=".S") as tmp:
        try:
            # Determine architecture from filename
            arch = "aarch64" if "aarch64" in infile_full else "x86_64"

            cmd = [
                "./scripts/simpasm",
                "--objdump=llvm-objdump",
                "--cfify",
                "--arch=" + arch,
                "-i",
                infile_full,
                "-o",
                tmp.name,
            ]
            if cross_prefix is not None:
                # Stick with llvm-objdump for disassembly
                cmd += ["--cc", cross_prefix + "gcc"]
                cmd += ["--nm", cross_prefix + "nm"]
            if cflags is not None:
                cmd += [f'--cflags="{cflags}"']
            if preserve_header is True:
                cmd += ["-p"]
            # Add syntax option for x86_64
            if arch == "x86_64" and x86_64_syntax != "att":
                cmd += ["--syntax", x86_64_syntax]
            r = subprocess.run(
                cmd,
                stdout=subprocess.DEVNULL,
                stderr=subprocess.PIPE,
                check=True,
                text=True,
            )
        except subprocess.CalledProcessError as e:
            print(f"Command failed: {' '.join(cmd)}")
            print(f"Exit code: {e.returncode}")
            print(f"stderr: {e.stderr}")
            raise Exception("Failed to run simpasm") from e
        tmp.seek(0)
        new_contents = tmp.read().decode()

    update_file(outfile_full, new_contents)


def gen_hol_light_asm_file(job):
    infile, outfile, indir, cflags, arch = job
    update_via_simpasm(
        f"{indir}/{infile}",
        "proofs/hol_light/" + arch + "/mldsa",
        outfile=outfile,
        cflags=cflags,
        preserve_header=False,
    )


def gen_hol_light_asm():
    x86_64_flags = "-mavx2 -mbmi2 -msse4 -fcf-protection=full"
    joblist_x86_64 = [
        (
            "ntt.S",
            "mldsa_ntt.S",
            "dev/x86_64/src",
            f"-Imldsa/src/native/x86_64/src -Imldsa/src/common.h {x86_64_flags}",
            "x86_64",
        ),
    ]

    run_parallel(joblist_x86_64, gen_hol_light_asm_file)


def update_via_copy(infile_full, outfile_full, transform=None):

    content = read_file(infile_full)

    if transform is not None:
        content = transform(content)

    update_file(outfile_full, content)


def update_via_remove(filename):
    update_file(filename, None)


def synchronize_file(f, in_dir, out_dir, delete=False, no_simplify=False, **kwargs):

    # Only synchronize sources, but not README.md, Makefile and so on
    extensions = (".c", ".h", ".i", ".inc", ".S")

    if not f.endswith(extensions):
        return None

    basename = os.path.basename(f)

    if delete is True:
        return basename

    if no_simplify is False and f.endswith(".S"):
        update_via_simpasm(f, out_dir, **kwargs)
    else:
        # Update via copy
        _, infile = os.path.split(f)
        outfile_full = os.path.join(out_dir, infile)
        # The header guards will also be checked later, but if we
        # don't do it here, the dry-run would fail because of a
        # mismatching intermediate file
        if f.endswith(".h"):
            transform = lambda c: adjust_header_guard_for_filename(c, outfile_full)
        else:
            transform = None
        update_via_copy(f, outfile_full, transform=transform)

    return basename


def synchronize_backend(in_dir, out_dir, delete=False, no_simplify=False, **kwargs):
    copied = []

    files = get_files(os.path.join(in_dir, "*"))
    pool_results = run_parallel(
        files,
        partial(
            synchronize_file,
            in_dir=in_dir,
            out_dir=out_dir,
            delete=delete,
            no_simplify=no_simplify,
            **kwargs,
        ),
    )

    copied = [r for r in pool_results if r is not None]

    if delete is False:
        return

    # Check for files in the target directory that have not been copied
    for f in get_files(os.path.join(out_dir, "*")):
        if os.path.basename(f) in copied:
            continue
        # Otherwise, remove it
        update_via_remove(f)


def synchronize_backends(
    *,
    force_cross=False,
    clean=False,
    delete=False,
    no_simplify=False,
    x86_64_syntax="att",
):
    if clean is False:
        ty = "opt"
    else:
        ty = "clean"

    if delete is False:
        # We may switch the AArch64 arithmetic backend, so adjust the metadata file
        update_via_copy(
            f"dev/aarch64_{ty}/meta.h",
            "mldsa/src/native/aarch64/meta.h",
            transform=lambda c: adjust_header_guard_for_filename(
                c, "mldsa/src/native/aarch64/meta.h"
            ),
        )

        update_via_copy(
            f"dev/x86_64/meta.h",
            "mldsa/src/native/x86_64/meta.h",
            transform=lambda c: adjust_header_guard_for_filename(
                c, "mldsa/src/native/x86_64/meta.h"
            ),
        )

    synchronize_backend(
        f"dev/aarch64_{ty}/src",
        "mldsa/src/native/aarch64/src",
        delete=delete,
        force_cross=force_cross,
        no_simplify=no_simplify,
        cflags="-Imldsa/src/native/aarch64/src",
    )
    synchronize_backend(
        "dev/fips202/aarch64/src",
        "mldsa/src/fips202/native/aarch64/src",
        delete=delete,
        force_cross=force_cross,
        no_simplify=no_simplify,
        cflags="-Imldsa/src/fips202/native/aarch64/src -march=armv8.4-a+sha3",
    )
    synchronize_backend(
        "dev/fips202/aarch64",
        "mldsa/src/fips202/native/aarch64",
        delete=delete,
        force_cross=force_cross,
        no_simplify=no_simplify,
        cflags="-Imldsa/src/fips202/native/aarch64 -march=armv8.4-a+sha3",
    )
    synchronize_backend(
        "dev/x86_64/src",
        "mldsa/src/native/x86_64/src",
        delete=delete,
        force_cross=force_cross,
        no_simplify=no_simplify,
        x86_64_syntax=x86_64_syntax,
        # Turn off control-flow protection (CET) explicitly. Newer versions of
        # clang turn it on by default and insert endbr64 instructions at every
        # global symbol.
        # We insert endbr64 instruction manually via the MLD_ASM_FN_SYMBOL
        # macro.
        # This leads to duplicate endbr64 instructions causing a failure when
        # comparing the object code before and after simplification.
        cflags="-Imldsa/src/native/x86_64/src/ -mavx2 -mbmi2 -msse4 -fcf-protection=none",
    )


def adjust_header_guard_for_filename(content, header_file):

    content = content.split("\n")
    exceptions = {
        "mldsa/mldsa_native.h": "MLD_H",
        "mldsa/mldsa_native_config.h": "MLD_CONFIG_H",
    }

    # Use full filename as the header guard, with '/' and '.' replaced by '_'
    guard_name = (
        header_file.removeprefix("mldsa/src/")
        .replace("/", "_")
        .replace(".", "_")
        .upper()
    )
    guard_name = "MLD_" + guard_name

    if header_file in exceptions.keys():
        guard_name = exceptions[header_file]

    def gen_guard():
        yield f"#ifndef {guard_name}"
        yield f"#define {guard_name}"

    def gen_footer():
        yield f"#endif"
        yield ""

    guard = list(gen_guard())
    footer = list(gen_footer())

    # Skip over initial commentary
    insert_at = None
    for i, l in enumerate(content):
        if l.strip() == "" or l.startswith(("/*", " *")):
            continue
        insert_at = i
        break

    i = insert_at
    while content[i].strip() == "":
        i += 1
    # Check if header file has some guard -- if so, drop it
    if content[i].strip().startswith("#if !defined") or content[i].strip().startswith(
        "#ifndef"
    ):
        del content[i]
        if content[i].strip().startswith("#define"):
            del content[i]
        has_guard = True
    else:
        has_guard = False
    # Add standardized guard
    content = content[:i] + guard + content[i:]
    # Check if header has some footer
    if (
        has_guard is True
        and content[-1] == ""
        and content[-2].strip().startswith("#endif")
    ):
        del content[-2:]
    # Add standardized footer
    content = content + footer

    return "\n".join(content)


def gen_header_guard(header_file):
    content = read_file(header_file)
    new_content = adjust_header_guard_for_filename(content, header_file)
    update_file(header_file, new_content)


def gen_header_guards():
    run_parallel(get_header_files(main_only=True), gen_header_guard)


def gen_source_undefs(source_file):

    # Get list of #define's clauses in this source file (ignore filename)
    undef_list = list(map(lambda c: c[1], get_defines_from_file(source_file)))
    if not undef_list:
        return

    # Get define clauses from header files, as dict
    header_defs = {d: c for (c, d) in get_defines()}

    undefs = []
    ignored = []
    for d in undef_list:
        if d not in header_defs.keys():
            undefs.append(f"#undef {d}")
        else:
            ignored.append((d, header_defs[d]))

    if len(ignored) != 0:
        undefs.append(
            "/* Some macros are kept because they are also defined in a header. */"
        )
        for d, c in ignored:
            undefs.append(f"/* Keep: {d} ({c.split('/')[-1]}) */")

    content = read_file(source_file).split("\n")

    # Find simpasm footer if present (search from end)
    footer_start = None
    if source_file.endswith(".S"):
        footer_start_marker = "simpasm: footer-start"
        for i in range(len(content) - 1, -1, -1):
            if footer_start_marker in content[i]:
                footer_start = i
                break

    if footer_start is not None:
        simpasm_footer = content[footer_start:]
        content = content[:footer_start]
    else:
        simpasm_footer = []

    # Strip trailing undefs and empty lines
    while content and (
        content[-1].startswith("#undef")
        or content[-1].startswith("/* Keep:")
        or content[-1].startswith("/* Some macros")
        or content[-1] == ""
    ):
        content.pop()

    footer = [
        "",
        "/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.",
        " * Don't modify by hand -- this is auto-generated by scripts/autogen. */",
    ]

    # Remove existing footer if present
    if len(content) >= len(footer) and content[-len(footer) :] == footer:
        content = content[: -len(footer)]

    content.extend(footer)
    content.extend(undefs)
    content.append("")

    new_content = "\n".join(content + simpasm_footer)
    update_file(source_file, new_content)


def gen_undefs():
    files = get_c_source_files(core_only=True) + get_asm_source_files(core_only=True)
    run_parallel(files, gen_source_undefs)


def gen_slothy(funcs):
    if not isinstance(funcs, list):
        return

    targets = list(map(lambda s: s + ".S", funcs))

    for t in targets:

        if t.startswith("keccak"):
            base = "dev/fips202/aarch64/src"
        else:
            base = "dev/aarch64_opt/src"

        # Remove file(s) to be re-generated
        if t.endswith(".S"):
            subprocess.run(["rm", "-f", f"{base}/{t}"])

        p = subprocess.run(["make", t] + ["-C", base])
        if p.returncode != 0:
            print(f"Failed to run SLOTHY on {t}!")
            exit(1)


class BibliographyEntry:
    def __init__(self, raw_dict):
        self._raw = raw_dict
        self._usages = []

    def register_usages(self, lst):
        self._usages += lst

    @property
    def usages(self):
        return self._usages

    @property
    def name(self):
        return self._raw["name"]

    @property
    def short(self):
        if "short" in self._raw.keys():
            return self._raw["short"]
        return self.name

    @property
    def id(self):
        return self._raw["id"]

    @property
    def url(self):
        return self._raw["url"]

    @staticmethod
    def full_name(name):
        if "," not in name:
            return name
        surname, forename = name.split(",")
        return forename.strip() + " " + surname.strip()

    @property
    def authors(self):
        authors = self._raw["author"]
        if not isinstance(authors, list):
            authors = [authors]
        authors = list(map(BibliographyEntry.full_name, authors))
        return authors

    @property
    def authors_text(self):
        authors = self._raw["author"]
        if not isinstance(authors, list):
            authors = [authors]

        def surname(name):
            return name.split(",")[0].strip()

        if len(authors) > 1:
            authors = ", ".join(map(surname, authors))
        else:
            authors = BibliographyEntry.full_name(authors[0])
        return authors


def gen_markdown_citations_for(filename, bibliography):

    # Skip BIBLIOGRAPHY.md
    if filename == "BIBLIOGRAPHY.md":
        return

    content = read_file(filename)
    if not _RE_MARKDOWN_CITE.search(content):
        return

    content = content.split("\n")

    # Lookup all citations in style `[^ID]`
    citations = {}
    for i, l in enumerate(content):
        for m in _RE_MARKDOWN_CITE.finditer(l):
            cite_id = m.group("id")
            uses = citations.get(cite_id, [])
            uses.append((filename, i))
            citations[cite_id] = uses

    # Find and remove any existing citation footnotes
    footnote_footer_start = "<!--- bibliography --->"
    try:
        i = content.index(footnote_footer_start)
        content = content[:i]
    except ValueError:
        pass

    # Add footnotes for all citations found
    if len(citations) > 0:
        content.append(footnote_footer_start)
    cite_ids = list(citations.keys())
    cite_ids.sort()
    for cite_id in cite_ids:
        uses = citations[cite_id]
        entry = bibliography.get(cite_id, None)
        if entry is None:
            raise Exception(
                f"Could not find bibliography entry {cite_id} referenced in {filename}. Known entries: {list(bibliography.keys())}"
            )
        content.append(
            f"[^{cite_id}]: {entry.authors_text}: {entry.name}, [{entry.url}]({entry.url})"
        )

        # Remember this usage of the bibliography entry
        entry.register_usages(uses)

    if len(citations) > 0:
        content.append("")

    update_file(filename, "\n".join(content))


def gen_c_citations_for(filename, bibliography):

    content = read_file(filename)

    if not _RE_C_CITE.search(content):
        return

    references_start = [
        "/* References",
        " * ==========",
    ]
    references_end = [" */"]

    # Find and remove any existing reference section
    ref_pattern = r"/\* (# )?References.*?\*/\n+"
    content = re.sub(ref_pattern, "", content, flags=re.DOTALL)

    content = content.split("\n")

    # Lookup all citations in style `@[ID]`
    citations = {}
    for i, l in enumerate(content):
        for m in _RE_C_CITE.finditer(l):
            cite_id = m.group("id")
            uses = citations.get(cite_id, [])
            # Remember usage. +1 because line counting starts at 1
            uses.append((filename, i + 1))
            citations[cite_id] = uses

    # Add references section
    references = []
    references += references_start

    cite_ids = list(citations.keys())
    cite_ids.sort()
    for cite_id in cite_ids:
        uses = citations[cite_id]
        entry = bibliography.get(cite_id, None)
        if entry is None:
            raise Exception(
                f"Could not find bibliography entry {cite_id} referenced in {filename}"
            )
        references.append(f" *")
        references.append(f" * - [{cite_id}]")
        prefix = " *   "
        # Wrap long lines at 80 chars
        for line in [entry.name, entry.authors_text, entry.url]:
            if len(line) + len(prefix) <= 80:
                references.append(f"{prefix}{line}")
            else:
                words = line.split()
                current = prefix
                for word in words:
                    if len(current) + len(word) <= 80 or current == prefix:
                        current += word + " "
                    else:
                        references.append(current.rstrip())
                        current = prefix + word + " "
                if current.rstrip() != prefix.rstrip():
                    references.append(current.rstrip())

    references += references_end

    # Fix indentation for comment lines
    references = "\n".join(references)
    references = references.split("\n")
    references = [""] + references

    if len(cite_ids) > 0:
        # Add references to file after initial header section
        # Skip over copyright
        insert_at = None
        for i, l in enumerate(content):
            if l.startswith(("/*", " *")):
                continue
            insert_at = i
            break
        content = content[:insert_at] + references + content[insert_at:]

    # Remember uses -- needs to happen after insertion of references
    # since we need to adjust the line count
    for cite_id in cite_ids:
        uses = citations[cite_id]
        entry = bibliography.get(cite_id, None)

        # Adjust line count after insertion of references
        def bump_line_count(x):
            return (x[0], x[1] + len(references))

        uses = list(map(bump_line_count, uses))

        # Remember this usage of the bibliography entry
        entry.register_usages(uses)

    update_file(filename, "\n".join(content))


def gen_citations_for(filename, bibliography):
    if filename.endswith(".md"):
        gen_markdown_citations_for(filename, bibliography)
    elif filename.endswith((".c", ".h", ".S")):
        gen_c_citations_for(filename, bibliography)
    else:
        raise Exception(f"Unexpected file extension in {filename}")


def gen_bib_file(bibliography):

    content = [
        "[//]: # (SPDX-License-Identifier: CC-BY-4.0)",
        "[//]: # (This file is auto-generated from BIBLIOGRAPHY.yml)",
        "[//]: # (Do not modify it directly)",
        "",
        "# Bibliography",
        "",
        "This file lists the citations made throughout the mldsa-native ",
        "source code and documentation.",
        "",
    ]

    cite_ids = list(bibliography.keys())
    cite_ids.sort()

    for cite_id in cite_ids:
        entry = bibliography[cite_id]
        content.append(f"### `{cite_id}`")
        content.append("")
        content.append(f"* {entry.name}")
        content.append(f"* Author(s):")
        for author in entry.authors:
            content.append(f"  - {author}")
        content.append(f"* URL: {entry.url}")
        content.append(f"* Referenced from:")
        # Usages are pairs of (filename, line_count)
        # Ignore line_count for now, as it would require `autogen` after
        # a change to source files.
        files = list(set(map(lambda x: x[0], entry.usages)))
        files.sort()
        for filename in files:
            content.append(f"  - [{filename}]({filename})")
        content.append("")

    update_file("BIBLIOGRAPHY.md", "\n".join(content))


def get_oqs_shared_sources(backend):
    """Get shared source files for OQS integration"""
    mldsa_dir = "mldsa/src/"

    # add files mldsa/*
    sources = [
        f"mldsa/src/{f}"
        for f in os.listdir(mldsa_dir)
        if os.path.isfile(f"{mldsa_dir}/{f}")
        and not f.endswith(".o")
        and not f == "mldsa_native.h"
    ]

    if backend != "ref":
        # add files mldsa/native/* (API definitions)
        sources += [
            f"mldsa/src/native/{f}"
            for f in os.listdir(f"{mldsa_dir}/native")
            if os.path.isfile(f"{mldsa_dir}/native/{f}")
        ]
    # Add FIPS202 glue code
    sources += [
        "integration/liboqs/fips202_glue.h",
        "integration/liboqs/fips202x4_glue.h",
    ]
    # Add custom config
    if backend == "ref":
        backend = "c"
    sources.append(f"integration/liboqs/config_{backend.lower()}.h")

    return sources


def get_oqs_native_sources(backend):
    """Get native source files for OQS integration"""
    return [f"mldsa/src/native/{backend}"]


def gen_oqs_meta_file(filename):
    """Generate OQS META.yml file with updated source lists"""

    content = read_file(filename)

    # Parse YAML while preserving structure
    yml_data = yaml.safe_load(content)

    for impl in yml_data["implementations"]:
        name = impl["name"]

        sources = get_oqs_shared_sources(name)

        # NOTE: Sorting at the end causes the libOQS importer to fail.
        # Somehow, the native directory cannot be imported too early.
        sources.sort()

        if name != "ref":
            sources += get_oqs_native_sources(name)
        impl["sources"] = " ".join(sources)

    # Convert back to YAML string with standard copyright header
    yaml_header = "\n".join(gen_yaml_header())

    new_content = yaml.dump(
        yml_data,
        default_flow_style=False,
        sort_keys=False,
        allow_unicode=True,
        encoding=None,
    )

    # Combine copyright header with new YAML content
    new_content = yaml_header + new_content

    update_file(filename, new_content)


def gen_oqs_meta_files():
    """Generate all OQS META.yml files"""
    meta_files = [
        "integration/liboqs/ML-DSA-44_META.yml",
        "integration/liboqs/ML-DSA-65_META.yml",
        "integration/liboqs/ML-DSA-87_META.yml",
    ]

    for meta_file in meta_files:
        gen_oqs_meta_file(meta_file)


def gen_citations():
    # Load bibliography
    with open("BIBLIOGRAPHY.yml", "r") as f:
        bibliography_raw = yaml.safe_load(f.read())

    bibliography = {}
    for r in bibliography_raw:
        cite_id = r["id"]
        bibliography[cite_id] = BibliographyEntry(r)

    files = (
        get_markdown_files()
        + get_asm_source_files()
        + get_c_source_files()
        + get_header_files()
    )
    run_parallel(files, partial(gen_citations_for, bibliography=bibliography))

    # Check that every bibliography entry has been used as least once
    for e in bibliography.values():
        if len(e.usages) == 0:
            raise Exception(
                f"Bibliography entry {e.id} is unused! "
                "Add a citation or remove from BIBLIOGRAPHY.yml."
            )

    gen_bib_file(bibliography)


def extract_bytecode_from_output(output_text):
    """Convert output of proofs/hol_light/x86_64/proofs/dump_bytecode.native
    into a dictionary mapping function names to byte code strings."""
    bytecode_dict = {}

    lines = output_text.split("\n")
    i = 0

    while i < len(lines):
        line = lines[i]
        match = _RE_BYTECODE_START.search(line)
        if match:
            filename = match.group(1)

            # Collect bytecode until end marker
            bytecode_lines = []
            i += 1
            while i < len(lines) and "==== bytecode end" not in lines[i]:
                bytecode_lines.append(lines[i])
                i += 1

            bytecode = "\n".join(bytecode_lines).strip()
            bytecode_dict[filename] = bytecode
        i += 1

    return bytecode_dict


def update_bytecode_in_proof_script(filepath, bytecode):
    content = read_file(filepath)

    # Check if markers exist
    start_marker = "(*** BYTECODE START ***)"
    end_marker = "(*** BYTECODE END ***)"

    if start_marker not in content or end_marker not in content:
        raise Exception(f"Could not find BYTECODE START/END markers in {filepath}")

    # Replace content between markers
    pattern = rf"{re.escape(start_marker)}.*?{re.escape(end_marker)}"
    replacement = f"{start_marker}\n{bytecode}\n{end_marker}"

    updated_content = re.sub(pattern, replacement, content, flags=re.DOTALL)

    update_file(filepath, updated_content)


def update_hol_light_bytecode_for_arch(arch, force_cross=False):
    source_arch = arch
    if platform.machine().lower() in ["arm64", "aarch64"]:
        native_arch = "aarch64"
    else:
        native_arch = "x86_64"

    if native_arch != source_arch:
        cross_prefix = f"{source_arch}-unknown-linux-gnu-"
        cross_gcc = cross_prefix + "gcc"
        # Check if cross-compiler is present
        if shutil.which(cross_gcc) is None:
            if force_cross is False:
                return
            raise Exception(f"Could not find cross toolchain {cross_prefix}")

    # Run make to get bytecode output
    result = subprocess.run(
        ["make", "-C", "proofs/hol_light/" + arch, "dump_bytecode"],
        capture_output=True,
        text=True,
        check=True,
    )
    output_text = result.stdout

    # Extract bytecode
    bytecode_dict = extract_bytecode_from_output(output_text)

    # Update each .ml file
    for obj_name, bytecode in bytecode_dict.items():
        ml_file = "proofs/hol_light/" + arch + "/proofs/" + obj_name + ".ml"
        update_bytecode_in_proof_script(ml_file, bytecode)


def update_hol_light_bytecode(force_cross=False):
    """Update HOL Light proof files with bytecode from make dump_bytecode."""
    # NOTE: The following line is commented out until there are hol_light aarch64 proofs.
    # update_hol_light_bytecode_for_arch("aarch64", force_cross=force_cross)
    update_hol_light_bytecode_for_arch("x86_64", force_cross=force_cross)


def gen_test_config(config_path, config_spec, default_config_content):
    """Generate a config file by modifying the default config."""

    # Start with the default config
    lines = default_config_content.split("\n")

    # Find copyright and reference header
    references_start = None
    references_end = None

    # NOTE: This needs further work if any custom config contains citations
    # not included in the default configuration. In this case, the reference
    # section needs to be updated.

    for i, line in enumerate(lines):
        if "/* References" in line:
            references_start = i
        elif (
            references_start is not None
            and references_end is None
            and line.strip() == "*/"
        ):
            references_end = i + 1
            break

    header = lines[:references_end]
    header += list(gen_autogen_warning())

    header.append("")
    header.append("/*")
    header.append(f" * Test configuration: {config_spec['description']}")
    header.append(" *")
    header.append(
        " * This configuration differs from the default mldsa/mldsa_native_config.h in the following places:"
    )

    def spec_has_value(opt_value):
        if not isinstance(opt_value, dict):
            return True
        else:
            return "value" in opt_value.keys() or "content" in opt_value.keys()

    for opt_name, opt_value in config_spec["defines"].items():
        if not spec_has_value(opt_value):
            continue
        header.append(f" *   - {opt_name}")
    header.append(" */")
    header.append("")

    # Combine: new header + config body
    lines = header + lines[references_end:]

    def locate_config_option(lines, opt_name):
        """Locate configuration option in lines. Returns (start, end) line indices"""
        i = 0
        while i < len(lines):
            if re.search(rf"\* Name:\s+{re.escape(opt_name)}\b", lines[i]):
                # Skip to the end of this comment block (find the closing */)
                while i < len(lines) and not lines[i].strip().endswith("*/"):
                    i += 1
                i += 1  # Skip the closing */

                start = i
                # Find the next config option (starts with /*****)
                while i < len(lines):
                    if lines[i].strip().startswith("/****"):
                        # Back up to exclude empty line before next option
                        while i > start and lines[i - 1].strip() == "":
                            i -= 1
                        return (start, i)
                    i += 1
                # If no next option found, go to end
                return (start, len(lines))
            i += 1
        raise Exception(f"Could not find config option {opt_name} in default config")

    def replace_config_option(lines, opt_name, opt_value):
        """Replace config option with new value"""
        block_start, block_end = locate_config_option(lines, opt_name)

        def content_from_value(value):
            if value is True:
                return [f"#define {opt_name}"]
            elif value is False:
                return [f"/* #define {opt_name} */"]
            else:
                return [f"#define {opt_name} {str(value)}"]

        if isinstance(opt_value, dict):
            if "content" in opt_value:
                content = opt_value.get("content").split("\n")
            elif "value" in opt_value:
                content = content_from_value(opt_value["value"])
            else:
                # Use original content
                content = lines[block_start:block_end]
            if "comment" in opt_value:
                comment = opt_value.get("comment").split("\n")
            else:
                comment = []
        else:
            content = content_from_value(opt_value)
            comment = []

        lines[block_start:block_end] = comment + content

    # Apply modifications for each defined option
    for opt_name, opt_value in config_spec["defines"].items():
        replace_config_option(lines, opt_name, opt_value)

    content = "\n".join(lines)
    update_file(config_path, content)


def gen_test_configs():
    """Generate all test configuration files from metadata."""
    # Load metadata
    metadata = yaml.safe_load(read_file("test/configs/configs.yml"))

    # Load default config
    default_config = read_file("mldsa/mldsa_native_config.h")

    # Generate each test config
    for config_spec in metadata["configs"]:
        gen_test_config(config_spec["path"], config_spec, default_config)


def _main():
    slothy_choices = [
        "ntt",
        "intt",
        "mld_polyvecl_pointwise_acc_montgomery_l4",
        "mld_polyvecl_pointwise_acc_montgomery_l5",
        "mld_polyvecl_pointwise_acc_montgomery_l7",
        "pointwise_montgomery",
        "poly_caddq_asm",
        "poly_chknorm_asm",
        "poly_decompose_32_asm",
        "poly_decompose_88_asm",
        "poly_use_hint_32_asm",
        "poly_use_hint_88_asm",
        "polyz_unpack_17_asm",
        "polyz_unpack_19_asm",
        "rej_uniform_asm",
        "rej_uniform_eta2_asm",
        "rej_uniform_eta4_asm",
    ]

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--dry-run", default=False, action="store_true")
    parser.add_argument(
        "--update-hol-light-bytecode", default=False, action="store_true"
    )
    parser.add_argument("--slothy", nargs="*", default=None, choices=slothy_choices)
    parser.add_argument("--aarch64-clean", default=False, action="store_true")
    parser.add_argument("--no-simplify", default=False, action="store_true")
    parser.add_argument("--force-cross", default=False, action="store_true")
    parser.add_argument(
        "--x86-64-syntax",
        type=str,
        choices=["att", "intel"],
        default="att",
        help="Assembly syntax for x86_64 disassembly output (att or intel)",
    )

    args = parser.parse_args()

    os.chdir(os.path.join(os.path.dirname(__file__), ".."))

    if args.slothy == []:
        args.slothy = slothy_choices

    def sync_backends():
        synchronize_backends(
            clean=args.aarch64_clean,
            no_simplify=args.no_simplify,
            force_cross=args.force_cross,
            x86_64_syntax=args.x86_64_syntax,
        )

    def sync_backends_final():
        synchronize_backends(
            clean=args.aarch64_clean,
            delete=True,
            force_cross=args.force_cross,
            no_simplify=args.no_simplify,
            x86_64_syntax=args.x86_64_syntax,
        )

    def gen_zeta_tables():
        gen_c_zeta_file()
        gen_aarch64_zeta_file()
        gen_aarch64_rej_uniform_table()
        gen_aarch64_rej_uniform_eta_table()
        gen_avx2_hol_light_zeta_file()
        gen_avx2_zeta_file()
        gen_avx2_rej_uniform_table()

    def gen_monolithic():
        gen_monolithic_source_file()
        gen_monolithic_asm_file()

    hol_light_asm_supported = platform.machine().lower() in ["x86_64"]

    # Build step list: (description, function, enabled)
    # If enabled is False, step is skipped
    steps = [
        ("Generate citations", gen_citations),
        ("Generate OQS META.yml files", gen_oqs_meta_files),
        ("Normalize assembly macro syntax", normalize_asm_macro_syntax),
        (
            "Generate SLOTHY optimized assembly",
            lambda: gen_slothy(args.slothy),
            args.slothy is not None and not args.dry_run,
        ),
        ("Check assembly register aliases", check_asm_register_aliases),
        ("Check assembly loop labels", check_asm_loop_labels),
        ("Generate zeta and lookup tables", gen_zeta_tables),
        ("Generate HOL Light assembly", gen_hol_light_asm, hol_light_asm_supported),
        ("Synchronize backends", sync_backends),
        ("Generate header guards", gen_header_guards),
        ("Complete final backend synchronization", sync_backends_final),
        (
            "Update HOL Light bytecode",
            partial(update_hol_light_bytecode, force_cross=args.force_cross),
            args.update_hol_light_bytecode,
        ),
        ("Generate monolithic source files", gen_monolithic),
        ("Generate undefs", gen_undefs),
        ("Generate test configs", gen_test_configs),
        ("Check macro typos", check_macro_typos),
        ("Generate preprocessor comments", gen_preprocessor_comments),
        # Formatting should be the last step
        ("Format files", lambda: format_files(args.dry_run)),
    ]

    global _progress, _main_task
    with Progress(
        BarColumn(),
        TaskProgressColumn(),
        TimeElapsedColumn(),
        TextColumn("{task.description}"),
        console=console,
    ) as progress:
        _progress = progress
        _main_task = progress.add_task("autogen", total=len(steps) + 1)

        for step in steps:
            desc, func = step[0], step[1]
            enabled = step[2] if len(step) > 2 else True
            high_level_task(desc)
            if enabled:
                func()
            high_level_status(desc, skipped=not enabled)

        high_level_task("Write files")
        finalize(args.dry_run)
        txt = (
            "Finalize and check files are up to date"
            if args.dry_run
            else "Finalize and write files"
        )
        high_level_status(txt)

        _progress.update(_main_task, description="[green] Done ✓[/]")
        _progress = None

    return print_check_errors()


if __name__ == "__main__":
    sys.exit(0 if _main() else 1)
