udiboy1209 commited on
Commit
7145fd6
·
1 Parent(s): e78b7eb

Add REMEND python module

Browse files
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "remend"
7
+ version = "1.0"
8
+ authors = [{name="Meet Udeshi", email="[email protected]"}]
9
+ description = "Neural Decompilation for Reverse Engineering Math Equations from Binary Executables"
10
+ readme = "README.md"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "Operating System :: OS Independent",
14
+ ]
15
+ requires-python = ">=3.9"
16
+ dependencies = [
17
+ "networkx",
18
+ "capstone",
19
+ "Levenshtein",
20
+ "tqdm",
21
+ "numpy",
22
+ "sympy",
23
+ "fairseq",
24
+ "torch",
25
+ "matplotlib",
26
+ "tokenizers"
27
+ ]
28
+
29
+ [tool.hatch.build.targets.wheel]
30
+ packages = ["remend"]
remend/__init__.py ADDED
File without changes
remend/bpe.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tokenizers import pre_tokenizers, Tokenizer
2
+ from tokenizers.models import BPE
3
+ from tokenizers.trainers import BpeTrainer
4
+ from tokenizers.pre_tokenizers import Whitespace, PreTokenizer
5
+ import random
6
+ import os
7
+ from tqdm import tqdm
8
+ import itertools as it
9
+
10
+ class ImmPreTokenizer:
11
+ def pre_tokenize(self, pretok):
12
+ pretok.split(self.hex_imm_split)
13
+ def hex_imm_split(self, i, norm_str):
14
+ tok = str(norm_str)
15
+ if tok[:2] == "0x" or tok.isdigit():
16
+ return [norm_str[i:i+1] for i in range(len(tok))]
17
+ else:
18
+ return [norm_str]
19
+
20
+ def get_asm_tok(files, save):
21
+ asm_tok = Tokenizer(BPE(unk_token="@@UNK@@"))
22
+ asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
23
+ asm_train = BpeTrainer(special_tokens=["@@UNK@@"])
24
+
25
+ asm_tok.train(files, asm_train)
26
+ asm_tok.pre_tokenizer = Whitespace() # Hack to save, careful to restore ImmPreTokenizer
27
+ asm_tok.save(save)
28
+ asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
29
+
30
+ return asm_tok
31
+
32
+ def load_asm_tok(load):
33
+ asm_tok = Tokenizer.from_file(load)
34
+ asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
35
+ return asm_tok
36
+
37
+
38
+ if __name__ == "__main__":
39
+ import argparse
40
+ parser = argparse.ArgumentParser("Train the tokenizer and tokenize the asm")
41
+ parser.add_argument("-i", "--indir", required=True, help="output directory")
42
+ parser.add_argument("-o", "--outdir", default="tokenized", help="output directory")
43
+ args = parser.parse_args()
44
+
45
+ os.makedirs(args.outdir, exist_ok=True)
46
+ injoin = lambda p: os.path.join(args.indir, p)
47
+ pjoin = lambda p: os.path.join(args.outdir, p)
48
+ max_asm_toks = 0
49
+
50
+ asm_tok = get_asm_tok([injoin("train.asm"), injoin("valid.asm")], pjoin("asm_tokens.json"))
51
+ for split in ["train", "valid", "test"]:
52
+ asmfile = split + ".asm"
53
+ with open(injoin(asmfile), "r") as asmf, open(pjoin(asmfile), "w") as asmtokf:
54
+ for asm in tqdm(asmf, desc=f"Tokenizing {split}"):
55
+ asm = asm.strip()
56
+ asm_enc = asm_tok.encode(asm)
57
+ max_asm_toks = max(max_asm_toks, len(asm_enc.tokens))
58
+ asm_seq = " ".join(asm_enc.tokens)
59
+ asmtokf.write(asm_seq + "\n")
60
+
61
+ print("Maximum tokens:", max_asm_toks)
62
+
63
+ # After this, run command:
64
+ # fairseq-preprocess -s asm -t eqn --trainpref {OUTDIR}/train --validpref {OUTDIR}/valid --testpref {OUTDIR}/test --destdir {OUTDIR}
remend/bpe_apply.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+
3
+ from .bpe import load_asm_tok
4
+
5
+ if __name__ == "__main__":
6
+ import argparse
7
+ parser = argparse.ArgumentParser("Tokenize using existing tokenizer")
8
+ parser.add_argument("-t", "--tokenizer", required=True, help="existing tokenizer")
9
+ parser.add_argument("-i", "--input", required=True, help="input file")
10
+ parser.add_argument("-o", "--output", required=True, help="output file")
11
+ args = parser.parse_args()
12
+
13
+ max_asm_toks = 0
14
+ asm_tok = load_asm_tok(args.tokenizer)
15
+
16
+ with open(args.input, "r") as asmf, open(args.output, "w") as asmtokf:
17
+ for asm in tqdm(asmf, desc=f"Tokenizing"):
18
+ asm = asm.strip()
19
+ asm_enc = asm_tok.encode(asm)
20
+ max_asm_toks = max(max_asm_toks, len(asm_enc.tokens))
21
+ asm_seq = " ".join(asm_enc.tokens)
22
+ asmtokf.write(asm_seq + "\n")
23
+
24
+ print("Maximum tokens:", max_asm_toks)
25
+
remend/change_eqn_format.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .parser import isint, OPERATORS
2
+
3
+ def prefix_to_brackets(eqn):
4
+ stack = []
5
+ lastop = []
6
+ intunit = []
7
+ N = len(eqn)
8
+ i = 0
9
+ while i < N:
10
+ # print("Stack", stack)
11
+ val = eqn[i]
12
+ if val.startswith("INT"):
13
+ intunit.append(val)
14
+ i += 1
15
+ while i < N and isint(eqn[i]):
16
+ intunit.append(eqn[i])
17
+ i += 1
18
+ stack.append(" ".join(intunit))
19
+ intunit = []
20
+ i -= 1
21
+ elif val in OPERATORS:
22
+ _, numops = OPERATORS[val]
23
+ lastop.append((len(stack), numops))
24
+ stack.append(val)
25
+ else:
26
+ stack.append(val)
27
+
28
+ while len(lastop) > 0 and len(stack) > lastop[-1][0] + lastop[-1][1]:
29
+ # Combine op
30
+ # print(lastop[-1], stack[lastop[-1][0]:])
31
+ op = " ".join(stack[lastop[-1][0]:])
32
+ del stack[lastop[-1][0]:]
33
+ lastop.pop()
34
+ stack.append(f"( {op} )")
35
+ i += 1
36
+ assert(len(stack) == 1)
37
+ return stack[0]
38
+
39
+ def prefix_to_postfix(eqn):
40
+ if eqn[0].startswith("INT"):
41
+ intunit = [eqn[0]]
42
+ for i, val in enumerate(eqn[1:]):
43
+ if not isint(val):
44
+ break
45
+ intunit.append(val)
46
+ return intunit, eqn[i+1:]
47
+ elif eqn[0] in OPERATORS:
48
+ _, numops = OPERATORS[eqn[0]]
49
+ remeqn = eqn[1:]
50
+ ops = []
51
+ for i in range(numops):
52
+ op, remeqn = prefix_to_postfix(remeqn)
53
+ ops.extend(op)
54
+ ops.append(eqn[0]) # Restructured to postfix
55
+ return ops, remeqn
56
+ else:
57
+ return [eqn[0]], eqn[1:]
58
+
59
+ if __name__ == "__main__":
60
+ import argparse
61
+ parser = argparse.ArgumentParser("Change equation format from prefix to other")
62
+ parser.add_argument("--eqn", required=True)
63
+ parser.add_argument("--out", required=True)
64
+ args = parser.parse_args()
65
+
66
+ with open(args.eqn, "r") as inf, open(args.out, "w") as outf:
67
+ for eqn in inf:
68
+ postfix, _ = prefix_to_postfix(eqn.strip().split(" "))
69
+ outf.write(" ".join(postfix) + "\n")
70
+
71
+ # eqn = "div mul x add INT+ 5 add mul INT+ 3 x mul pow x INT+ 2 add INT- 5 add mul INT- 3 x mul x mul add INT+ 1 mul k0 pow x INT+ 3 add INT+ 4 x add INT+ 5 mul INT+ 3 x"
72
+ # eqn = "div add mul INT+ 3 x pow x INT- 4 mul sub x k0 add mul INT+ 5 x k1"
73
+ # print(" ".join(prefix_to_postfix(eqn.split(" "))[0]))
74
+ # postfix = "x INT+ 3 mul INT+ 5 add x INT+ 4 add INT+ 3 x pow k0 mul INT+ 1 add mul x mul x INT- 3 mul add INT- 5 add INT+ 2 x pow mul x INT+ 3 mul add INT+ 5 add x mul div"
75
+ # print(prefix_to_brackets(eqn.split(" ")))
76
+
77
+ # (div (mul x (add INT+ 5 (add (mul INT+ 3 x) (mul (pow x INT+ 2) (add INT- 5 (add (mul INT- 3 x) (mul x (mul (add INT+ 1 (mul k0 (pow x INT+ 3))) (add INT+ 4 x))))))))) (add INT+ 5 (mul INT+ 3 x)))
78
+ # ( div ( mul x ( add INT+ 5 ( add ( mul INT+ 3 x ) ( mul ( pow x INT+ 2 ) ( add INT- 5 ( add ( mul INT- 3 x ) ( mul x ( mul ( add INT+ 1 ( mul k0 ( pow x INT+ 3 ) ) ) ( add INT+ 4 x ) ) ) ) ) ) ) ) ) ( add INT+ 5 ( mul INT+ 3 x ) ) )
79
+
remend/check_generated.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sympy as sp
2
+ import sys
3
+ import re
4
+ from tqdm import tqdm
5
+ from Levenshtein import distance
6
+ import networkx as nx
7
+ from networkx import graph_edit_distance
8
+
9
+ from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint
10
+
11
+ def percent(a, n):
12
+ return f"{a/n*100:0.1f}%"
13
+
14
+ def do_simplify_match(orig_expr, gen_expr):
15
+ orig_simp = sp.simplify(orig_expr)
16
+ gen_simp = sp.simplify(gen_expr)
17
+ if orig_simp == gen_simp:
18
+ return True
19
+ return False
20
+
21
+ def do_structure_match(orig_toks, gen_toks):
22
+ def _isconst(t):
23
+ return re.match(r"c[0-9]+", t)
24
+ def _isvar(t):
25
+ return re.match(r"x[0-9]+", t)
26
+ if len(orig_toks) != len(gen_toks):
27
+ return False
28
+ for orig, gen in zip(orig_toks, gen_toks):
29
+ if (_isconst(orig) and _isconst(gen)) \
30
+ or (_isvar(orig) and _isvar(gen)) \
31
+ or (isint(orig) and isint(gen)) \
32
+ or (orig.startswith("INT") and gen.startswith("INT")) \
33
+ or (orig == gen):
34
+ continue
35
+ # Mismatched
36
+ return False
37
+ return True
38
+
39
+ if __name__ == "__main__":
40
+ import argparse
41
+ parser = argparse.ArgumentParser("Check generated expressions")
42
+ parser.add_argument("-g", required=True, help="Generated expressions file")
43
+ parser.add_argument("-r", required=True, help="Results file")
44
+ parser.add_argument("--simplify", action="store_true", default=False)
45
+ parser.add_argument("--postfix", action="store_true", default=False)
46
+ args = parser.parse_args()
47
+
48
+
49
+ orig_list = []
50
+ gen_list = []
51
+ with open(args.g, 'r') as f:
52
+ for line in tqdm(f, desc="Reading file"):
53
+ comps = line.strip().split("\t")
54
+ if line[0] == 'T':
55
+ num = int(comps[0][2:])
56
+ tokens = comps[1].split(" ")
57
+ orig_list.append((num, tokens))
58
+ elif line[0] == 'H':
59
+ num = int(comps[0][2:])
60
+ tokens = comps[2].split(" ")
61
+ gen_list.append((num, tokens))
62
+
63
+ N = len(orig_list)
64
+ gen_errors = []
65
+ parsed = []
66
+ exact_match = []
67
+ structure_match = []
68
+ simplify_match = []
69
+
70
+ orig_exprs = {}
71
+ gen_exprs = {}
72
+
73
+ all_aed = []
74
+ # all_ged = []
75
+
76
+ results = []
77
+
78
+ for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N):
79
+ assert orig_num == gen_num
80
+ aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks))
81
+ all_aed.append(aed)
82
+ res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False}
83
+
84
+ if aed == 0:
85
+ parsed.append(orig_num)
86
+ exact_match.append(orig_num)
87
+ structure_match.append(orig_num)
88
+ res["parsed"] = True
89
+ res["matched"] = "Exact"
90
+ results.append(res)
91
+ continue
92
+
93
+ if do_structure_match(orig_toks, gen_toks):
94
+ structure_match.append(orig_num)
95
+ res["matched"] = "Structure"
96
+
97
+ if "<<unk>>" in orig_toks:
98
+ # Why this happened?
99
+ res["parsed"] = False
100
+ res["matched"] = False
101
+ results.append(res)
102
+ continue
103
+
104
+ if args.postfix:
105
+ orig_expr = parse_postfix_to_sympy(orig_toks)
106
+ else:
107
+ orig_expr = parse_prefix_to_sympy(orig_toks)
108
+ try:
109
+ if args.postfix:
110
+ gen_expr = parse_postfix_to_sympy(gen_toks)
111
+ else:
112
+ gen_expr = parse_prefix_to_sympy(gen_toks)
113
+ res["parsed"] = True
114
+ except: # Exception as e:
115
+ gen_errors.append(gen_num)
116
+ results.append(res)
117
+ continue
118
+
119
+ parsed.append(gen_num)
120
+ orig_exprs[gen_num] = orig_expr
121
+ gen_exprs[gen_num] = gen_expr
122
+
123
+ if orig_expr == gen_expr:
124
+ exact_match.append(gen_num)
125
+ res["matched"] = "Exact"
126
+ elif args.simplify and do_simplify_match(orig_expr, gen_expr):
127
+ simplify_match.append(gen_num)
128
+ res["matched"] = "Simplify"
129
+ results.append(res)
130
+
131
+ with open(args.r, "w") as resf:
132
+ for res in results:
133
+ resf.write("{id} {aed} {parsed} {matched}\n".format(**res))
134
+ resf.write("\n")
135
+ print("Total", N, file=resf)
136
+ print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf)
137
+ print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf)
138
+ print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf)
139
+ if args.simplify:
140
+ print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf)
141
+ print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf)
142
+ # print("Avg GED", sum(all_ged) / len(all_ged), max(all_ged), file=resf)
143
+
remend/compile_dataset.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import random
3
+ import sympy as sp
4
+ import json
5
+ import subprocess as sproc
6
+ from os.path import realpath, dirname, join as pjoin
7
+ from os import makedirs
8
+ import multiprocessing as mp
9
+ from time import sleep
10
+ import logging
11
+
12
+ from .implementation import Implementor
13
+ from .parser import parse_prefix_to_sympy, sympy_to_prefix, constant_fold
14
+ from .disassemble import DisassemblerARM32, DisassemblerAArch64, DisassemblerX64
15
+ from .util import DecodeError, timeout, sympy_expr_ok
16
+
17
+ SCRIPT = pjoin(dirname(realpath(__file__)), "compile_eqn.sh")
18
+
19
+ QUEUE_END = "QUEUE_END_SENTINEL"
20
+
21
+ def compile_c(code, elf, arch="arm32", src="/tmp/myfunc.c", opt=0):
22
+ with open(src, "w") as f:
23
+ f.write(code)
24
+ ret = sproc.run(["bash", "-e", SCRIPT, arch+"-c", src, elf, f"-O{opt}"], capture_output=True)
25
+ if ret.returncode != 0:
26
+ raise DecodeError("compile failed")
27
+
28
+ def compile_fortran(code, elf, arch="arm32", src="/tmp/myfunc.f95", opt=0):
29
+ with open(src, "w") as f:
30
+ f.write(code)
31
+ ret = sproc.run(["bash", "-e", SCRIPT, arch+"-fortran", src, elf, f"-O{opt}"], capture_output=True)
32
+ if ret.returncode != 0:
33
+ raise DecodeError("compile failed")
34
+
35
+ class EquationCompiler:
36
+ def __init__(self, q, arch, impl, opt, outdir, prefix, dtype="double"):
37
+ if "fortran" in impl:
38
+ self.compiler = compile_fortran
39
+ else:
40
+ self.compiler = compile_c
41
+
42
+ if arch == "arm32":
43
+ self.disassembler = DisassemblerARM32
44
+ elif arch == "aarch64":
45
+ self.disassembler = DisassemblerAArch64
46
+ elif arch == "x64":
47
+ self.disassembler = DisassemblerX64
48
+ else:
49
+ raise DecodeError("arch not supported: " + arch)
50
+
51
+ self.q = q
52
+ self.impl = impl
53
+ self.opt = opt
54
+ self.outdir = outdir
55
+ self.prefix = prefix
56
+ self.dtype = dtype
57
+ self.arch = arch
58
+
59
+ def run(self):
60
+ outdir = pjoin(self.outdir, f"O{self.opt}", self.impl)
61
+ makedirs(outdir, exist_ok=True)
62
+ outfiles = {
63
+ "asm": open(pjoin(outdir, self.prefix + ".asm"), "w"),
64
+ "eqn": open(pjoin(outdir, self.prefix + ".eqn"), "w"),
65
+ "src": open(pjoin(outdir, self.prefix + ".src"), "w"),
66
+ "const": open(pjoin(outdir, self.prefix + ".const.jsonl"), "w"),
67
+ "err": open(pjoin(outdir, self.prefix + ".error"), "w")
68
+ }
69
+ l = 0
70
+ tmpsrc = f"/tmp/myfunc_{self.impl}_{self.opt}_{self.prefix}"
71
+ if "fortran" in self.impl:
72
+ tmpsrc += ".f95"
73
+ func = "myfunc_"
74
+ else:
75
+ tmpsrc += ".c"
76
+ func = "myfunc"
77
+ tmpelf = f"/tmp/myfunc_{self.arch}_{self.impl}_{self.opt}_{self.prefix}.elf"
78
+
79
+ while True:
80
+ data = self.q.get()
81
+ if data == QUEUE_END:
82
+ # Queue is closed, break from inf loop
83
+ break
84
+ n, expr, expr_const, pref = data
85
+ impl = Implementor(expr, constants=expr_const, dtype=self.dtype)
86
+ try:
87
+ code = impl.implement(self.impl)
88
+ self.compiler(code, tmpelf, arch=self.arch, src=tmpsrc, opt=self.opt)
89
+ disasm = self.disassembler(tmpelf, expr_constants=expr_const,
90
+ match_constants=True)
91
+ asm = disasm.disassemble(func)
92
+ if len(disasm.constants) < len(expr_const):
93
+ print(n, "constants not identified", disasm.constants, expr_const,
94
+ file=outfiles["err"])
95
+ continue
96
+ except DecodeError as e:
97
+ print(n, "impl error", e, expr, expr_const, pref, file=outfiles["err"])
98
+ continue
99
+
100
+ outfiles["asm"].write(asm + "\n")
101
+ outfiles["eqn"].write(pref + "\n")
102
+ outfiles["src"].write(f"==== pick={n} line={l} ====\n" + code + "\n")
103
+ outfiles["const"].write(json.dumps(expr_const) + "\n")
104
+ l += 1
105
+
106
+ for f in outfiles:
107
+ outfiles[f].close()
108
+
109
+
110
+ if __name__ == "__main__":
111
+ import argparse
112
+ parser = argparse.ArgumentParser("Compile prefix to asm->eqn dataset")
113
+ parser.add_argument("-f", "--file", required=True, help="Input file")
114
+ parser.add_argument("--outdir", required=True, help="Output directory")
115
+ parser.add_argument("--prefix", required=True, help="File prefix")
116
+ parser.add_argument("--impl", nargs="+", required=True,
117
+ choices=["dag_c", "cse_c", "dag_fortran", "cse_fortran"])
118
+ parser.add_argument("--pick", type=float, required=True,
119
+ help="Ratio of samples to pick (0 to 1)")
120
+ parser.add_argument("--start", type=int, default=0, help="Start from index")
121
+ parser.add_argument("--count", type=int, default=0, help="Process only these many")
122
+ parser.add_argument("--seed", type=int, default=1225)
123
+ parser.add_argument("--min-tokens", help="Minimum tokens in equations", type=int, default=5)
124
+ parser.add_argument("--min-ops", help="Minimum ops in equations", type=int, default=5)
125
+ parser.add_argument("--dtype", help="Implementation datatype", type=str,
126
+ choices=["double", "float"], default="double")
127
+ parser.add_argument("--arch", help="Target architecture", type=str,
128
+ choices=["arm32", "aarch64", "x64"], default="arm32")
129
+ parser.add_argument("-O", "--opt", nargs="+", type=int, choices=[0, 1, 2, 3], default=[0],
130
+ help="Optimization level (s)")
131
+
132
+ # Dont show warnings
133
+ logging.getLogger("cle").setLevel(logging.ERROR)
134
+
135
+ args = parser.parse_args()
136
+ random.seed(args.seed)
137
+
138
+ eqcompilers = [EquationCompiler(mp.Queue(), args.arch, impl, opt, args.outdir, args.prefix, dtype=args.dtype)
139
+ for impl in args.impl
140
+ for opt in args.opt]
141
+ pool = [mp.Process(target=eqc.run, args=()) for eqc in eqcompilers]
142
+ for proc in pool:
143
+ proc.start()
144
+
145
+ count = 0
146
+ prefixf = open(args.file, "r")
147
+ for n, line in tqdm(enumerate(prefixf), desc="Parsing file"):
148
+ # Skip for start lines and with some probability
149
+ if n < args.start or random.random() > args.pick:
150
+ continue
151
+ comps = line.strip().split("\t")
152
+ pref = comps[0][comps[0].find("Y'")+3:]
153
+ prefl = pref.split(" ")
154
+ # pref = comps[1].split(" ")
155
+ if len(prefl) < args.min_tokens:
156
+ continue
157
+ try:
158
+ expr = parse_prefix_to_sympy(prefl)
159
+ with timeout(10):
160
+ expr = sp.simplify(expr)
161
+ if not sympy_expr_ok(expr):
162
+ # Simplified is bad
163
+ continue
164
+ expr, expr_const = constant_fold(expr)
165
+ pref = " ".join(sympy_to_prefix(expr))
166
+ except:
167
+ continue
168
+
169
+ if sp.count_ops(expr) < args.min_ops:
170
+ continue
171
+
172
+ for eqc in eqcompilers:
173
+ # Poll on this queue to get empty
174
+ while eqc.q.qsize() > 5:
175
+ sleep(1)
176
+ eqc.q.put((n, expr, expr_const, pref))
177
+ count += 1
178
+ if args.count > 0 and count >= args.count:
179
+ break
180
+
181
+ # Close queues
182
+ for eqc in eqcompilers:
183
+ eqc.q.put(QUEUE_END)
184
+ for proc in pool:
185
+ proc.join()
remend/compile_eqn.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!bin/bash
2
+
3
+ MODE=$1
4
+ SRC=$2
5
+ ELF=$3
6
+ OPT=$4
7
+
8
+ if [ ! -f "$SRC" ]
9
+ then
10
+ echo "Please provide source file"
11
+ exit 1
12
+ fi
13
+
14
+ if [ "$ELF" == "" ]
15
+ then
16
+ echo "Please provide elf file path"
17
+ exit 1
18
+ fi
19
+
20
+
21
+ if [ "$MODE" == "arm32-c" ]
22
+ then
23
+ arm-linux-gnueabihf-gcc $OPT $SRC -lm -o $ELF
24
+ elif [ "$MODE" == "arm32-fortran" ]
25
+ then
26
+ arm-linux-gnueabihf-gfortran -std=gnu $OPT $SRC -o $ELF
27
+ elif [ "$MODE" == "aarch64-c" ]
28
+ then
29
+ aarch64-linux-gnu-gcc $OPT $SRC -lm -o $ELF
30
+ elif [ "$MODE" == "aarch64-fortran" ]
31
+ then
32
+ aarch64-linux-gnu-gfortran -std=gnu $OPT $SRC -o $ELF
33
+ elif [ "$MODE" == "x64-c" ]
34
+ then
35
+ gcc $OPT $SRC -lm -o $ELF
36
+ elif [ "$MODE" == "x64-fortran" ]
37
+ then
38
+ gfortran -std=gnu $OPT $SRC -o $ELF
39
+ else
40
+ echo "Incorrect mode: $MODE. Choose from: {arm32,aarch64,x64}-{c,fortran}"
41
+ exit 1
42
+ fi
43
+
44
+ # arm-linux-gnueabihf-objdump --no-show-raw-insn --no-addresses -d $1.elf | sed -n -e 's/\s;\s.*$//' -e "/myfunc>:$/,/^$/p" | sed '1d;$d' | tr '\n' ' '
remend/convert_generated.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .parser import parse_prefix_to_sympy
2
+
3
+ if __name__ == "__main__":
4
+ import argparse
5
+ parser = argparse.ArgumentParser("Parse result prefix to equation")
6
+ parser.add_argument("--input", required=True, help="Input result file")
7
+ args = parser.parse_args()
8
+
9
+ res_list = []
10
+
11
+ with open(args.input, 'r') as f:
12
+ for line in f:
13
+ comps = line.strip().split("\t")
14
+ if line[0] == 'H':
15
+ num = int(comps[0][2:])
16
+ tokens = comps[2].split(" ")
17
+ res_list.append((num, tokens))
18
+
19
+ for n, toks in res_list:
20
+ try:
21
+ ex = parse_prefix_to_sympy(toks)
22
+ print(n, ex)
23
+ except Exception as e:
24
+ print(n, "could not parse:", str(e))
remend/deduplicate_split.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import random
3
+ import os
4
+ import re
5
+ from tqdm import tqdm
6
+
7
+ def filter_poly(asm, eqn):
8
+ rejects = {"ln", "exp", "sin", "cos", "sqrt", "tan", "asin", "acos", "atan", "E", "pi", "cot"}
9
+ return any(t in rejects for t in asm.strip().split(" ")) \
10
+ or any(t in rejects for t in eqn.strip().split(" "))
11
+
12
+ def filter_bigint(asm, eqn):
13
+ if re.search(r"CONST=[0-9]{4,}", asm):
14
+ return True
15
+ return False
16
+
17
+ if __name__ == "__main__":
18
+ import argparse
19
+ parser = argparse.ArgumentParser("Deduplicate ASM and split files into train/test/valid")
20
+ parser.add_argument("--inprefix", required=True, help="Prefix of input files")
21
+ parser.add_argument("--outdir", required=True)
22
+ parser.add_argument("--split", type=float, default=0.05)
23
+ parser.add_argument("--seed", type=int, default=1225)
24
+ parser.add_argument("--filter", choices=["poly", "bigint"], default=None)
25
+ parser.add_argument("--no-separate-eqn", action="store_true")
26
+
27
+ args = parser.parse_args()
28
+
29
+ eq_mapped = {}
30
+ combined_ds = []
31
+ asm_hash = set()
32
+ removed = 0
33
+
34
+ with open(args.inprefix + ".asm", "r") as asmf, \
35
+ open(args.inprefix + ".eqn", "r") as eqnf, \
36
+ open(args.inprefix + ".const.jsonl", "r") as constf:
37
+ for i, (asm, eqn, const) in tqdm(enumerate(zip(asmf, eqnf, constf)),
38
+ desc="Read files", leave=False):
39
+ h = hash(asm)
40
+ if h in asm_hash:
41
+ # Skip this repeated line
42
+ removed += 1
43
+ continue
44
+
45
+ if re.search(r"[0-9]\.[0-9]", eqn):
46
+ # Float not represented, remove
47
+ removed += 1
48
+ continue
49
+
50
+ if args.filter == "poly" and filter_poly(asm, eqn):
51
+ removed += 1
52
+ continue
53
+ if args.filter == "bigint" and filter_bigint(asm, eqn):
54
+ removed += 1
55
+ continue
56
+
57
+ asm_hash.add(h)
58
+ if args.no_separate_eqn:
59
+ combined_ds.append((i, asm, eqn, const))
60
+ else:
61
+ if eqn not in eq_mapped:
62
+ eq_mapped[eqn] = []
63
+ eq_mapped[eqn].append((i, asm, const))
64
+
65
+ print("Removed", removed)
66
+
67
+ if args.no_separate_eqn:
68
+ dataset = combined_ds
69
+ else:
70
+ dataset = list(eq_mapped.keys())
71
+
72
+ random.seed(args.seed)
73
+ random.shuffle(dataset)
74
+
75
+ N = len(dataset)
76
+ Ntest = int(N * args.split)
77
+
78
+ splits = {
79
+ "train": dataset[:N-2*Ntest],
80
+ "valid": dataset[N-2*Ntest:N-Ntest],
81
+ "test": dataset[N-Ntest:]
82
+ }
83
+ splitidxs = {s: [] for s in splits}
84
+
85
+ idxf = open(os.path.join(args.outdir, "splits.txt"), "w")
86
+ for s in splits:
87
+ asmfn = os.path.join(args.outdir, f"{s}.asm")
88
+ eqnfn = os.path.join(args.outdir, f"{s}.eqn")
89
+ constfn = os.path.join(args.outdir, f"{s}.const.jsonl")
90
+ with open(asmfn, "w") as asmf, open(eqnfn, "w") as eqnf, \
91
+ open(constfn, "w") as constf:
92
+ if args.no_separate_eqn:
93
+ for i, asm, eqn, const in splits[s]:
94
+ asmf.write(asm)
95
+ eqnf.write(eqn)
96
+ constf.write(const)
97
+ splitidxs[s].append(i)
98
+ else:
99
+ for eqn in splits[s]:
100
+ for i, asm, const in eq_mapped[eqn]:
101
+ asmf.write(asm)
102
+ eqnf.write(eqn)
103
+ constf.write(const)
104
+ splitidxs[s].append(i)
105
+ print("Split", s, len(splitidxs[s]))
106
+ idxf.write(f"==== {s} ====\n")
107
+ for j, i in enumerate(splitidxs[s]):
108
+ idxf.write(f"{j}: {i}\n")
109
+ idxf.write("\n")
110
+ idxf.close()
111
+
remend/disassemble.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from capstone import *
2
+ from capstone.arm import *
3
+ from capstone.arm64 import *
4
+ from capstone.x86 import *
5
+ import cle
6
+ import struct
7
+ from math import e as CONST_E, pi as CONST_PI
8
+ import sympy as sp
9
+
10
+ from .util import DecodeError
11
+
12
+ def int2fp32(v):
13
+ if type(v) == int:
14
+ v = struct.unpack("<f", v.to_bytes(4, "little"))
15
+ v = v[0]
16
+ return v
17
+ def int2fp64(v):
18
+ if type(v) == int:
19
+ v = struct.unpack("<d", v.to_bytes(8, "little"))
20
+ v = v[0]
21
+ return v
22
+
23
+ def align4(v):
24
+ return v & (0xFFFFFFFC)
25
+
26
+ class DisassemblerBase:
27
+ def __init__(self, expr_constants={}, match_constants=False):
28
+ self.loader = None # Load in child class
29
+ self.reg_values = {}
30
+ self.constidx = 0
31
+ self.constants = {}
32
+ self.constaddrs = set()
33
+ self.expr_constants = expr_constants
34
+ self.match_constants = match_constants
35
+
36
+ def get_function_bytes(self, funcname):
37
+ func = self.loader.find_symbol(funcname)
38
+ if not func:
39
+ raise DecodeError(f"Function {funcname} not found in binary")
40
+ faddr = func.rebased_addr
41
+ if (not isinstance(self, DisassemblerX64)) and faddr % 2 == 1:
42
+ # Unaligned address, aligning
43
+ faddr = faddr - 1
44
+ fbytes = self.loader.memory.load(faddr, func.size)
45
+ self.funcrange = faddr, faddr + func.size
46
+ return faddr, fbytes
47
+
48
+ def find_constant(self, constants, value):
49
+ for ec in constants:
50
+ if abs(value - constants[ec]) < 1e-5:
51
+ return ec, ""
52
+ elif abs(1/value - constants[ec]) < 1e-5:
53
+ return ec, "1/"
54
+ elif abs(-value - constants[ec]) < 1e-5:
55
+ return ec, "-"
56
+ elif abs(-1/value - constants[ec]) < 1e-5:
57
+ return ec, "-1/"
58
+ return False
59
+
60
+ def add_constant(self, value, addr=0, size=0):
61
+ # Don't map known constants like e, pi, 0
62
+ if value == 0:
63
+ cname = "CONST=0"
64
+ elif abs(value - CONST_E) < 1e-7:
65
+ cname = "CONST=E"
66
+ elif abs(value - CONST_PI) < 1e-7:
67
+ cname = "CONST=pi"
68
+ elif self.match_constants and \
69
+ (ecmatch := self.find_constant(self.expr_constants, value)):
70
+ # Gives the name and expression of the matched constant
71
+ ecname, ecxpr = ecmatch
72
+ # print(value, ecname, ecxpr, self.expr_constants[ecname])
73
+ cname = f"{ecxpr}CSYM{ecname[1:]}"
74
+ self.constants[ecname] = value
75
+ elif size > 0 and addr in self.constaddrs and \
76
+ (smatch := self.find_constant(self.constants, value)):
77
+ sname, sxpr = smatch
78
+ cname = f"{sxpr}CSYM{sname}"
79
+ else:
80
+ rep = sp.nsimplify(value, [sp.E, sp.pi], tolerance=1e-7)
81
+ if isinstance(rep, sp.Integer) or \
82
+ (isinstance(rep, sp.Rational) and rep.q <= 16):
83
+ cname = f"CONST={rep}"
84
+ elif not self.match_constants:
85
+ cname = f"CSYM{self.constidx}"
86
+ self.constants[self.constidx] = value
87
+ self.constidx += 1
88
+ else:
89
+ raise DecodeError(f"Cannot represent unmatched float {value}")
90
+
91
+ if size > 0:
92
+ self.constaddrs |= {addr+i for i in range(size)}
93
+ return cname
94
+
95
+ def disassemble(self, function):
96
+ raise NotImplementedError("Call disassemble on child classes, not base")
97
+
98
+
99
+ class DisassemblerARM32(DisassemblerBase):
100
+ def __init__(self, binpath, expr_constants={}, match_constants=False):
101
+ super().__init__(expr_constants=expr_constants, match_constants=match_constants)
102
+ self.md = Cs(CS_ARCH_ARM, CS_MODE_THUMB)
103
+ self.md.detail = True
104
+ self.loader = cle.Loader(binpath)
105
+
106
+ def check_mov_imm(self, insn):
107
+ if insn.id not in {ARM_INS_MOV, ARM_INS_MOVW,
108
+ ARM_INS_MOVT, ARM_INS_ADR}:
109
+ return False
110
+ ops = list(insn.operands)
111
+ if len(ops) != 2:
112
+ return False
113
+ if ops[0].type != ARM_OP_REG or ops[1].type != ARM_OP_IMM:
114
+ return False
115
+ imm = ops[1].value.imm
116
+ if imm < 0:
117
+ imm = 2**32 + imm # 2's complement
118
+ if insn.id == ARM_INS_ADR:
119
+ # Add PC value
120
+ imm += insn.address + 4
121
+ return ops[0].value.reg, imm
122
+
123
+ def check_float_store(self, insn):
124
+ if insn.id not in {ARM_INS_STR, ARM_INS_STRD}:
125
+ return False
126
+ ops = list(insn.operands)
127
+ if insn.id == ARM_INS_STRD:
128
+ dest = ops[0].value.reg
129
+ dest2 = ops[1].value.reg
130
+ if dest not in self.reg_values or dest2 not in self.reg_values:
131
+ return False
132
+ fval = int2fp64((self.reg_values[dest2]<<32) + self.reg_values[dest])
133
+ else:
134
+ dest = ops[0].value.reg
135
+ if dest not in self.reg_values:
136
+ return False
137
+ fval = int2fp32(self.reg_values[dest])
138
+ if abs(fval) < 1e-3 or abs(fval) > 100:
139
+ return False
140
+ return fval
141
+
142
+ def check_ldrd(self, insn):
143
+ if insn.id != ARM_INS_LDRD:
144
+ return False
145
+ ops = insn.op_str.split(", ")
146
+ if len(ops) != 3:
147
+ return False
148
+ mem = ops[2] # format: [<reg> + #<offset>]
149
+ if mem[0] != "[" or mem[-1] != "]":
150
+ return False
151
+ memcomps = mem[1:-1].split(" ")
152
+ if memcomps[0] == "pc":
153
+ base = align4(insn.address + 4)
154
+ else:
155
+ basereg = ARM_REG_R0 + int(memcomps[0][1:]) # Shitty hack, may malfunction
156
+ if basereg not in self.reg_values:
157
+ return False
158
+ base = align4(self.reg_values[basereg])
159
+ if len(memcomps) == 3:
160
+ offset = int(memcomps[2][1:])
161
+ else:
162
+ offset = 0
163
+ addr = base + offset
164
+ fhex = self.loader.memory.load(addr, 8)
165
+ fval = struct.unpack("d", fhex)[0]
166
+ return fval, addr, 8
167
+
168
+ def check_vldr(self, insn):
169
+ if insn.id != ARM_INS_VLDR:
170
+ return False
171
+ ops = list(insn.operands)
172
+ dest = ops[0]
173
+ if ops[1].type != ARM_OP_MEM:
174
+ return False
175
+ mem = ops[1].value.mem
176
+ if mem.base == ARM_REG_PC:
177
+ # Align4(PC) + Imm
178
+ # For whatever reason, in Thumb PC=addr+4
179
+ addr = align4(insn.address + 4) + mem.disp
180
+ elif mem.base in self.reg_values:
181
+ addr = align4(self.reg_values[mem.base]) + mem.disp
182
+ else:
183
+ return False
184
+ if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr:
185
+ # Out of bounds
186
+ return False
187
+ if dest.value.reg >= ARM_REG_D0 and dest.value.reg <= ARM_REG_D31:
188
+ size = 8
189
+ fhex = self.loader.memory.load(addr, 8)
190
+ fval = struct.unpack("d", fhex)[0]
191
+ else:
192
+ size = 4
193
+ fhex = self.loader.memory.load(addr, 4)
194
+ fval = struct.unpack("f", fhex)[0]
195
+ return fval, addr, size
196
+
197
+ def check_vmov(self, insn):
198
+ # fconsts/d == vmov.f32/f64 (old/new names)
199
+ if insn.id not in {ARM_INS_FCONSTS, ARM_INS_FCONSTD}:
200
+ return False
201
+ ops = list(insn.operands)
202
+ if len(ops) != 2 or ops[1].type != ARM_OP_FP:
203
+ return False
204
+ fval = ops[1].value.fp
205
+ destname = insn.reg_name(ops[0].value.reg)
206
+ asm = f"{insn.mnemonic} {destname}, {fval}"
207
+ return asm, fval
208
+
209
+ def check_branch_symbol(self, insn):
210
+ if insn.id not in {ARM_INS_B, ARM_INS_BL, ARM_INS_BLX}:
211
+ return False
212
+ ops = list(insn.operands)
213
+ if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
214
+ return False
215
+ addr = ops[0].value.imm
216
+ if addr > self.funcrange[0] and addr < self.funcrange[1]:
217
+ # Self-branch
218
+ func = f"SELF+{hex(addr - self.funcrange[0])}"
219
+ else:
220
+ func = self.loader.find_plt_stub_name(addr)
221
+ if func is None:
222
+ # Some tail call optimized PLT stubs have extra instructions
223
+ # that are not identified by CLE, so check with offset of 4 also.
224
+ func = self.loader.find_plt_stub_name(addr + 4)
225
+ if func is None:
226
+ return False
227
+ asm = f"{insn.mnemonic} <{func}>"
228
+ return asm
229
+
230
+ def get_function_bytes(self, funcname):
231
+ func = self.loader.find_symbol(funcname)
232
+ if not func:
233
+ raise DecodeError(f"Function {funcname} not found in binary")
234
+ faddr = func.rebased_addr
235
+ if faddr % 2 == 1:
236
+ # Unaligned address, aligning
237
+ faddr = faddr - 1
238
+ fbytes = self.loader.memory.load(faddr, func.size)
239
+ self.funcrange = faddr, faddr + func.size
240
+ return faddr, fbytes
241
+
242
+ def disassemble(self, funcname):
243
+ funcaddr, funcbytes = self.get_function_bytes(funcname)
244
+ disassm = []
245
+
246
+ for insn in self.md.disasm(funcbytes, funcaddr):
247
+ if insn.address in self.constaddrs:
248
+ # Skip if this is a constant value and not instruction
249
+ continue
250
+
251
+ cname = None
252
+ asm = None
253
+
254
+ if vldr := self.check_vldr(insn):
255
+ fval, faddr, fsize = vldr
256
+ cname = self.add_constant(fval, faddr, fsize)
257
+ elif ldrd := self.check_ldrd(insn):
258
+ fval, faddr, fsize = ldrd
259
+ cname = self.add_constant(fval, faddr, fsize)
260
+ elif strfloat := self.check_float_store(insn):
261
+ fval = strfloat
262
+ cname = self.add_constant(fval)
263
+ elif vmovfloat := self.check_vmov(insn):
264
+ asm, fval = vmovfloat
265
+ cname = self.add_constant(fval)
266
+ elif branch := self.check_branch_symbol(insn):
267
+ asm = branch
268
+
269
+ # Maintain values of immediate moves.
270
+ # Needs to be done after processing current instruction.
271
+ if movimm := self.check_mov_imm(insn):
272
+ reg, imm = movimm
273
+ if insn.id == ARM_INS_MOVT:
274
+ if reg not in self.reg_values:
275
+ self.reg_values[reg] = 0
276
+ self.reg_values[reg] += imm << 16
277
+ else:
278
+ self.reg_values[reg] = imm
279
+ else:
280
+ reads, writes = insn.regs_access()
281
+ for r in writes:
282
+ # Remove this reg if written to
283
+ if r in self.reg_values:
284
+ del self.reg_values[r]
285
+
286
+ if not asm:
287
+ asm = f"{insn.mnemonic} {insn.op_str}"
288
+ if cname:
289
+ asm += f", {cname}"
290
+ disassm.append(asm)
291
+
292
+ fulldiss = "; ".join(disassm)
293
+ return fulldiss
294
+
295
+ class DisassemblerAArch64(DisassemblerBase):
296
+ def __init__(self, binpath, expr_constants={}, match_constants=False):
297
+ super().__init__(expr_constants=expr_constants, match_constants=match_constants)
298
+ self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM)
299
+ self.md.detail = True
300
+ self.loader = cle.Loader(binpath)
301
+
302
+ def reg_size_type(self, reg):
303
+ # Bit width and datatype of register
304
+ if reg >= ARM64_REG_W0 and reg <= ARM64_REG_W30:
305
+ return 32, int
306
+ elif reg >= ARM64_REG_X0 and reg <= ARM64_REG_X30:
307
+ return 64, int
308
+ elif reg >= ARM64_REG_S0 and reg <= ARM64_REG_S31:
309
+ return 32, float
310
+ elif reg >= ARM64_REG_D0 and reg <= ARM64_REG_D31:
311
+ return 64, float
312
+ return 0, None
313
+
314
+ def check_mov_imm(self, insn):
315
+ if insn.id not in {ARM64_INS_ADRP, ARM64_INS_ADR, ARM64_INS_MOV, ARM64_INS_MOVK}:
316
+ return False
317
+
318
+ ops = insn.operands
319
+ if len(ops) != 2:
320
+ return False
321
+ if ops[0].type != ARM64_OP_REG or ops[1].type != ARM64_OP_IMM:
322
+ return False
323
+
324
+ imm = ops[1].value.imm
325
+ if ops[1].shift.type == 1: # LSL
326
+ imm <<= ops[1].shift.value
327
+ mask = 0xFFFF << ops[1].shift.value
328
+
329
+ if insn.id == ARM64_INS_ADRP:
330
+ # imm -= 0x400000 # Subtract global offset for some reason
331
+ # imm = ((insn.address + 4) & (~4095)) + imm
332
+ # Really confused about this, maybe I can use the imm directly
333
+ pass
334
+ elif insn.id == ARM64_INS_ADR:
335
+ imm -= 0x400000 # Subtract global offset for some reason
336
+ imm += insn.address + 4
337
+ elif insn.id == ARM64_INS_MOVK:
338
+ # load previous reg value
339
+ if ops[0].value.reg in self.reg_values:
340
+ curr = self.reg_values[ops[0].value.reg]
341
+ imm = (imm & mask) | (curr & (~mask))
342
+
343
+ return ops[0].value.reg, imm
344
+
345
+ def check_fmov(self, insn):
346
+ if insn.id != ARM64_INS_FMOV:
347
+ return False
348
+ ops = insn.operands
349
+ if len(ops) != 2: # or ops[1].type != ARM64_OP_FP:
350
+ return False
351
+
352
+ destsize, _ = self.reg_size_type(ops[0].value.reg)
353
+ destname = insn.reg_name(ops[0].value.reg)
354
+ if ops[1].type == ARM64_OP_FP:
355
+ fval = ops[1].value.fp
356
+ asm = f"{insn.mnemonic} {destname}, {fval}"
357
+ elif ops[1].type == ARM64_OP_REG:
358
+ reg = ops[1].value.reg
359
+ if reg not in self.reg_values:
360
+ return False
361
+ # TODO datatype
362
+ fhex = self.reg_values[reg]
363
+ if destsize == 64:
364
+ if fhex < 0:
365
+ fhex += 2**64
366
+ fval = int2fp64(fhex)
367
+ elif destsize == 32:
368
+ if fhex < 0:
369
+ fhex += 2**32
370
+ fval = int2fp32(fhex)
371
+ else:
372
+ return False
373
+
374
+ if abs(fval) < 1e-5 or abs(fval) > 1e5:
375
+ return False
376
+ asm = f"{insn.mnemonic} {insn.op_str}"
377
+ return asm, fval
378
+
379
+ def check_ldr(self, insn):
380
+ if insn.id != ARM64_INS_LDR:
381
+ return False
382
+ ops = insn.op_str[:-1].split(", ")
383
+ destsize, desttype = self.reg_size_type(insn.operands[0].value.reg)
384
+ if len(ops) < 2 or desttype != float:
385
+ return False
386
+ reg = ops[1]
387
+ if reg[0] != "[" or "sp" in reg:
388
+ return False
389
+ basereg = ARM64_REG_X0 + int(reg[2:]) # Shitty hack, may malfunction
390
+ if basereg not in self.reg_values:
391
+ return False
392
+ base = align4(self.reg_values[basereg])
393
+ if len(ops) == 3:
394
+ offset = ops[2][1:]
395
+ if offset.startswith("0x"):
396
+ offset = int(offset[2:], base=16)
397
+ else:
398
+ offset = int(offset)
399
+ else:
400
+ offset = 0
401
+ addr = base + offset
402
+ if destsize == 64:
403
+ fhex = self.loader.memory.load(addr, 8)
404
+ fval = struct.unpack("d", fhex)[0]
405
+ return fval, addr, 8
406
+ elif destsize == 32:
407
+ fhex = self.loader.memory.load(addr, 4)
408
+ fval = struct.unpack("f", fhex)[0]
409
+ return fval, addr, 4
410
+ else:
411
+ return False
412
+
413
+
414
+ def check_branch_symbol(self, insn):
415
+ if insn.id not in {ARM64_INS_BL, ARM64_INS_B}:
416
+ return False
417
+ ops = insn.operands
418
+ if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
419
+ return False
420
+ addr = ops[0].value.imm
421
+ if addr > self.funcrange[0] and addr < self.funcrange[1]:
422
+ # Self-branch
423
+ func = f"SELF+{hex(addr - self.funcrange[0])}"
424
+ else:
425
+ func = self.loader.find_plt_stub_name(addr)
426
+ if func is None:
427
+ # Some tail call optimized PLT stubs have extra instructions
428
+ # that are not identified by CLE, so check with offset of 4 also.
429
+ func = self.loader.find_plt_stub_name(addr + 4)
430
+ if func is None:
431
+ return False
432
+ asm = f"{insn.mnemonic} <{func}>"
433
+ return asm
434
+
435
+ def disassemble(self, funcname):
436
+ funcaddr, funcbytes = self.get_function_bytes(funcname)
437
+ disassm = []
438
+
439
+ for insn in self.md.disasm(funcbytes, funcaddr):
440
+ if insn.address in self.constaddrs:
441
+ # Skip if this is a constant value and not instruction
442
+ continue
443
+
444
+ cname = None
445
+ asm = None
446
+ # Maintain values of immediate moves
447
+ if movimm := self.check_mov_imm(insn):
448
+ reg, imm = movimm
449
+ self.reg_values[reg] = imm
450
+ else:
451
+ reads, writes = insn.regs_access()
452
+ for r in writes:
453
+ # Remove this reg if written to
454
+ if r in self.reg_values:
455
+ del self.reg_values[r]
456
+
457
+ if fmov := self.check_fmov(insn):
458
+ asm, fval = fmov
459
+ cname = self.add_constant(fval)
460
+ elif ldr := self.check_ldr(insn):
461
+ fval, faddr, fsize = ldr
462
+ cname = self.add_constant(fval, faddr, fsize)
463
+ elif branch := self.check_branch_symbol(insn):
464
+ asm = branch
465
+
466
+ if not asm:
467
+ asm = f"{insn.mnemonic} {insn.op_str}"
468
+ if cname:
469
+ asm += f", {cname}"
470
+ disassm.append(asm)
471
+
472
+ fulldiss = "; ".join(disassm)
473
+ return fulldiss
474
+
475
+ class DisassemblerX64(DisassemblerBase):
476
+ def __init__(self, binpath, expr_constants={}, match_constants=False):
477
+ super().__init__(expr_constants=expr_constants, match_constants=match_constants)
478
+ self.md = Cs(CS_ARCH_X86, CS_MODE_64)
479
+ self.md.detail = True
480
+ self.loader = cle.Loader(binpath)
481
+
482
+ def check_call_symbol(self, insn):
483
+ if insn.id != X86_INS_CALL:
484
+ return False
485
+ ops = insn.operands
486
+ # TODO check this ARM_OP
487
+ if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
488
+ return False
489
+ addr = ops[0].value.imm
490
+ func = self.loader.find_plt_stub_name(addr)
491
+ if func is None:
492
+ return False
493
+ asm = f"{insn.mnemonic} <{func}>"
494
+ return asm
495
+
496
+ def check_fload(self, insn):
497
+ # Cannot rely on ID because any instruction
498
+ # can access memory.
499
+ ops = insn.operands
500
+ memops = [op for op in ops
501
+ if (op.type == X86_OP_MEM and
502
+ op.value.mem.base == X86_REG_RIP)]
503
+ if len(memops) != 1:
504
+ return False
505
+ mem, size = memops[0].value.mem, memops[0].size
506
+ if size > 8:
507
+ return False
508
+ addr = insn.address + insn.size + mem.disp
509
+ fhex = self.loader.memory.load(addr, size)
510
+ fval = struct.unpack("f" if size == 4 else "d", fhex)[0]
511
+ return fval, addr, size
512
+
513
+ def disassemble(self, funcname):
514
+ funcaddr, funcbytes = self.get_function_bytes(funcname)
515
+ disassm = []
516
+
517
+ for insn in self.md.disasm(funcbytes, funcaddr):
518
+ asm = None
519
+ cname = None
520
+ if fload := self.check_fload(insn):
521
+ fval, faddr, fsize = fload
522
+ cname = self.add_constant(fval, faddr, fsize)
523
+ elif call := self.check_call_symbol(insn):
524
+ asm = call
525
+
526
+ if not asm:
527
+ asm = f"{insn.mnemonic} {insn.op_str}"
528
+ if cname:
529
+ asm += f", {cname}"
530
+ disassm.append(asm)
531
+
532
+ fulldiss = "; ".join(disassm)
533
+ return fulldiss
534
+
535
+
536
+ # Regular
537
+ if __name__ == "__main__":
538
+ import argparse
539
+ parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
540
+ parser.add_argument("--bin", required=True)
541
+ parser.add_argument("--func", required=True)
542
+ parser.add_argument("--arch", required=True)
543
+ args = parser.parse_args()
544
+
545
+ if args.arch == "arm32":
546
+ D = DisassemblerARM32(args.bin)
547
+ elif args.arch == "aarch64":
548
+ D = DisassemblerAArch64(args.bin)
549
+ elif args.arch == "x64":
550
+ D = DisassemblerX64(args.bin)
551
+ diss = D.disassemble(args.func)
552
+ print(diss)
553
+ print(D.constants)
remend/edit_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ removed = ["encoder.layers.0.in_proj_weight", "encoder.layers.0.in_proj_bias", "encoder.layers.0.out_proj_weight", "encoder.layers.0.out_proj_bias", "encoder.layers.0.fc1_weight", "encoder.layers.0.fc1_bias", "encoder.layers.0.fc2_weight", "encoder.layers.0.fc2_bias", "encoder.layers.1.in_proj_weight", "encoder.layers.1.in_proj_bias", "encoder.layers.1.out_proj_weight", "encoder.layers.1.out_proj_bias", "encoder.layers.1.fc1_weight", "encoder.layers.1.fc1_bias", "encoder.layers.1.fc2_weight", "encoder.layers.1.fc2_bias", "encoder.layers.2.in_proj_weight", "encoder.layers.2.in_proj_bias", "encoder.layers.2.out_proj_weight", "encoder.layers.2.out_proj_bias", "encoder.layers.2.fc1_weight", "encoder.layers.2.fc1_bias", "encoder.layers.2.fc2_weight", "encoder.layers.2.fc2_bias", "encoder.layers.3.in_proj_weight", "encoder.layers.3.in_proj_bias", "encoder.layers.3.out_proj_weight", "encoder.layers.3.out_proj_bias", "encoder.layers.3.fc1_weight", "encoder.layers.3.fc1_bias", "encoder.layers.3.fc2_weight", "encoder.layers.3.fc2_bias", "encoder.layers.4.in_proj_weight", "encoder.layers.4.in_proj_bias", "encoder.layers.4.out_proj_weight", "encoder.layers.4.out_proj_bias", "encoder.layers.4.fc1_weight", "encoder.layers.4.fc1_bias", "encoder.layers.4.fc2_weight", "encoder.layers.4.fc2_bias", "encoder.layers.5.in_proj_weight", "encoder.layers.5.in_proj_bias", "encoder.layers.5.out_proj_weight", "encoder.layers.5.out_proj_bias", "encoder.layers.5.fc1_weight", "encoder.layers.5.fc1_bias", "encoder.layers.5.fc2_weight", "encoder.layers.5.fc2_bias"]
2
+
3
+ if __name__ == "__main__":
4
+ import argparse
5
+ import torch
6
+
7
+ parser = argparse.ArgumentParser("Edit the checkpoint to remove extra dict weights")
8
+ parser.add_argument("-c", "--checkpoint", required=True, help="Input checkpoint")
9
+ parser.add_argument("-e", "--edited", required=True, help="Edited checkpoint")
10
+ args = parser.parse_args()
11
+
12
+ sd = torch.load(args.checkpoint, weights_only=False)
13
+ for k in removed:
14
+ if k in sd['model']:
15
+ del sd['model'][k]
16
+ torch.save(sd, args.edited)
remend/eval_generated.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sympy as sp
2
+ import numpy as np
3
+ import warnings
4
+ from sympy.abc import x
5
+ import sys
6
+ import json
7
+ from tqdm import tqdm
8
+
9
+ from .parser import parse_prefix_to_sympy, isint
10
+
11
+ # Ignore sympy lambda warnings.
12
+ warnings.simplefilter("ignore")
13
+
14
+ def percent(a, n):
15
+ return f"{a/n*100:0.1f}%"
16
+
17
+ def do_eval_match(orig_expr, gen_expr):
18
+ try:
19
+ origl = sp.lambdify(x, orig_expr)
20
+ genl = sp.lambdify(x, gen_expr)
21
+ count = 0
22
+
23
+ for v in np.arange(0.2, 1, 0.01):
24
+ o = origl(v)
25
+ g = genl(v)
26
+ if o == float('nan') or o == float('inf'):
27
+ continue
28
+ if g == float('nan') or g == float('inf'):
29
+ continue
30
+ # if type(o) != np.float64 or type(g) != np.float64:
31
+ # print(orig_expr, o, gen_expr, g)
32
+ # return False
33
+ if abs((o-g)/o) > 1e-5:
34
+ return False
35
+ count += 1
36
+ except:
37
+ return False
38
+ return count >= 5
39
+
40
+ if __name__ == "__main__":
41
+ import argparse
42
+ parser = argparse.ArgumentParser("Check generated expressions")
43
+ parser.add_argument("-g", required=True, help="Generated expressions file")
44
+ parser.add_argument("-c", required=True, help="Constants file")
45
+ parser.add_argument("-e", required=True, help="Equations file")
46
+ parser.add_argument("-r", required=True, help="Results file")
47
+ args = parser.parse_args()
48
+
49
+ gens = []
50
+ with open(args.g, 'r') as genf, open(args.c) as constf, open(args.e) as eqnf:
51
+ for line in tqdm(genf, desc="Reading file"):
52
+ comps = line.strip().split("\t")
53
+ if line[0] == 'H':
54
+ num = int(comps[0][2:])
55
+ tokens = comps[2].split(" ")
56
+ eqn = next(eqnf)
57
+ const = next(constf)
58
+ const = json.loads(const.strip())
59
+ gens.append((num, tokens, eqn.strip(), const))
60
+
61
+ parsed = []
62
+ matched = []
63
+ results = []
64
+
65
+ for n, toks, eqn, const in tqdm(gens, desc="Evaluating expressions"):
66
+ res = {"id": n, "parsed": False, "matched": False, "orig": "", "gen": ""}
67
+ if "<<unk>>" in toks:
68
+ # Not parsed
69
+ results.append(res)
70
+ continue
71
+ try:
72
+ gen_expr = parse_prefix_to_sympy(toks)
73
+ except Exception as e:
74
+ # Not parsed
75
+ results.append(res)
76
+ continue
77
+
78
+ res["parsed"] = True
79
+ parsed.append(n)
80
+
81
+ gen_expr = gen_expr.subs([(sp.Symbol("k"+c), const[c]) for c in const])
82
+ orig_expr = sp.parse_expr(eqn, local_dict={"x0":x})
83
+ res["orig"] = str(orig_expr)
84
+ res["gen"] = str(gen_expr)
85
+
86
+ if not do_eval_match(orig_expr, gen_expr):
87
+ results.append(res)
88
+ continue
89
+ res["matched"] = True
90
+ matched.append(n)
91
+ results.append(res)
92
+
93
+ with open(args.r, "w") as resf:
94
+ for res in results:
95
+ resf.write("{id} {parsed} {matched} \"{orig}\" \"{gen}\"\n".format(**res))
96
+ resf.write("\n")
97
+ N = len(gens)
98
+ print("Total", N, file=resf)
99
+ print("Parsed", len(parsed), percent(len(parsed), N), file=resf)
100
+ print("Matched", len(matched), percent(len(matched), N), file=resf)
remend/experiment.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sympy as sp
2
+ import random
3
+
4
+ from .parser import parse_prefix_to_sympy
5
+
6
+ isconst = lambda e: not any(c.is_symbol for c in e.atoms())
7
+ def constfold(expr):
8
+ q = [expr]
9
+ cidx = 0
10
+ subsmap = {}
11
+ constmap = {}
12
+
13
+ while len(q) > 0:
14
+ curr_expr = q.pop(0)
15
+ if isinstance(curr_expr, sp.Number) or isconst(e):
16
+ const_expr = curr_expr.evalf()
17
+ rep = sp.nsimplify(const_expr, [sp.E, sp.pi])
18
+ if isinstance(rep, sp.Integer) or \
19
+ (isinstance(rep, sp.Rational) and rep.q <= 16):
20
+ subsmap[curr_expr] = rep
21
+ else:
22
+ subsmap[curr_expr] = sp.Symbol(f"k{cidx}")
23
+ constmap[f"k{cidx}"] = float(const_expr)
24
+ cidx += 1
25
+ else:
26
+ for child in curr_expr.args:
27
+ q.append(child)
28
+
29
+ return expr.subs(subsmap), constmap
30
+
31
+ def replace_const(expr):
32
+ cidx = 0
33
+ subsmap = {}
34
+ constmap = {}
35
+
36
+ for c in sp.preorder_traversal(expr):
37
+ if isinstance(c, sp.Float):
38
+ rep = sp.nsimplify(c)
39
+ if isinstance(rep, sp.Integer) or \
40
+ (isinstance(rep, sp.Rational) and rep.q <= 16):
41
+ subsmap[c] = rep
42
+ else:
43
+ subsmap[c] = sp.Symbol(f"c{cidx}")
44
+ constmap[f"c{cidx}"] = float(c)
45
+ cidx += 1
46
+ return expr.subs(subsmap), constmap
47
+
48
+
49
+ if __name__ == "__main__":
50
+ import argparse
51
+ parser = argparse.ArgumentParser("Random experiments")
52
+ parser.add_argument("-f", required=True)
53
+ parser.add_argument("-p", type=float, default=0.1)
54
+ parser.add_argument("-n", type=int, default=20)
55
+ args = parser.parse_args()
56
+
57
+ random.seed(1225)
58
+
59
+ count = 0
60
+ with open(args.f, "r") as f:
61
+ for line in f:
62
+ if random.random() > args.p:
63
+ continue
64
+
65
+ prefl = line.strip().split(" ")
66
+
67
+ orig = parse_prefix_to_sympy(prefl)
68
+ # simp = sp.simplify(expr)
69
+ expr = constfold(orig)
70
+ expr, consts = replace_const(expr)
71
+ print(orig, expr, consts)
72
+
73
+ count += 1
74
+ if count == args.n:
75
+ break
remend/find_duplicates.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from tqdm import tqdm
3
+ from Levenshtein import distance
4
+
5
+ if __name__ == "__main__":
6
+ import argparse
7
+ parser = argparse.ArgumentParser("Find duplicates in the dataset ASM")
8
+ parser.add_argument("--train", required=True)
9
+ # parser.add_argument("--valid", required=True)
10
+ parser.add_argument("--test", required=True)
11
+ parser.add_argument("--result", required=False)
12
+ parser.add_argument("--distance", action="store_true", default=False)
13
+ args = parser.parse_args()
14
+
15
+ train = []
16
+ train_hash = {}
17
+ # valid = []
18
+ test = []
19
+ with open(args.train, "r") as tf:
20
+ for idx, line in tqdm(enumerate(tf), desc="Read train", leave=False):
21
+ train_hash[hash(line)] = idx
22
+ comps = line.strip().split(" ")
23
+ train.append(comps)
24
+ # with open(args.valid, "r") as tf:
25
+ # for line in tqdm(tf, desc="Read valid", leave=False):
26
+ # valid.append(line.strip().split(" "))
27
+ with open(args.test, "r") as tf:
28
+ for line in tqdm(tf, desc="Read test", leave=False):
29
+ test.append(line)
30
+
31
+ selfcheck = args.train == args.test
32
+ if args.result:
33
+ rf = open(args.result, "w")
34
+ searchdist = args.distance
35
+ else:
36
+ searchdist = False # Dont compute if no result file
37
+ rf = None
38
+
39
+ def reswrite(s):
40
+ if rf:
41
+ rf.write(s)
42
+
43
+ exact = 0
44
+ for i, testline in tqdm(enumerate(test), desc="Test", total=len(test)):
45
+ testl = testline.strip().split(" ")
46
+ htest = hash(testline)
47
+ if htest in train_hash:
48
+ # Found exact match
49
+ j = train_hash[htest]
50
+ if not selfcheck or j != i:
51
+ exact += 1
52
+ reswrite(f"{i} {j} 0 0.0\n")
53
+ continue
54
+
55
+ # If not, then search
56
+ if searchdist:
57
+ minavgdist, mindist, minj = 100, 100, -1
58
+ for j, trainl in enumerate(train):
59
+ if abs(len(trainl) - len(testl)) > 10:
60
+ dist = abs(len(trainl) - len(testl)) * 2 # HACK to speed it up
61
+ else:
62
+ dist = distance(trainl, testl)
63
+ avgdist = dist / (len(trainl) + len(testl))
64
+ if mindist > dist:
65
+ minavgdist, mindist, minj = avgdist, dist, j
66
+
67
+ reswrite(f"{i} {minj} {mindist} {minavgdist}\n")
68
+
69
+ print("Exact duplicates:", exact)
remend/implementation.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sympy as sp
2
+ from sympy.codegen import ast
3
+ import itertools as it
4
+ import networkx as nx
5
+
6
+ from .parser import OPERATORS, sympy_to_dag
7
+ from .util import DecodeError
8
+
9
+ def isnum(s):
10
+ try:
11
+ float(s)
12
+ return True
13
+ except ValueError:
14
+ return False
15
+
16
+ class Implementor:
17
+ def __init__(self, expr, constants={}, dtype="double"):
18
+ self.expr = expr
19
+ self.constants = constants
20
+ self.cdtype = dtype
21
+ self.cpf = "lf" if dtype == "double" else "f"
22
+ self.fdtype = "double precision" if dtype == "double" else "real"
23
+
24
+ def implement(self, impl):
25
+ if impl == "dag_c":
26
+ return self.dag_to_c_impl()
27
+ elif impl == "cse_c":
28
+ return self.sympy_cse_c_impl()
29
+ elif impl == "dag_fortran":
30
+ return self.dag_to_fortran_impl()
31
+ elif impl == "cse_fortran":
32
+ return self.sympy_cse_fortran_impl()
33
+
34
+
35
+ def op_c_impl(self, f, children):
36
+ if f == "add":
37
+ return " + ".join(children);
38
+ elif f == "mul":
39
+ return " * ".join(children);
40
+ elif f == "pow":
41
+ assert len(children) == 2
42
+ if self.cdtype == "double":
43
+ return f"pow({children[0]}, {children[1]})"
44
+ else:
45
+ return f"powf({children[0]}, {children[1]})"
46
+ elif f == "ln":
47
+ assert len(children) == 1
48
+ if self.cdtype == "double":
49
+ return f"log({children[0]})"
50
+ else:
51
+ return f"logf({children[0]})"
52
+ else:
53
+ if f in OPERATORS and OPERATORS[f][1] == 1:
54
+ assert len(children) == 1
55
+ if self.cdtype == "double":
56
+ return f"{f}({children[0]})"
57
+ else:
58
+ return f"{f}f({children[0]})"
59
+ else:
60
+ raise DecodeError(f"C impl: operation {f} not handled")
61
+
62
+ def op_f_impl(self, f, children):
63
+ if f == "add":
64
+ j = ")+(".join(children)
65
+ return "(" + j + ")"
66
+ elif f == "mul":
67
+ j = ")*(".join(children)
68
+ return "(" + j + ")"
69
+ elif f == "pow":
70
+ assert len(children) == 2
71
+ return f"({children[0]})**({children[1]})"
72
+ elif f == "ln":
73
+ assert len(children) == 1
74
+ return f"log({children[0]})"
75
+ else:
76
+ if f in OPERATORS and OPERATORS[f][1] == 1:
77
+ assert len(children) == 1
78
+ return f"{f}({children[0]})"
79
+ else:
80
+ raise DecodeError(f"F impl: operation {f} not handled")
81
+
82
+ def full_c_code(self, body):
83
+ pre = f"#include <stdio.h>\n#include <math.h>\n{self.cdtype} myfunc({self.cdtype} x) {{"
84
+ post = f"}}\nint main() {{ {self.cdtype} x; scanf(\"%{self.cpf}\", &x); printf(\"%{self.cpf}\", myfunc(x)); }}"
85
+ return f"{pre}\n{body}\n{post}"
86
+
87
+ def full_f_code(self, body):
88
+ pre = "function myfunc(x) result(y)\nimplicit none\n" + \
89
+ f"{self.fdtype}, intent(in) :: x\n{self.fdtype} :: y, E, pi\n"
90
+ post = "end function myfunc\nprogram main\nimplicit none\n" + \
91
+ f"{self.fdtype} :: x\n{self.fdtype} :: myfunc\n" + \
92
+ "read(*, *) x\nprint *, \"y is:\", myfunc(x)\nend program main"
93
+ return f"{pre}\n{body}\n{post}"
94
+
95
+ def dag_to_c_impl(self):
96
+ dag = sympy_to_dag(self.expr, csuf="F" if self.cdtype == "float" else "")
97
+ cstr = ""
98
+ added_pi, added_E = False, False
99
+ for c in self.constants:
100
+ cstr += f"{self.cdtype} {c} = {self.constants[c]};\n"
101
+ varidx = it.count()
102
+ for node in reversed(list(nx.topological_sort(dag))):
103
+ label = dag.nodes[node]["label"]
104
+ children = [dag.nodes[n]["var"] for n in dag.adj[node]]
105
+ if len(children) == 0:
106
+ if label == "pi":
107
+ if self.cdtype == "float" and not added_pi:
108
+ cstr += "const float pi = 3.14159265F;\n"
109
+ added_pi = True
110
+ else:
111
+ label = "M_PI"
112
+ elif label == "E":
113
+ if self.cdtype == "float" and not added_E:
114
+ cstr += "const float E = 2.71828183F;\n"
115
+ added_E = True
116
+ else:
117
+ label = "M_E"
118
+ dag.nodes[node]["var"] = label
119
+ continue
120
+ varname = f"t{next(varidx)}"
121
+ cexpr = self.op_c_impl(label, children)
122
+ dag.nodes[node]["var"] = varname
123
+ cstr += f"{self.cdtype} {varname} = {cexpr};\n"
124
+ retname = varname
125
+ cstr += f"return {retname};\n"
126
+ return self.full_c_code(cstr)
127
+
128
+ def dag_to_fortran_impl(self):
129
+ csuf = "" if self.fdtype == "real" else "d0"
130
+ dag = sympy_to_dag(self.expr, csuf=csuf)
131
+ varstr = ""
132
+ fstr = "parameter E = 2.71828183\nparameter pi = 3.14159265\n"
133
+ for c in self.constants:
134
+ varstr += f"{self.fdtype} :: {c}\n"
135
+ fstr += f"parameter {c} = {self.constants[c]}{csuf}\n"
136
+ varidx = it.count()
137
+ allvars = []
138
+ for node in reversed(list(nx.topological_sort(dag))):
139
+ label = dag.nodes[node]["label"]
140
+ children = [dag.nodes[n]["var"] for n in dag.adj[node]]
141
+ if len(children) == 0:
142
+ dag.nodes[node]["var"] = label
143
+ continue
144
+ varname = f"t{next(varidx)}"
145
+ fexpr = self.op_f_impl(label, children)
146
+ dag.nodes[node]["var"] = varname
147
+ fstr += f"{varname} = {fexpr}\n"
148
+ retname = varname
149
+ varstr += f"{self.fdtype} :: {varname}\n"
150
+ fstr += f"y = {retname};\n"
151
+ fstr = varstr + "\n" + fstr
152
+ return self.full_f_code(fstr)
153
+
154
+ def sympy_cse_c_impl(self):
155
+ if self.cdtype == "float":
156
+ extraargs = {
157
+ "type_aliases": {ast.real: ast.float32},
158
+ "math_macros": {},
159
+ }
160
+ else:
161
+ extraargs = {}
162
+ cstr = ""
163
+ for c in self.constants:
164
+ cstr += f"{self.cdtype} {c} = {self.constants[c]};\n"
165
+ xvars, xpr = sp.cse(self.expr)
166
+ for vname, vxpr in xvars:
167
+ code = sp.ccode(vxpr, assign_to=vname.name, **extraargs)
168
+ cstr += f"{self.cdtype} {vname.name}; {code};\n"
169
+ assert len(xpr) == 1
170
+ code = sp.ccode(xpr[0], assign_to="y", **extraargs)
171
+ cstr += f"{self.cdtype} y; {code}; return y;\n"
172
+ return self.full_c_code(cstr)
173
+
174
+ def sympy_cse_fortran_impl(self):
175
+ csuf = "" if self.fdtype == "real" else "d0"
176
+ varstr = ""
177
+ fstr = ""
178
+ for c in self.constants:
179
+ varstr += f"{self.fdtype} :: {c}\n"
180
+ fstr += f"parameter {c} = {self.constants[c]}{csuf}\n"
181
+ xvars, xpr = sp.cse(self.expr)
182
+ for vname, vxpr in xvars:
183
+ varstr += f"{self.fdtype} :: {vname.name}\n"
184
+ fstr += sp.fcode(vxpr, assign_to=vname.name, standard=95, source_format="free") + "\n"
185
+ assert len(xpr) == 1
186
+ fstr += sp.fcode(xpr[0], assign_to="y", standard=95, source_format="free") + "\n"
187
+ fstr = varstr + "\n" + fstr
188
+ if self.fdtype == "real":
189
+ # Hack to fix sympy generation
190
+ fstr = fstr.replace("d0", "")
191
+ return self.full_f_code(fstr)
192
+
193
+
194
+
195
+ # For testing only
196
+ if __name__ == "__main__":
197
+ from .parser import parse_prefix_to_sympy, sympy_to_dag
198
+
199
+ prefs = "add mul div INT+ 1 INT+ 5 x mul div INT+ 1 INT+ 5 mul x tan pow x INT+ 2".split(" ")
200
+ exp = parse_prefix_to_sympy(prefs)
201
+ impl = Implementor(exp, dtype="float")
202
+
203
+ print("DAG C:")
204
+ print(impl.dag_to_c_impl())
205
+ print("DAG Fortran:")
206
+ print(impl.dag_to_fortran_impl())
207
+ print("CSE C:")
208
+ print(impl.sympy_cse_c_impl())
209
+ print("CSE Fortran:")
210
+ print(impl.sympy_cse_fortran_impl())
remend/parser.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sympy as sp
2
+ import networkx as nx
3
+ import itertools as it
4
+ import sys
5
+
6
+ from .util import DecodeError, sympy_expr_ok
7
+
8
+ OPERATORS = {
9
+ # Elementary functions
10
+ 'add': (lambda a,b: a+b, 2),
11
+ 'sub': (lambda a,b: a-b, 2),
12
+ 'mul': (lambda a,b: a*b, 2),
13
+ 'div': (lambda a,b: a/b, 2),
14
+ 'pow': (lambda a,b: a**b, 2),
15
+ # 'inv': (lambda a: 1/a, 1),
16
+ # 'pow2': (lambda a: a**2, 1),
17
+ # 'pow3': (lambda a: a**3, 1),
18
+ # 'pow4': (lambda a: a**4, 1),
19
+ # 'pow5': (lambda a: a**5, 1),
20
+ 'sqrt': (lambda a: sp.sqrt(a), 1),
21
+ 'exp': (lambda a: sp.exp(a), 1),
22
+ 'ln': (lambda a: sp.ln(a), 1),
23
+ # 'abs': (lambda a: sp.abs(a), 1),
24
+ # 'sign': (lambda a: sp.sign(a), 1),
25
+ # Trigonometric Functions
26
+ 'sin': (lambda a: sp.sin(a), 1),
27
+ 'cos': (lambda a: sp.cos(a), 1),
28
+ 'tan': (lambda a: sp.tan(a), 1),
29
+ 'cot': (lambda a: sp.cot(a), 1),
30
+ 'sec': (lambda a: sp.sec(a), 1),
31
+ 'csc': (lambda a: sp.csc(a), 1),
32
+ # Trigonometric Inverses
33
+ 'asin': (lambda a: sp.asin(a), 1),
34
+ 'acos': (lambda a: sp.acos(a), 1),
35
+ 'atan': (lambda a: sp.atan(a), 1),
36
+ 'acot': (lambda a: sp.acot(a), 1),
37
+ 'asec': (lambda a: sp.asec(a), 1),
38
+ 'acsc': (lambda a: sp.acsc(a), 1),
39
+ # Hyperbolic
40
+ # 'sinh': (lambda a: sp.sinh(a), 1),
41
+ # 'cosh': (lambda a: sp.cosh(a), 1),
42
+ # 'tanh': (lambda a: sp.tanh(a), 1),
43
+ }
44
+
45
+ CONSTANTS = {
46
+ 'E': sp.E,
47
+ 'pi': sp.pi,
48
+ '0': 0,
49
+ '1': 1,
50
+ '2': 2,
51
+ '3': 3,
52
+ '4': 4,
53
+ '5': 5,
54
+ '6': 6,
55
+ '7': 7,
56
+ '8': 8,
57
+ '9': 9,
58
+ }
59
+
60
+ VARIABLES = {
61
+ 'x': sp.Symbol('x'),
62
+ 'x0': sp.Symbol('x0'),
63
+ 'x1': sp.Symbol('x1'),
64
+
65
+ 'c0': sp.Symbol('c0'),
66
+ 'c1': sp.Symbol('c1'),
67
+ 'c2': sp.Symbol('c2'),
68
+ 'c3': sp.Symbol('c3'),
69
+ 'c4': sp.Symbol('c4'),
70
+ 'c5': sp.Symbol('c5'),
71
+ 'c6': sp.Symbol('c6'),
72
+ 'c7': sp.Symbol('c7'),
73
+ 'c8': sp.Symbol('c8'),
74
+ 'c9': sp.Symbol('c9'),
75
+ 'c10': sp.Symbol('c10'),
76
+
77
+ 'k0': sp.Symbol('k0'),
78
+ 'k1': sp.Symbol('k1'),
79
+ 'k2': sp.Symbol('k2'),
80
+ 'k3': sp.Symbol('k3'),
81
+ # 'y': sp.Symbol('y'),
82
+ # 'z': sp.Symbol('z')
83
+ }
84
+
85
+ FUNC_TO_OP = {
86
+ sp.Add: 'add',
87
+ sp.Mul: 'mul',
88
+ sp.Pow: 'pow',
89
+
90
+ sp.log: 'ln',
91
+ sp.sqrt: 'sqrt',
92
+ sp.exp: 'exp',
93
+ sp.Abs: 'abs',
94
+ # 'abs': (lambda a: sp.abs(a), 1),
95
+ # 'sign': (lambda a: sp.sign(a), 1),
96
+ # Trigonometric Functions
97
+ sp.sin: 'sin',
98
+ sp.cos: 'cos',
99
+ sp.tan: 'tan',
100
+ sp.cot: 'cot',
101
+ sp.sec: 'sec',
102
+ sp.csc: 'csc',
103
+ # Trigonometric Inverses
104
+ sp.asin: 'asin',
105
+ sp.acos: 'acos',
106
+ sp.atan: 'atan',
107
+ sp.acot: 'acot',
108
+ sp.asec: 'asec',
109
+ sp.acsc: 'acsc',
110
+ # Hyperbolic
111
+ # sp.cosh: 'cosh',
112
+ # sp.sinh: 'sinh',
113
+ # sp.tanh: 'tanh'
114
+ }
115
+
116
+ def sympy_func_to_op(f):
117
+ if f in FUNC_TO_OP:
118
+ return FUNC_TO_OP[f]
119
+ else:
120
+ raise DecodeError(f"Op not found {f}")
121
+ return str(f)
122
+
123
+ def isint(s):
124
+ try:
125
+ int(s)
126
+ return True
127
+ except ValueError:
128
+ return False
129
+
130
+ def reverse_iter_prefix(prefs):
131
+ n = len(prefs) - 1
132
+ # currnum = 0
133
+ # currpow = 1
134
+ currnum = []
135
+ while n >= 0:
136
+ if isint(prefs[n]) or prefs[n] in ["e", "+", "-", "."]:
137
+ currnum += prefs[n]
138
+ # currnum += currpow * int(prefs[n])
139
+ # currpow *= 10
140
+ elif prefs[n][:3] == "INT":
141
+ parsedint = int("".join(reversed(currnum)))
142
+ if prefs[n][3] == "+":
143
+ yield parsedint
144
+ else:
145
+ yield -parsedint
146
+ currnum = []
147
+ # currpow = 1
148
+ elif prefs[n][:5] == "FLOAT":
149
+ parsedfloat = float("".join(reversed(currnum)))
150
+ if prefs[n][5] == "+":
151
+ yield parsedfloat
152
+ else:
153
+ yield -parsedfloat
154
+ currnum = []
155
+ else:
156
+ yield prefs[n]
157
+ n -= 1
158
+
159
+ def parse_prefix_to_sympy(prefs):
160
+ stack = []
161
+ for val in reverse_iter_prefix(prefs):
162
+ # print(stack, val)
163
+ if val in OPERATORS:
164
+ spop, numops = OPERATORS[val]
165
+ operands = [stack.pop() for i in range(numops)]
166
+ expr = spop(*operands)
167
+ stack.append(expr)
168
+ elif val in CONSTANTS:
169
+ stack.append(CONSTANTS[val])
170
+ elif val in VARIABLES:
171
+ stack.append(VARIABLES[val])
172
+ elif type(val) == int or type(val) == float:
173
+ stack.append(val)
174
+ elif val == "(" or val == ")":
175
+ # Simply ignore brackets
176
+ continue
177
+ else:
178
+ raise DecodeError(f"{val} invalid")
179
+
180
+ if len(stack) != 1:
181
+ raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
182
+ expr = stack.pop()
183
+ if not sympy_expr_ok(expr):
184
+ raise DecodeError("Complex or infinite expression")
185
+ return expr
186
+
187
+ def parse_postfix_to_sympy(prefs):
188
+ stack = []
189
+ postfix = reversed(list(reverse_iter_prefix(prefs)))
190
+ for val in postfix:
191
+ if val in OPERATORS:
192
+ spop, numops = OPERATORS[val]
193
+ operands = [stack.pop() for i in range(numops)]
194
+ expr = spop(*operands)
195
+ stack.append(expr)
196
+ elif val in CONSTANTS:
197
+ stack.append(CONSTANTS[val])
198
+ elif val in VARIABLES:
199
+ stack.append(VARIABLES[val])
200
+ elif type(val) == int or type(val) == float:
201
+ stack.append(val)
202
+ elif val == "(" or val == ")":
203
+ # Simply ignore brackets
204
+ continue
205
+ else:
206
+ raise DecodeError(f"{val} invalid")
207
+
208
+ if len(stack) != 1:
209
+ raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
210
+ expr = stack.pop()
211
+ if not sympy_expr_ok(expr):
212
+ raise DecodeError("Complex or infinite expression")
213
+ return expr
214
+
215
+
216
+ def parse_prefix_to_tree(prefs):
217
+ tree = nx.DiGraph()
218
+ stack = []
219
+ newidx = len(prefs)
220
+ for nidx, val in enumerate(reverse_iter_prefix(prefs)):
221
+ tree.add_node(nidx, label=val)
222
+ if val in OPERATORS:
223
+ _, numops = OPERATORS[val]
224
+ childs = [stack.pop() for i in range(numops)]
225
+ if val in {"pow", "sub", "div"}:
226
+ # Ordered children
227
+ tree.add_node(newidx, label="lhs")
228
+ tree.add_node(newidx+1, label="rhs")
229
+ tree.add_edge(nidx, newidx)
230
+ tree.add_edge(nidx, newidx+1)
231
+ tree.add_edge(newidx, childs[0])
232
+ tree.add_edge(newidx+1, childs[1])
233
+ newidx += 2
234
+ else:
235
+ for c in childs:
236
+ tree.add_edge(nidx, c)
237
+ elif val in CONSTANTS or val in VARIABLES or type(val) == int:
238
+ pass
239
+ else:
240
+ raise DecodeError(f"Val {val} invalid")
241
+ stack.append(nidx)
242
+
243
+ if len(stack) != 1:
244
+ raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
245
+
246
+ return tree, stack.pop() # Root node
247
+
248
+ def sympy_to_dag(expression, csuf=""):
249
+ dag = nx.DiGraph()
250
+ seen = {}
251
+ nitr = it.count()
252
+
253
+ def _dfs(node):
254
+ children = []
255
+ for child in node.args:
256
+ if child in seen:
257
+ cid = seen[child]
258
+ else:
259
+ cid = _dfs(child)
260
+ children.append(cid)
261
+
262
+ nid = next(nitr)
263
+ dag.add_node(nid, expr=node)
264
+ seen[node] = nid
265
+ for cid in children:
266
+ dag.add_edge(nid, cid)
267
+ return nid
268
+
269
+ _dfs(expression)
270
+ for node in dag.nodes:
271
+ if len(dag.adj[node]) == 0:
272
+ e = dag.nodes[node]["expr"]
273
+ if isinstance(e, sp.Integer):
274
+ dag.nodes[node]["label"] = f"{e}.0{csuf}"
275
+ elif isinstance(e, sp.Rational):
276
+ dag.nodes[node]["label"] = f"{e.p}.0{csuf}/{e.q}.0{csuf}"
277
+ elif isinstance(e, sp.Float):
278
+ dag.nodes[node]["label"] = f"{float(e)}{csuf}"
279
+ else:
280
+ dag.nodes[node]["label"] = str(e)
281
+ else:
282
+ dag.nodes[node]["label"] = sympy_func_to_op(dag.nodes[node]["expr"].func)
283
+
284
+ return dag
285
+
286
+ def sympy_to_prefix(expr):
287
+ trav = []
288
+
289
+ def _pre(node):
290
+ nonlocal trav
291
+ if isinstance(node, sp.Rational):
292
+ if node.q != 1:
293
+ trav.append("div")
294
+ _pre(node.p)
295
+ _pre(node.q)
296
+ else:
297
+ _pre(node.p)
298
+ elif isinstance(node, sp.Integer) or isinstance(node, int):
299
+ v = int(node)
300
+ if v >= 0:
301
+ trav.append("INT+")
302
+ trav.extend(list(str(v)))
303
+ else:
304
+ trav.append("INT-")
305
+ trav.extend(list(str(-v)))
306
+ elif isinstance(node, sp.Symbol):
307
+ trav.append(str(node))
308
+ elif isinstance(node, sp.Mul):
309
+ mulargs = []
310
+ divargs = []
311
+ children = node.args
312
+ for child in children:
313
+ if isinstance(child, sp.Pow) and \
314
+ isinstance(child.args[1], sp.Integer) and child.args[1] == -1:
315
+ divargs.append(child.args[0])
316
+ else:
317
+ mulargs.append(child)
318
+ if len(divargs) > 0:
319
+ trav.append("div")
320
+ if len(mulargs) == 0:
321
+ trav.append("INT+")
322
+ trav.append("1")
323
+ # Insert numerator
324
+ for i, child in enumerate(mulargs):
325
+ if i < len(mulargs) - 1:
326
+ trav.append("mul")
327
+ _pre(child)
328
+ # Insert denominator
329
+ for i, child in enumerate(divargs):
330
+ if i < len(divargs) - 1:
331
+ trav.append("mul")
332
+ _pre(child)
333
+ elif isinstance(node, sp.Add):
334
+ addargs = []
335
+ subargs = []
336
+ children = node.args
337
+ for child in children:
338
+ if isinstance(child, sp.Mul) and len(child.args) == 2 and \
339
+ isinstance(child.args[1], sp.Integer) and child.args[1] == -1:
340
+ subargs.append(child.args[0])
341
+ elif isinstance(child, sp.Mul) and len(child.args) == 2 and \
342
+ isinstance(child.args[0], sp.Integer) and child.args[0] == -1:
343
+ subargs.append(child.args[1])
344
+ else:
345
+ addargs.append(child)
346
+ if len(subargs) > 0:
347
+ trav.append("sub")
348
+ if len(addargs) == 0:
349
+ trav.append("INT+")
350
+ trav.append("0")
351
+ # Insert numerator
352
+ for i, child in enumerate(addargs):
353
+ if i < len(addargs) - 1:
354
+ trav.append("add")
355
+ _pre(child)
356
+ # Insert denominator
357
+ for i, child in enumerate(subargs):
358
+ if i < len(subargs) - 1:
359
+ trav.append("add")
360
+ _pre(child)
361
+ elif isinstance(node, sp.Float):
362
+ rep = sp.nsimplify(node, tolerance=1e-7)
363
+ if isinstance(rep, sp.Integer):
364
+ _pre(rep)
365
+ elif isinstance(rep, sp.Rational) and rep.q <= 16:
366
+ _pre(rep)
367
+ else:
368
+ raise DecodeError(f"Float {node} encountered while generating")
369
+ # trav.append(str(node))
370
+ elif node == sp.E or node == sp.pi:
371
+ # Transcendental constants
372
+ trav.append(str(node))
373
+ else:
374
+ op = sympy_func_to_op(node.func)
375
+ children = node.args
376
+ for i, child in enumerate(children):
377
+ # Insert op repeatedly to maintain binary tree
378
+ if i == 0 or i < len(children) - 1:
379
+ trav.append(op)
380
+ _pre(child)
381
+ _pre(expr)
382
+ return trav
383
+
384
+ def constant_fold(expr):
385
+ q = [expr]
386
+ cidx = 0
387
+ subsmap = {}
388
+ constmap = {}
389
+
390
+ isconst = lambda e: not any(c.is_symbol for c in e.atoms())
391
+
392
+ while len(q) > 0:
393
+ curr_expr = q.pop(0)
394
+ if isinstance(curr_expr, sp.Number) or isconst(curr_expr):
395
+ const_expr = curr_expr.evalf()
396
+ rep = sp.nsimplify(const_expr, [sp.E, sp.pi], tolerance=1e-7)
397
+ if isinstance(rep, sp.Integer) or \
398
+ (isinstance(rep, sp.Rational) and rep.q <= 16) or \
399
+ rep == sp.E or rep == sp.pi:
400
+ subsmap[curr_expr] = rep
401
+ else:
402
+ val = float(const_expr)
403
+ found = False
404
+ for c in constmap:
405
+ if abs(val - constmap[c]) < 1e-7:
406
+ subsmap[curr_expr] = sp.Symbol(c)
407
+ found = True
408
+ elif abs(1/val - constmap[c]) < 1e-7:
409
+ subsmap[curr_expr] = 1/sp.Symbol(c)
410
+ found = True
411
+ elif abs(-val - constmap[c]) < 1e-7:
412
+ subsmap[curr_expr] = -sp.Symbol(c)
413
+ found = True
414
+ elif abs(-1/val - constmap[c]) < 1e-7:
415
+ subsmap[curr_expr] = -1/sp.Symbol(c)
416
+ found = True
417
+ if not found:
418
+ subsmap[curr_expr] = sp.Symbol(f"k{cidx}")
419
+ constmap[f"k{cidx}"] = val
420
+ cidx += 1
421
+ else:
422
+ for child in curr_expr.args:
423
+ q.append(child)
424
+
425
+ return expr.subs(subsmap), constmap
426
+
427
+
428
+ # For testing only
429
+ if __name__ == "__main__":
430
+ prefs = "add mul INT- 1 x mul pow ln INT+ 4 INT- 1 add x mul INT- 1 pow x INT+ 5".split(" ")
431
+ exp = parse_prefix_to_sympy(prefs)
432
+ exp = sp.simplify(exp)
433
+ print(exp)
434
+ print(constant_fold(exp))
435
+
436
+ # prefs = "mul x mul pow cos INT+ 4 INT- 3 pow ln INT+ 3 INT- 6".split(" ")
437
+ # exp = parse_prefix_to_sympy(prefs)
438
+ # print(exp)
439
+ # dag = sympy_to_dag(exp)
440
+
441
+ # exp = sp.parse_expr("(((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - (-((x0) + (x0)))) / (-((x0) + (x0)))) * ((-((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))))) * ((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0)))))))", evaluate=False)
442
+ # # print(sympy_to_prefix(exp))
443
+
444
+ # simp = sp.simplify(exp)
445
+ # pre = sympy_to_prefix(simp)
446
+ # print(pre)
447
+ # repars = parse_prefix_to_sympy(pre)
448
+ # print(simp)
449
+ # print(repars)
remend/plot_loss.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ import json
3
+
4
+ if __name__ == "__main__":
5
+ import argparse
6
+ parser = argparse.ArgumentParser("Plot loss for the training log")
7
+ parser.add_argument("-t", "--trainlog", required=True, help="Training log file")
8
+ parser.add_argument("-l", "--loss", help="Loss plot to save (optional)")
9
+ parser.add_argument("--log-scale", default=False, action="store_true", help="Log scale")
10
+ parser.add_argument("-P", "--no-plot", default=True, action="store_false", help="Don't open matplotlib figure")
11
+ args = parser.parse_args()
12
+
13
+ train_inner_upd, train_inner_loss = [], []
14
+ train_upd, train_loss = [], []
15
+ val_upd, val_loss = [], []
16
+
17
+ with open(args.trainlog, "r") as tl:
18
+ for line in tl:
19
+ # Filter out json
20
+ if line[0] != "{":
21
+ continue
22
+ try:
23
+ data = json.loads(line.strip())
24
+ except:
25
+ continue
26
+ if "loss" in data:
27
+ loss = float(data["loss"])
28
+ upd = int(data["num_updates"])
29
+ if len(train_inner_upd) == 0 or train_inner_upd[-1] < upd:
30
+ train_inner_upd.append(upd)
31
+ train_inner_loss.append(loss)
32
+ if "valid_loss" in data:
33
+ loss = float(data["valid_loss"])
34
+ upd = int(data["valid_num_updates"])
35
+ if len(val_upd) == 0 or val_upd[-1] < upd:
36
+ val_upd.append(upd)
37
+ val_loss.append(loss)
38
+ if "train_loss" in data:
39
+ loss = float(data["train_loss"])
40
+ upd = int(data["train_num_updates"])
41
+ if len(train_upd) == 0 or train_upd[-1] < upd:
42
+ train_upd.append(upd)
43
+ train_loss.append(loss)
44
+
45
+ plt.figure()
46
+ plt.plot(train_upd, train_loss, "r")
47
+ plt.plot(val_upd, val_loss, "b")
48
+ if len(train_inner_upd) > 0:
49
+ plt.plot(train_inner_upd, train_inner_loss, "r", alpha=0.3)
50
+ plt.legend(["train", "valid"])
51
+ if args.log_scale:
52
+ plt.yscale("log")
53
+ elif min(min(train_loss), min(val_loss)) < 1:
54
+ plt.ylim((0, 1))
55
+ plt.xlabel("Updates")
56
+ plt.ylabel("Loss")
57
+ if args.loss:
58
+ plt.savefig(args.loss)
59
+ if args.no_plot:
60
+ plt.show()
remend/preprocess_remaqe.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from tqdm import tqdm
4
+ import itertools as it
5
+ import sympy as sp
6
+
7
+ from .disassemble import DisassemblerARM32
8
+ from .parser import sympy_to_prefix, isint
9
+
10
+ def match_constants(exprconst, asmconst, constsym, eps=1e-5):
11
+ def _close(a, b):
12
+ return abs(a - b) <= eps
13
+ mapping = {}
14
+ mapped = set()
15
+
16
+ for ec in exprconst:
17
+ ecf = float(exprconst[ec])
18
+ ecsym = constsym[ec]
19
+ if abs(ecf) < eps:
20
+ continue
21
+ for ac in asmconst:
22
+ acf = asmconst[ac]
23
+ acsym = constsym[ac]
24
+ if _close(acf, ecf):
25
+ mapping[ecsym] = acsym
26
+ mapped.add(ec)
27
+ break
28
+ if _close(acf, 1/ecf):
29
+ mapping[ecsym] = 1/acsym
30
+ mapped.add(ec)
31
+ break
32
+ if _close(acf, -ecf):
33
+ mapping[ecsym] = -acsym
34
+ mapped.add(ec)
35
+ break
36
+ return mapping, mapped
37
+
38
+ def replace_naming(pref):
39
+ ret = []
40
+ for p in pref:
41
+ if p == "x0":
42
+ ret.append("x")
43
+ elif p[0] == "c" and isint(p[1:]):
44
+ # Constant
45
+ ret.append("k"+p[1:])
46
+ else:
47
+ ret.append(p)
48
+ return ret
49
+
50
+
51
+ if __name__ == "__main__":
52
+ import argparse
53
+ parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
54
+ parser.add_argument("--list", required=True)
55
+ parser.add_argument("--prefix", required=True)
56
+ args = parser.parse_args()
57
+
58
+ with open(args.list, "r") as f:
59
+ mdllist = list(f)
60
+ opts = ["O0", "O1", "O2", "O3"]
61
+
62
+ asmf = open(args.prefix + ".asm", "w")
63
+ eqnf = open(args.prefix + ".eqn", "w")
64
+ constf = open(args.prefix + ".const.jsonl", "w")
65
+
66
+ basedir = os.path.dirname(args.list)
67
+ for mdl in tqdm(mdllist):
68
+ mdl = mdl.strip()
69
+ mdlname = os.path.basename(mdl)
70
+ with open(os.path.join(basedir, mdl, "expressions.json")) as f:
71
+ expressions = json.load(f)
72
+ yexpr = expressions["expressions"]["y"]
73
+ exprconsts = {c: float(expressions["constants"][c]) for c in expressions["constants"]}
74
+ if len(exprconsts) > 4:
75
+ continue
76
+ yexpr = sp.parse_expr(yexpr)
77
+ exprconstsym = {c: sp.Symbol(c) for c in expressions["constants"]}
78
+
79
+ for opt in opts:
80
+ funcname = f"{mdlname}_run"
81
+ binf = os.path.join(basedir, mdl, opt, f"c_bin.elf")
82
+ D = DisassemblerARM32(binf)
83
+ diss = D.disassemble(funcname)
84
+ constants = D.constants
85
+ if len(constants) > 3:
86
+ continue
87
+
88
+ exprconstsym.update({c: sp.Symbol(f"c{c}") for c in constants})
89
+ mapping, mapped = match_constants(exprconsts, constants, exprconstsym)
90
+ if len(mapped) != len(constants):
91
+ continue
92
+
93
+ exprsubs = yexpr.subs(mapping)
94
+ exprprefix = replace_naming(sympy_to_prefix(exprsubs))
95
+
96
+ asmf.write(diss + "\n")
97
+ eqnf.write(" ".join(exprprefix) + "\n")
98
+ constf.write(json.dumps(constants) + "\n")
99
+
100
+ asmf.close()
101
+ eqnf.close()
102
+ constf.close()
remend/util.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import signal
3
+ import sympy as sp
4
+
5
+ def timeout_handler(signum, frame):
6
+ raise TimeoutError("Block timed out")
7
+ @contextmanager
8
+ def timeout(duration):
9
+ signal.signal(signal.SIGALRM, timeout_handler)
10
+ signal.alarm(duration)
11
+ try:
12
+ yield
13
+ finally:
14
+ signal.alarm(0)
15
+
16
+ class DecodeError(Exception):
17
+ pass
18
+
19
+ def sympy_expr_ok(expr):
20
+ atoms = expr.atoms()
21
+ return not (sp.I in atoms or sp.oo in atoms or sp.zoo in atoms or sp.nan in atoms)