#!/usr/bin/env python3
# Copyright (c) The mlkem-native project authors
# Copyright (c) The mldsa-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

import subprocess
import argparse
import logging
import pathlib
import tempfile
import platform
import sys
import os
import re


def patchup_disasm(asm, cfify=False):
    asm = asm.split("\n")
    indentation = 8

    def decode_label(asm_line):
        r = re.search(r"^\s*[0-9a-fA-F]+\s*<([a-zA-Z0-9_]+)>:\s*$", asm_line)
        if r is None:
            return None
        return r.group(1)

    def make_label(lbl):
        if cfify:
            return "L" + lbl + ":"
        return lbl + ":"

    # Find first label
    for i, l in enumerate(asm):
        if decode_label(l) is not None:
            break

    asm = asm[i + 1 :]

    def gen(asm):
        for l in asm:
            if l.strip() == "":
                yield ""
                continue
            lbl = decode_label(l)
            # Re-format labels as assembly labels
            if lbl is not None:
                yield make_label(lbl)
                continue
            # Drop comments
            l = l.split(";")[0]
            # Re-format references to labels
            # Those are assumed to have the form `0xDEADBEEF <label>`
            if cfify:
                l = re.sub(
                    r"(0x)?[0-9a-fA-F]+\s+<(?P<label>[a-zA-Z0-9_]+)>", r"L\g<label>", l
                )
            else:
                l = re.sub(
                    r"(0x)?[0-9a-fA-F]+\s+<(?P<label>[a-zA-Z0-9_]+)>", r"\g<label>", l
                )
            # Drop address and byte code from line
            d = re.search(
                r"^\s*[0-9a-fA-F]+:\s+([0-9a-fA-F][0-9a-fA-F][ ]*)+\s+(?P<inst>.*)$", l
            )
            if d is None:
                raise Exception(
                    f'The following does not seem to be an assembly line of the expected format `ADDRESS: BYTECODE INSTRUCTION`:\n"{l}"'
                )
            yield " " * indentation + d.group("inst")

    return list(gen(asm))


def find_header_footer(asm, filename):
    header_end_marker = "simpasm: header-end"
    footer_start_marker = "simpasm: footer-start"

    # Extract header
    header_end = None
    for i, l in enumerate(asm):
        if header_end_marker in l:
            header_end = i
            break
    if header_end is None:
        raise Exception(
            f"Could not find header-end marker {header_end_marker} in {filename}"
        )
    header = asm[:header_end]

    # Extract footer
    footer_start = None
    for i, l in enumerate(asm):
        if footer_start_marker in l:
            footer_start = i
            break
    if footer_start is None:
        raise Exception(
            f"Could not find footer-start marker {footer_start_marker} in {filename}"
        )
    footer = asm[footer_start + 1 :]

    body = asm[header_end + 1 : footer_start]

    return header, body, footer


def find_globals(asm):
    global_symbols = []
    for l in asm:
        r = re.search(r"^\s*\.global\s+(.*)$", l)
        if r is None:
            continue
        global_symbols.append(r.group(1))
    return global_symbols


# Converts `#if ...` statements into `#if 1` in header to avoid having
# to specify various `-D...` in the CFLAGS. The original header will be
# reinstated in the final assembly, so the output is subject to the same
# guards as the input.
def drop_if_from_header(header):
    header_new = []
    i = 0
    while i < len(header):
        l = header[i]
        if not l.strip().startswith("#if"):
            header_new.append(l)
            i += 1
            continue
        header_new.append("#if 1")
        while i < len(header) and header[i].endswith("\\"):
            i += 1
        i += 1

    return header_new


def drop_preprocessor_directives(header):
    header_new = []
    i = 0
    while i < len(header):
        l = header[i]
        if not l.strip().startswith("#"):
            header_new.append(l)
            i += 1
            continue
        while i < len(header) and header[i].endswith("\\"):
            i += 1
        i += 1

    return header_new


