#!/usr/bin/env python3
# Copyright (c) The mlkem-native project authors
# Copyright (c) The mldsa-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

#
# Looks for CBMC contracts without proof
#

import re
import sys
import subprocess
import pathlib

def get_c_source_files():
    return get_files("mldsa/**/*.c")

def get_header_files():
    return get_files("mldsa/**/*.h")

def get_files(pattern):
    return list(map(str, pathlib.Path().glob(pattern)))

def gen_proofs():
     cmd_str = ["./proofs/cbmc/list_proofs.sh"]
     p = subprocess.run(cmd_str, capture_output=True, universal_newlines=False)
     proofs = filter(lambda s: s.strip() != "", p.stdout.decode().split("\n"))
     return proofs

def gen_contracts():
    files = get_c_source_files() + get_header_files()

    for filename in files:
        with open(filename, "r") as f:
            content = f.read()

        contract_pattern = r'(\w+)\s*\([^)]*\)\s*\n?\s*__contract__'
        matches = re.finditer(contract_pattern, content)
        for m in matches:
            line = content.count('\n', 0, m.start())
            yield (filename, line, m.group(1).removeprefix("mld_"))

def is_exception(funcname):
    # The functions passing this filter are known not to have a proof

    if funcname == 'poly_permute_bitrev_to_custom':
        return True
    
    if funcname.endswith("_native") or funcname.endswith("_asm"):
        # CBMC proofs are axiomatized against contracts of the backends
        return True

    if funcname == "ct_get_optblocker_u64":
        # As documented in the code, this contract is treated as an axiom
        return True

    if funcname in ["memcmp", "randombytes"]:
        # External functions
        return True

    if funcname in ["zeroize"]:
        # Implemented using inline ASM or external functions
        return True

    return False

def check_contracts():
    contracts = set(gen_contracts())
    proofs = set(gen_proofs())

    bad = []

    # Print contracts without proofs
    for (filename, line, funcname) in contracts:
        if funcname in proofs:
            continue

        if is_exception(funcname):
            print(f"{filename}:{line}:{funcname} has contract but no proof, "
                  "but is listed as exception")
            continue

        print(f"{filename}:{line}:{funcname}: has contract but no proof. FAIL",
              file=sys.stderr)
        bad.append(funcname)

    return len(bad) == 0

def _main():
    if check_contracts() != True:
        sys.exit(1)

if __name__ == "__main__":
    _main()
