#!/usr/bin/env python3
# Copyright (c) The mlkem-native project authors
# Copyright (c) The mldsa-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

#
# Looks for magic numbers without explanation
#

import re
import math
import pathlib

from sympy import simplify, sympify, Function, Rational

def get_c_source_files():
    return get_files("mldsa/**/*.c")

def get_header_files():
    return get_files("mldsa/**/*.h")

def get_files(pattern):
    return list(map(str, pathlib.Path().glob(pattern)))

# Standard color definitions
GREEN="\033[32m"
RED="\033[31m"
BLUE="\033[94m"
BOLD="\033[1m"
NORMAL="\033[0m"

CHECKED = f"{GREEN}✓{NORMAL}"
FAIL = f"{RED}✗{NORMAL}"
REMEMBERED = f"{BLUE}⊢{NORMAL}"

def check_magic_numbers():
    mldsa_q = 8380417
    exceptions = [mldsa_q]
    enable_marker = "check-magic: on"
    disable_marker = "check-magic: off"
    autogen_marker = "This file is auto-generated from scripts/autogen"

    files = get_c_source_files() + get_header_files()

    def is_exception(filename, l, magic):
        return magic in exceptions

    def get_magic(l):
        regexp = r'/\* check-magic:\s+([-]?\d{4,})\s*==\s*(.*?) \*/'
        m = re.search(regexp, l)
        if m is not None:
            # Remove magic annotation to avoid it being treated
            # as magic value itself
            l = re.sub(regexp,'',l)
            return l, (int(m.group(1)), m.group(2))
        return l, None

    def get_define(l):
        m = re.search(r'#define\s+(\w+)', l)
        if m is not None:
            return m.group(1)
        return None

    def evaluate_magic(m, known_magics):
        def unsigned_mod(x,y):
            return x % y
        def signed_mod(x,y):
            r = unsigned_mod(x,y)
            if r >= y // 2:
                r -= y
            return r
        def pow_mod(x,y,m):
            x = int(x)
            y = int(y)
            m = int(m)
            return signed_mod(pow(x,y,m),m)
        def safe_round(x):
            if x - math.floor(x) == Rational(1, 2):
                raise ValueError(f"Ambiguous rounding: {x} is an odd multiple of 0.5 and it is unclear if round-up or round-down is desired")
            return round(x)
        def safe_floordiv(x, y):
            x = int(x)
            y = int(y)
            if x % y != 0:
                raise ValueError(f"Non-integral division: {x} // {y} has remainder {x % y}")
            return x // y
        locals_dict = {'signed_mod': signed_mod,
                       'unsigned_mod': unsigned_mod,
                       'pow': pow_mod,
                       'round': safe_round,
                       'intdiv': safe_floordiv }
        locals_dict.update(known_magics)
        return sympify(m, locals=locals_dict)

    for filename in files:
        with open(filename, "r") as f:
            content = f.read()
        if autogen_marker in content:
            continue
        content = content.split("\n")
        # Use negative lookbefore and lookahead to exclude numbers
        # that occur as part of identifiers (e.g. layer12345 or 199901L)
        pattern = r'(?<![0-9a-zA-Z/_-])([-]?\d{4,})(?![0-9a-zA-Z_-])'
        enabled = True
        magic_dict = {'MLDSA_Q': mldsa_q, 'MLD_REDUCE32_DOMAIN_MAX': 2143289343}
        magic_expr = None
        verified_magics = {}
        for i, l in enumerate(content):
            if enabled is True and disable_marker in l:
                enabled = False
                continue
            if enabled is False and enable_marker in l:
                enabled = True
                continue
            if enabled is False:
                continue
            l, g = get_magic(l)
            if g is not None:
                magic_val, magic_expr = g
                magic_val_check = evaluate_magic(magic_expr, magic_dict)
                if magic_val != magic_val_check:
                    print(f"{FAIL}:{filename}:{i+1}: Mismatching magic annotation: {magic_val} != {magic_expr} (= {magic_val_check})")
                    exit(1)
                print(f"{REMEMBERED}:{filename}:{i+1}: Verified explanation {magic_val} == {magic_expr}")
                verified_magics[magic_val] = magic_expr

            found = next(re.finditer(pattern, l), None)
            if found is None:
                continue

            magic = int(found.group())
            if is_exception(filename, l, magic):
                continue

            explanation = verified_magics.get(magic, None)
            if explanation is None:
                print(f"{FAIL}:{filename}:{i+1}: No explanation for magic value {magic}")
                exit(1)

            print(f"{CHECKED}:{filename}:{i+1}: {magic} previously explained as {explanation}")

            # If this is a #define's clause, remember it
            define = get_define(l)
            if define is not None:
                magic_dict[define] = magic

def _main():
    check_magic_numbers()

if __name__ == "__main__":
    _main()
