#!/usr/bin/env python3
# Copyright (c) The mlkem-native project authors
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT

import sys
import re
import argparse


def add_cfi_directives(text, arch):
    lines = text.split("\n")
    result = []
    i = 0

    while i < len(lines):
        line = lines[i].rstrip()

        if arch == "aarch64":
            # Check for SIMD save pattern: stp d8,d9; stp d10,d11; stp d12,d13; stp d14,d15
            if i + 3 < len(lines):
                pattern_text = "\n".join(lines[i : i + 4])
                simd_save_pattern = (
                    r"(\s*)stp\s+d8,\s*d9,\s*\[sp(?:,\s*#([^]]+))?\]\s*\n"
                    r"\s*stp\s+d10,\s*d11,\s*\[sp(?:,\s*#([^]]+))?\]\s*\n"
                    r"\s*stp\s+d12,\s*d13,\s*\[sp(?:,\s*#([^]]+))?\]\s*\n"
                    r"\s*stp\s+d14,\s*d15,\s*\[sp(?:,\s*#([^]]+))?\]"
                )
                match = re.match(simd_save_pattern, pattern_text, re.IGNORECASE)
                if match:
                    indent = match.group(1)
                    offsets = [match.group(j + 2) or "0" for j in range(4)]
                    for j, reg_pair in enumerate(
                        [(8, 9), (10, 11), (12, 13), (14, 15)]
                    ):
                        result.append(lines[i + j].rstrip())
                        try:
                            offset_val = int(offsets[j], 0)
                            result.append(
                                f"{indent}.cfi_rel_offset d{reg_pair[0]}, 0x{offset_val:x}"
                            )
                            result.append(
                                f"{indent}.cfi_rel_offset d{reg_pair[1]}, 0x{offset_val+8:x}"
                            )
                        except:
                            result.append(
                                f"{indent}.cfi_rel_offset d{reg_pair[0]}, {offsets[j]}"
                            )
                            result.append(
                                f"{indent}.cfi_rel_offset d{reg_pair[1]}, ({offsets[j]}+8)"
                            )
                    i += 4
                    continue

            # Check for SIMD restore pattern: ldp d8,d9; ldp d10,d11; ldp d12,d13; ldp d14,d15
            if i + 3 < len(lines):
                pattern_text = "\n".join(lines[i : i + 4])
                simd_restore_pattern = (
                    r"(\s*)ldp\s+d8,\s*d9,\s*\[sp(?:,\s*#[^]]+)?\]\s*\n"
                    r"\s*ldp\s+d10,\s*d11,\s*\[sp(?:,\s*#[^]]+)?\]\s*\n"
                    r"\s*ldp\s+d12,\s*d13,\s*\[sp(?:,\s*#[^]]+)?\]\s*\n"
                    r"\s*ldp\s+d14,\s*d15,\s*\[sp(?:,\s*#[^]]+)?\]"
                )
                match = re.match(simd_restore_pattern, pattern_text, re.IGNORECASE)
                if match:
                    indent = match.group(1)
                    for j, reg_pair in enumerate(
                        [(8, 9), (10, 11), (12, 13), (14, 15)]
                    ):
                        result.append(lines[i + j].rstrip())
                        result.append(f"{indent}.cfi_restore d{reg_pair[0]}")
                        result.append(f"{indent}.cfi_restore d{reg_pair[1]}")
                    i += 4
                    continue

            # Check for GPR save pattern: stp x19,x20 through stp x29,x30
            if i + 5 < len(lines):
                pattern_text = "\n".join(lines[i : i + 6])
                gpr_save_pattern = (
                    r"(\s*)stp\s+x19,\s*x20,\s*\[sp(?:,\s*#([^]]+))?\]\s*\n"
                    r"\s*stp\s+x21,\s*x22,\s*\[sp(?:,\s*#([^]]+))?\]\s*\n"
                    r"\s*stp\s+x23,\s*x24,\s*\[sp(?:,\s*#([^]]+))?\]\s*\n"
                    r"\s*stp\s+x25,\s*x26,\s*\[sp(?:,\s*#([^]]+))?\]\s*\n"
                    r"\s*stp\s+x27,\s*x28,\s*\[sp(?:,\s*#([^]]+))?\]\s*\n"
                    r"\s*stp\s+x29,\s*x30,\s*\[sp(?:,\s*#([^]]+))?\]"
                )
                match = re.match(gpr_save_pattern, pattern_text, re.IGNORECASE)
                if match:
                    indent = match.group(1)
                    offsets = [match.group(j + 2) or "0" for j in range(6)]
                    for j, reg_pair in enumerate(
                        [(19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30)]
                    ):
                        result.append(lines[i + j].rstrip())
                        try:
                            offset_val = int(offsets[j], 0)
                            result.append(
                                f"{indent}.cfi_rel_offset x{reg_pair[0]}, 0x{offset_val:x}"
                            )
                            result.append(
                                f"{indent}.cfi_rel_offset x{reg_pair[1]}, 0x{offset_val+8:x}"
                            )
                        except:
                            result.append(
                                f"{indent}.cfi_rel_offset x{reg_pair[0]}, {offsets[j]}"
                            )
                            result.append(
                                f"{indent}.cfi_rel_offset x{reg_pair[1]}, ({offsets[j]}+8)"
                            )
                    i += 6
                    continue

            # Check for GPR restore pattern: ldp x19,x20 through ldp x29,x30
            if i + 5 < len(lines):
                pattern_text = "\n".join(lines[i : i + 6])
                gpr_restore_pattern = (
                    r"(\s*)ldp\s+x19,\s*x20,\s*\[sp(?:,\s*#[^]]+)?\]\s*\n"
                    r"\s*ldp\s+x21,\s*x22,\s*\[sp(?:,\s*#[^]]+)?\]\s*\n"
                    r"\s*ldp\s+x23,\s*x24,\s*\[sp(?:,\s*#[^]]+)?\]\s*\n"
                    r"\s*ldp\s+x25,\s*x26,\s*\[sp(?:,\s*#[^]]+)?\]\s*\n"
                    r"\s*ldp\s+x27,\s*x28,\s*\[sp(?:,\s*#[^]]+)?\]\s*\n"
                    r"\s*ldp\s+x29,\s*x30,\s*\[sp(?:,\s*#[^]]+)?\]"
                )
                match = re.match(gpr_restore_pattern, pattern_text, re.IGNORECASE)
                if match:
                    indent = match.group(1)
                    for j, reg_pair in enumerate(
                        [(19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30)]
                    ):
                        result.append(lines[i + j].rstrip())
                        result.append(f"{indent}.cfi_restore x{reg_pair[0]}")
                        result.append(f"{indent}.cfi_restore x{reg_pair[1]}")
                    i += 6
                    continue

            # Rule 7: add sp, sp, #offset -> .cfi_adjust_cfa_offset (-(offset))
            match = re.match(
                r"(\s*)add\s+sp,\s*sp,\s*#(0x[0-9a-fA-F]+|\d+)", line, re.IGNORECASE
            )
            if match:
                indent, offset_str = match.groups()
                offset = (
                    int(offset_str, 16)
                    if offset_str.lower().startswith("0x")
                    else int(offset_str)
                )
                result.append(line)
                result.append(f"{indent}.cfi_adjust_cfa_offset -{offset:#x}")
                i += 1
                continue

            # Rule 8: sub sp, sp, #offset -> .cfi_adjust_cfa_offset (offset)
            match = re.match(
                r"(\s*)sub\s+sp,\s*sp,\s*#(0x[0-9a-fA-F]+|\d+)", line, re.IGNORECASE
            )
            if match:
                indent, offset_str = match.groups()
                offset = (
                    int(offset_str, 16)
                    if offset_str.lower().startswith("0x")
                    else int(offset_str)
                )
                result.append(line)
                result.append(f"{indent}.cfi_adjust_cfa_offset {offset:#x}")
                i += 1
                continue

            # Rule 2: ret -> .cfi_endproc after ret
            match = re.match(r"(\s*)ret\s*$", line, re.IGNORECASE)
            if match:
                indent = match.group(1)
                result.append(line)
                result.append(f"{indent}.cfi_endproc")
                i += 1
                continue

        elif arch == "x86_64":
            # Check for labels and see if there's a corresponding callq
            label_match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*):$", line)
            if label_match:
                label = label_match.group(1)
                # Check if this label is called anywhere in the text
                if re.search(rf"\s*callq\s+{re.escape(label)}\b", text, re.IGNORECASE):
                    result.append(line)
                    result.append("        .cfi_startproc")
                    i += 1
                    continue

            # x86_64: subq $OFFSET, %rsp -> .cfi_adjust_cfa_offset OFFSET (stack alloc)
            match = re.match(
                r"(\s*)subq\s+\$(0x[0-9a-fA-F]+|\d+),\s*%rsp", line, re.IGNORECASE
            )
            if match:
                indent, offset_str = match.groups()
                offset = (
                    int(offset_str, 16)
                    if offset_str.lower().startswith("0x")
                    else int(offset_str)
                )
                result.append(line)
                result.append(f"{indent}.cfi_adjust_cfa_offset {offset:#x}")
                i += 1
                continue

            # x86_64: addq $OFFSET, %rsp -> .cfi_adjust_cfa_offset -OFFSET (stack free)
            match = re.match(
                r"(\s*)addq\s+\$(0x[0-9a-fA-F]+|\d+),\s*%rsp", line, re.IGNORECASE
            )
            if match:
                indent, offset_str = match.groups()
                offset = (
                    int(offset_str, 16)
                    if offset_str.lower().startswith("0x")
                    else int(offset_str)
                )
                result.append(line)
                result.append(f"{indent}.cfi_adjust_cfa_offset -{offset:#x}")
                i += 1
                continue

            # x86_64: ret/retq -> .cfi_endproc after ret
            match = re.match(r"(\s*)retq?\s*$", line, re.IGNORECASE)
            if match:
                indent = match.group(1)
                result.append(line)
                result.append(f"{indent}.cfi_endproc")
                i += 1
                continue

        result.append(line)
        i += 1

    return "\n".join(result)


def main():

    parser = argparse.ArgumentParser(
        description="Add CFI directives to AArch64 assembly"
    )
    parser.add_argument("-i", "--input", help="Input file (default: stdin)")
    parser.add_argument("-o", "--output", help="Output file (default: stdout)")
    parser.add_argument(
        "--emit-cfi-proc-start",
        action="store_true",
        help="Emit .cfi_proc_start as first line",
    )
    parser.add_argument(
        "--arch",
        choices=["aarch64", "x86_64"],
        default="aarch64",
        help="Target architecture (default: aarch64)",
    )
    args = parser.parse_args()

    input_file = open(args.input, "r") if args.input else sys.stdin
    output_file = open(args.output, "w") if args.output else sys.stdout

    try:
        # Read all input
        text = input_file.read()

        # Add initial .cfi_startproc if requested
        if args.emit_cfi_proc_start:
            text = "        .cfi_startproc\n" + text

        # Process the text
        result = add_cfi_directives(text, args.arch)

        # Write output
        output_file.write(result)
    finally:
        if args.input:
            input_file.close()
        if args.output:
            output_file.close()


if __name__ == "__main__":
    main()
