#!/usr/bin/env python3
# Copyright (c) The mlkem-native project authors
# Copyright (c) The mldsa-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

# This scripts runs nm on the object files (excluding test objects) and checks that all exported
# symbols are properly namespaced.
# It assumes that object files are present under test/build/mldsa{44,65,87} and
# test/build/fips202.

# The checked namespaces are
# PQCP_MLDSA_NATIVE_FIPS202_ for FIPS202 code
# PQCP_MLDSA_NATIVE_MLDSA44_ for MLDSA44 code
# PQCP_MLDSA_NATIVE_MLDSA65_ for MLDSA65 code
# PQCP_MLDSA_NATIVE_MLDSA87_ for MLDSA87 code

import subprocess
import os


def check_file(file_path, namespaces):
    print("checking namespacing: {}".format(file_path))
    command = ["nm", "-g", file_path]

    result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

    result = result.stdout.decode("utf-8")
    lines = result.strip().split("\n")
    symbols = []
    for line in lines:
        if line.startswith("00"):
            symbols.append(line)

    def is_namespaced(symbol):
        for namespace in namespaces:
            if symbol.startswith(namespace) or symbol.startswith("_" + namespace):
                return True
        return False

    non_namespaced = []
    for symbolstr in symbols:
        *_, symtype, symbol = symbolstr.split()
        if symtype in "TDRS":
            if is_namespaced(symbol) is False:
                non_namespaced.append(symbol)

    if len(non_namespaced) > 0:
        print("Missing namespace literal from {}".format(namespaces))
        for symbol in non_namespaced:
            print("\tsymbol: {}".format(symbol))
    assert not non_namespaced, "Literals with missing namespaces"


def check_folder(folder, namespace):
    checked = 0
    # recursively go through folder and check all object files
    for root, dirnames, filenames in os.walk(folder):
        for filename in filenames:
            if filename.endswith(".o"):
                check_file(os.path.join(root, filename), namespace)
                checked += 1
    print("Checked {} files".format(checked))
    assert checked > 0


def make_mldsa_namespace(lvl):
    return [f"PQCP_MLDSA_NATIVE_MLDSA{lvl}"]


def run():
    check_folder("test/build/mldsa44/mldsa", make_mldsa_namespace(44))
    check_folder("test/build/mldsa65/mldsa", make_mldsa_namespace(65))
    check_folder("test/build/mldsa87/mldsa", make_mldsa_namespace(87))


if __name__ == "__main__":
    os.chdir(os.path.join(os.path.dirname(__file__), ".."))
    run()