def simplify(logger, args, asm_input, asm_output=None):

    def run_cmd(cmd, input=None):
        logger.debug(f"Running command: {' '.join(cmd)}")
        try:
            r = subprocess.run(
                cmd, capture_output=True, input=input, text=True, check=True
            )
            return r
        except subprocess.CalledProcessError as e:
            logger.error(f"Command failed: {' '.join(cmd)}")
            logger.error(f"Exit code: {e.returncode}")
            logger.error(f"stderr: {e.stderr}")
            raise Exception("simpasm failed") from e

    if asm_output is None:
        asm_output = asm_input

    # Load input assembly
    with open(asm_input, "r") as f:
        orig_asm = f.read().split("\n")

    header, body, footer = find_header_footer(orig_asm, asm_input)

    # Extract unique global symbol from assembly
    syms = find_globals(orig_asm)
    if len(syms) != 1:
        logger.error(
            f"Expected exactly one global symbol in {asm_input}, but found {syms}"
        )
        raise Exception("simpasm failed")
    sym = syms[0]

    if args.cflags is not None:
        cflags = args.cflags.split(" ")
    else:
        cflags = []

    # Create temporary object files for byte code before/after simplification
    with tempfile.NamedTemporaryFile(
        suffix=".o", delete=False
    ) as tmp0, tempfile.NamedTemporaryFile(suffix=".o", delete=False) as tmp1:

        tmp_objfile0 = tmp0.name
        tmp_objfile1 = tmp1.name

        cmd = (
            [args.cc, "-c", "-x", "assembler-with-cpp"]
            + cflags
            + ["-o", tmp_objfile0, "-"]
        )
        logger.debug(f"Assembling {asm_input} ...")
        asm_no_if = "\n".join(drop_if_from_header(header) + body + footer)
        run_cmd(cmd, input=asm_no_if)

        # Remember the binary contents of the object file for later
        tmp0.seek(0)
        orig_obj = tmp0.read()

        # Check that there is exactly one global symbol at location 0
        cmd = [args.nm, "--extern-only", tmp_objfile0]
        logger.debug(
            f"Extracting symbols from temporary object file {tmp_objfile0} ..."
        )
        r = run_cmd(cmd)

        nm_output = r.stdout.split("\n")
        nm_output = list(filter(lambda s: s.strip() != "", nm_output))
        if len(nm_output) == 0:
            logger.error(
                f"Found one .global annotation in {asm_input}, but no external symbols in object file {tmp_objfile0} -- should not happen?"
            )
            logger.error(asm_no_if)
            raise Exception("simpasm failed")
        elif len(nm_output) > 1:
            logger.error(
                f"Found only one .global annotation in {asm_input}, but multiple external symbols {nm_output} in object file -- should not happen?"
            )
            raise Exception("simpasm failed")
        sym_info = nm_output[0].split(" ")
        sym_addr = int(sym_info[0])
        if sym_addr != 0:
            logger.error(
                f"Global sym {sym} not at address 0 (instead at address {hex(sym_addr)}) -- please reorder the assembly to start with the global function symbol"
            )
            raise Exception("simpasm failed")

        # If we don't preserve preprocessor directives, use the raw global symbol name instead;
        # otherwise, end up emitting a namespaced symbol without including the header needed to
        # make sense of it.
        if args.preserve_preprocessor_directives is False:
            # Expecting format "ADDRESS T SYMBOL"
            sym = sym_info[2]
            logger.debug(f"Using raw global symbol {sym} going forward ...")

        cmd = [args.objdump, "--disassemble", tmp_objfile0]
        if platform.system() == "Darwin" and args.arch == "aarch64":
            cmd += ["--triple=aarch64"]

        # Add syntax option if specified
        if args.syntax and args.syntax.lower() != "att":
            cmd += [f"--x86-asm-syntax={args.syntax}"]

        logger.debug(f"Disassembling temporary object file {tmp_objfile0} ...")
        disasm = run_cmd(cmd).stdout

        logger.debug("Patching up disassembly ...")
        simplified = patchup_disasm(disasm, cfify=args.cfify)

        autogen_header = [
            "",
            "/*",
            f" * WARNING: This file is auto-derived from the mldsa-native source file",
            f" *   {asm_input} using scripts/simpasm. Do not modify it directly.",
            " */",
            "",
            "#if defined(__ELF__)",
            '.section .note.GNU-stack,"",@progbits',
            "#endif",
            "",
            ".text",
            ".balign 4",
        ]

        # Add syntax specifier for Intel syntax
        if args.syntax and args.syntax.lower() == "intel":
            autogen_header.append(".intel_syntax noprefix")

        if args.preserve_preprocessor_directives is False:
            if platform.system() == "Darwin" and sym[0] == "_":
                sym = sym[1:]
            autogen_header += [
                "#ifdef __APPLE__",
                f".global _{sym}",
                f"_{sym}:",
                "#else",
                f".global {sym}",
                f"{sym}:",
                "#endif",
                "",
            ]
            simplified_header = drop_preprocessor_directives(header)
            simplified_footer = []
        else:
            autogen_header += [
                f".global {sym}",
                f"{sym.replace('MLD_ASM_NAMESPACE', 'MLD_ASM_FN_SYMBOL')}",
                "",
            ]
            simplified_header = header
            simplified_footer = footer

        # Apply CFI directives if requested
        if args.cfify:
            logger.debug("Applying CFI directives...")
            cfify_script = os.path.join(os.path.dirname(__file__), "cfify")
            cfify_cmd = [cfify_script, "--emit-cfi-proc-start", "--arch", args.arch]
            cfify_result = run_cmd(cfify_cmd, input="\n".join(simplified))
            simplified = cfify_result.stdout.split("\n")

        # Write simplified assembly file
        full_simplified = (
            simplified_header + autogen_header + simplified + simplified_footer
        )

        logger.debug(f"Writing simplified assembly to {asm_output} ...")
        with open(asm_output, "w+") as f:
            f.write("\n".join(full_simplified))

        cmd = (
            [args.cc, "-c", "-x", "assembler-with-cpp"]
            + cflags
            + ["-o", tmp_objfile1, "-"]
        )

        new_asm = "\n".join(
            drop_if_from_header(simplified_header)
            + autogen_header
            + simplified
            + simplified_footer
        )

        logger.debug(f"Assembling simplified assembly ...")
        logger.debug(new_asm)
        logger.debug(f"Command: {' '.join(cmd)}")
        run_cmd(cmd, input=new_asm)

        # Get binary contents of re-assembled object file
        tmp1.seek(0)
        simplified_obj = tmp1.read()

        logger.debug("Checking that byte-code is unchanged ...")

        # When CFI is enabled, compare only the __text section content
        if args.cfify:
            logger.debug("Comparing __text section content for CFI comparison...")

            # Extract __text section from both files
            orig_text_cmd = [
                args.objdump,
                "-s",
                "--section=__text",
                "--section=.text",
                tmp_objfile0,
            ]
            simplified_text_cmd = [
                args.objdump,
                "-s",
                "--section=__text",
                "--section=.text",
                tmp_objfile1,
            ]

            orig_text_result = run_cmd(orig_text_cmd)
            simplified_text_result = run_cmd(simplified_text_cmd)

            # Remove filename lines from both outputs
            orig_lines = [
                line
                for line in orig_text_result.stdout.split("\n")
                if tmp_objfile0 not in line
            ]
            simplified_lines = [
                line
                for line in simplified_text_result.stdout.split("\n")
                if tmp_objfile1 not in line
            ]

            if orig_lines != simplified_lines:
                logger.error(
                    f"__text section content of {tmp_objfile0} and {tmp_objfile1} before/after simplification differs -- aborting"
                )
                logger.error("I'll keep them around for you to have a look")
                raise Exception("simpasm failed")
        else:
            if orig_obj != simplified_obj:
                logger.error(
                    f"Object files {tmp_objfile0} and {tmp_objfile1} before/after simplification are not byte-wise equal -- aborting"
                )
                logger.error("I'll keep them around for you to have a look")
                raise Exception("simpasm failed")

    os.unlink(tmp_objfile0)
    os.unlink(tmp_objfile1)

    logger.info(f"Simplified {asm_input} -> {asm_output} (same byte code)")


def _main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("-d", "--directory", type=str, help="Input assembly file")
    parser.add_argument("-i", "--input", type=str, help="Input assembly file")
    parser.add_argument("-o", "--output", type=str, help="Output assembly file")
    parser.add_argument(
        "-p", "--preserve-preprocessor-directives", default=False, action="store_true"
    )
    parser.add_argument("-v", "--verbose", default=False, action="store_true")
    parser.add_argument(
        "--cc", type=str, default="gcc" if platform.system() != "Darwin" else "clang"
    )
    parser.add_argument("--nm", type=str, default="nm")
    parser.add_argument("--objdump", type=str, default="objdump")
    parser.add_argument("--strip", type=str, default="llvm-strip")
    parser.add_argument("--cflags", type=str)
    parser.add_argument("--cfify", action="store_true", help="Apply CFI directives")
    parser.add_argument(
        "--arch",
        type=str,
        default="aarch64",
        help="Target architecture for CFI directives",
    )

    parser.add_argument(
        "--syntax",
        type=str,
        choices=["att", "intel"],
        default="att",
        help="Assembly syntax for disassembly output (att or intel)",
    )

    args = parser.parse_args()

    os.chdir(os.path.join(os.path.dirname(__file__), ".."))

    if (
        args.cflags is not None
        and args.cflags.startswith('"')
        and args.cflags.endswith('"')
    ):
        args.cflags = args.cflags[1:-1]

    logging.basicConfig(stream=sys.stdout, format="%(name)s: %(message)s")

    logger = logging.getLogger("simpasm")
    if args.verbose is True:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    if args.input is not None:
        simplify(logger, args, args.input, args.output)
    if args.directory is not None:
        # Simplify all files in directory
        asm_files = pathlib.Path(args.directory).glob("*.S")
        for f in asm_files:
            simplify(logger, args, f)


if __name__ == "__main__":
    _main()
