Commit
·
7145fd6
1
Parent(s):
e78b7eb
Add REMEND python module
Browse files- pyproject.toml +30 -0
- remend/__init__.py +0 -0
- remend/bpe.py +64 -0
- remend/bpe_apply.py +25 -0
- remend/change_eqn_format.py +79 -0
- remend/check_generated.py +143 -0
- remend/compile_dataset.py +185 -0
- remend/compile_eqn.sh +44 -0
- remend/convert_generated.py +24 -0
- remend/deduplicate_split.py +111 -0
- remend/disassemble.py +553 -0
- remend/edit_model.py +16 -0
- remend/eval_generated.py +100 -0
- remend/experiment.py +75 -0
- remend/find_duplicates.py +69 -0
- remend/implementation.py +210 -0
- remend/parser.py +449 -0
- remend/plot_loss.py +60 -0
- remend/preprocess_remaqe.py +102 -0
- remend/util.py +21 -0
pyproject.toml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["hatchling"]
|
3 |
+
build-backend = "hatchling.build"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "remend"
|
7 |
+
version = "1.0"
|
8 |
+
authors = [{name="Meet Udeshi", email="[email protected]"}]
|
9 |
+
description = "Neural Decompilation for Reverse Engineering Math Equations from Binary Executables"
|
10 |
+
readme = "README.md"
|
11 |
+
classifiers = [
|
12 |
+
"Programming Language :: Python :: 3",
|
13 |
+
"Operating System :: OS Independent",
|
14 |
+
]
|
15 |
+
requires-python = ">=3.9"
|
16 |
+
dependencies = [
|
17 |
+
"networkx",
|
18 |
+
"capstone",
|
19 |
+
"Levenshtein",
|
20 |
+
"tqdm",
|
21 |
+
"numpy",
|
22 |
+
"sympy",
|
23 |
+
"fairseq",
|
24 |
+
"torch",
|
25 |
+
"matplotlib",
|
26 |
+
"tokenizers"
|
27 |
+
]
|
28 |
+
|
29 |
+
[tool.hatch.build.targets.wheel]
|
30 |
+
packages = ["remend"]
|
remend/__init__.py
ADDED
File without changes
|
remend/bpe.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers import pre_tokenizers, Tokenizer
|
2 |
+
from tokenizers.models import BPE
|
3 |
+
from tokenizers.trainers import BpeTrainer
|
4 |
+
from tokenizers.pre_tokenizers import Whitespace, PreTokenizer
|
5 |
+
import random
|
6 |
+
import os
|
7 |
+
from tqdm import tqdm
|
8 |
+
import itertools as it
|
9 |
+
|
10 |
+
class ImmPreTokenizer:
|
11 |
+
def pre_tokenize(self, pretok):
|
12 |
+
pretok.split(self.hex_imm_split)
|
13 |
+
def hex_imm_split(self, i, norm_str):
|
14 |
+
tok = str(norm_str)
|
15 |
+
if tok[:2] == "0x" or tok.isdigit():
|
16 |
+
return [norm_str[i:i+1] for i in range(len(tok))]
|
17 |
+
else:
|
18 |
+
return [norm_str]
|
19 |
+
|
20 |
+
def get_asm_tok(files, save):
|
21 |
+
asm_tok = Tokenizer(BPE(unk_token="@@UNK@@"))
|
22 |
+
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
|
23 |
+
asm_train = BpeTrainer(special_tokens=["@@UNK@@"])
|
24 |
+
|
25 |
+
asm_tok.train(files, asm_train)
|
26 |
+
asm_tok.pre_tokenizer = Whitespace() # Hack to save, careful to restore ImmPreTokenizer
|
27 |
+
asm_tok.save(save)
|
28 |
+
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
|
29 |
+
|
30 |
+
return asm_tok
|
31 |
+
|
32 |
+
def load_asm_tok(load):
|
33 |
+
asm_tok = Tokenizer.from_file(load)
|
34 |
+
asm_tok.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), PreTokenizer.custom(ImmPreTokenizer())])
|
35 |
+
return asm_tok
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
import argparse
|
40 |
+
parser = argparse.ArgumentParser("Train the tokenizer and tokenize the asm")
|
41 |
+
parser.add_argument("-i", "--indir", required=True, help="output directory")
|
42 |
+
parser.add_argument("-o", "--outdir", default="tokenized", help="output directory")
|
43 |
+
args = parser.parse_args()
|
44 |
+
|
45 |
+
os.makedirs(args.outdir, exist_ok=True)
|
46 |
+
injoin = lambda p: os.path.join(args.indir, p)
|
47 |
+
pjoin = lambda p: os.path.join(args.outdir, p)
|
48 |
+
max_asm_toks = 0
|
49 |
+
|
50 |
+
asm_tok = get_asm_tok([injoin("train.asm"), injoin("valid.asm")], pjoin("asm_tokens.json"))
|
51 |
+
for split in ["train", "valid", "test"]:
|
52 |
+
asmfile = split + ".asm"
|
53 |
+
with open(injoin(asmfile), "r") as asmf, open(pjoin(asmfile), "w") as asmtokf:
|
54 |
+
for asm in tqdm(asmf, desc=f"Tokenizing {split}"):
|
55 |
+
asm = asm.strip()
|
56 |
+
asm_enc = asm_tok.encode(asm)
|
57 |
+
max_asm_toks = max(max_asm_toks, len(asm_enc.tokens))
|
58 |
+
asm_seq = " ".join(asm_enc.tokens)
|
59 |
+
asmtokf.write(asm_seq + "\n")
|
60 |
+
|
61 |
+
print("Maximum tokens:", max_asm_toks)
|
62 |
+
|
63 |
+
# After this, run command:
|
64 |
+
# fairseq-preprocess -s asm -t eqn --trainpref {OUTDIR}/train --validpref {OUTDIR}/valid --testpref {OUTDIR}/test --destdir {OUTDIR}
|
remend/bpe_apply.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
|
3 |
+
from .bpe import load_asm_tok
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
import argparse
|
7 |
+
parser = argparse.ArgumentParser("Tokenize using existing tokenizer")
|
8 |
+
parser.add_argument("-t", "--tokenizer", required=True, help="existing tokenizer")
|
9 |
+
parser.add_argument("-i", "--input", required=True, help="input file")
|
10 |
+
parser.add_argument("-o", "--output", required=True, help="output file")
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
max_asm_toks = 0
|
14 |
+
asm_tok = load_asm_tok(args.tokenizer)
|
15 |
+
|
16 |
+
with open(args.input, "r") as asmf, open(args.output, "w") as asmtokf:
|
17 |
+
for asm in tqdm(asmf, desc=f"Tokenizing"):
|
18 |
+
asm = asm.strip()
|
19 |
+
asm_enc = asm_tok.encode(asm)
|
20 |
+
max_asm_toks = max(max_asm_toks, len(asm_enc.tokens))
|
21 |
+
asm_seq = " ".join(asm_enc.tokens)
|
22 |
+
asmtokf.write(asm_seq + "\n")
|
23 |
+
|
24 |
+
print("Maximum tokens:", max_asm_toks)
|
25 |
+
|
remend/change_eqn_format.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .parser import isint, OPERATORS
|
2 |
+
|
3 |
+
def prefix_to_brackets(eqn):
|
4 |
+
stack = []
|
5 |
+
lastop = []
|
6 |
+
intunit = []
|
7 |
+
N = len(eqn)
|
8 |
+
i = 0
|
9 |
+
while i < N:
|
10 |
+
# print("Stack", stack)
|
11 |
+
val = eqn[i]
|
12 |
+
if val.startswith("INT"):
|
13 |
+
intunit.append(val)
|
14 |
+
i += 1
|
15 |
+
while i < N and isint(eqn[i]):
|
16 |
+
intunit.append(eqn[i])
|
17 |
+
i += 1
|
18 |
+
stack.append(" ".join(intunit))
|
19 |
+
intunit = []
|
20 |
+
i -= 1
|
21 |
+
elif val in OPERATORS:
|
22 |
+
_, numops = OPERATORS[val]
|
23 |
+
lastop.append((len(stack), numops))
|
24 |
+
stack.append(val)
|
25 |
+
else:
|
26 |
+
stack.append(val)
|
27 |
+
|
28 |
+
while len(lastop) > 0 and len(stack) > lastop[-1][0] + lastop[-1][1]:
|
29 |
+
# Combine op
|
30 |
+
# print(lastop[-1], stack[lastop[-1][0]:])
|
31 |
+
op = " ".join(stack[lastop[-1][0]:])
|
32 |
+
del stack[lastop[-1][0]:]
|
33 |
+
lastop.pop()
|
34 |
+
stack.append(f"( {op} )")
|
35 |
+
i += 1
|
36 |
+
assert(len(stack) == 1)
|
37 |
+
return stack[0]
|
38 |
+
|
39 |
+
def prefix_to_postfix(eqn):
|
40 |
+
if eqn[0].startswith("INT"):
|
41 |
+
intunit = [eqn[0]]
|
42 |
+
for i, val in enumerate(eqn[1:]):
|
43 |
+
if not isint(val):
|
44 |
+
break
|
45 |
+
intunit.append(val)
|
46 |
+
return intunit, eqn[i+1:]
|
47 |
+
elif eqn[0] in OPERATORS:
|
48 |
+
_, numops = OPERATORS[eqn[0]]
|
49 |
+
remeqn = eqn[1:]
|
50 |
+
ops = []
|
51 |
+
for i in range(numops):
|
52 |
+
op, remeqn = prefix_to_postfix(remeqn)
|
53 |
+
ops.extend(op)
|
54 |
+
ops.append(eqn[0]) # Restructured to postfix
|
55 |
+
return ops, remeqn
|
56 |
+
else:
|
57 |
+
return [eqn[0]], eqn[1:]
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
import argparse
|
61 |
+
parser = argparse.ArgumentParser("Change equation format from prefix to other")
|
62 |
+
parser.add_argument("--eqn", required=True)
|
63 |
+
parser.add_argument("--out", required=True)
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
with open(args.eqn, "r") as inf, open(args.out, "w") as outf:
|
67 |
+
for eqn in inf:
|
68 |
+
postfix, _ = prefix_to_postfix(eqn.strip().split(" "))
|
69 |
+
outf.write(" ".join(postfix) + "\n")
|
70 |
+
|
71 |
+
# eqn = "div mul x add INT+ 5 add mul INT+ 3 x mul pow x INT+ 2 add INT- 5 add mul INT- 3 x mul x mul add INT+ 1 mul k0 pow x INT+ 3 add INT+ 4 x add INT+ 5 mul INT+ 3 x"
|
72 |
+
# eqn = "div add mul INT+ 3 x pow x INT- 4 mul sub x k0 add mul INT+ 5 x k1"
|
73 |
+
# print(" ".join(prefix_to_postfix(eqn.split(" "))[0]))
|
74 |
+
# postfix = "x INT+ 3 mul INT+ 5 add x INT+ 4 add INT+ 3 x pow k0 mul INT+ 1 add mul x mul x INT- 3 mul add INT- 5 add INT+ 2 x pow mul x INT+ 3 mul add INT+ 5 add x mul div"
|
75 |
+
# print(prefix_to_brackets(eqn.split(" ")))
|
76 |
+
|
77 |
+
# (div (mul x (add INT+ 5 (add (mul INT+ 3 x) (mul (pow x INT+ 2) (add INT- 5 (add (mul INT- 3 x) (mul x (mul (add INT+ 1 (mul k0 (pow x INT+ 3))) (add INT+ 4 x))))))))) (add INT+ 5 (mul INT+ 3 x)))
|
78 |
+
# ( div ( mul x ( add INT+ 5 ( add ( mul INT+ 3 x ) ( mul ( pow x INT+ 2 ) ( add INT- 5 ( add ( mul INT- 3 x ) ( mul x ( mul ( add INT+ 1 ( mul k0 ( pow x INT+ 3 ) ) ) ( add INT+ 4 x ) ) ) ) ) ) ) ) ) ( add INT+ 5 ( mul INT+ 3 x ) ) )
|
79 |
+
|
remend/check_generated.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sympy as sp
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
from tqdm import tqdm
|
5 |
+
from Levenshtein import distance
|
6 |
+
import networkx as nx
|
7 |
+
from networkx import graph_edit_distance
|
8 |
+
|
9 |
+
from .parser import parse_prefix_to_sympy, parse_postfix_to_sympy, isint
|
10 |
+
|
11 |
+
def percent(a, n):
|
12 |
+
return f"{a/n*100:0.1f}%"
|
13 |
+
|
14 |
+
def do_simplify_match(orig_expr, gen_expr):
|
15 |
+
orig_simp = sp.simplify(orig_expr)
|
16 |
+
gen_simp = sp.simplify(gen_expr)
|
17 |
+
if orig_simp == gen_simp:
|
18 |
+
return True
|
19 |
+
return False
|
20 |
+
|
21 |
+
def do_structure_match(orig_toks, gen_toks):
|
22 |
+
def _isconst(t):
|
23 |
+
return re.match(r"c[0-9]+", t)
|
24 |
+
def _isvar(t):
|
25 |
+
return re.match(r"x[0-9]+", t)
|
26 |
+
if len(orig_toks) != len(gen_toks):
|
27 |
+
return False
|
28 |
+
for orig, gen in zip(orig_toks, gen_toks):
|
29 |
+
if (_isconst(orig) and _isconst(gen)) \
|
30 |
+
or (_isvar(orig) and _isvar(gen)) \
|
31 |
+
or (isint(orig) and isint(gen)) \
|
32 |
+
or (orig.startswith("INT") and gen.startswith("INT")) \
|
33 |
+
or (orig == gen):
|
34 |
+
continue
|
35 |
+
# Mismatched
|
36 |
+
return False
|
37 |
+
return True
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
import argparse
|
41 |
+
parser = argparse.ArgumentParser("Check generated expressions")
|
42 |
+
parser.add_argument("-g", required=True, help="Generated expressions file")
|
43 |
+
parser.add_argument("-r", required=True, help="Results file")
|
44 |
+
parser.add_argument("--simplify", action="store_true", default=False)
|
45 |
+
parser.add_argument("--postfix", action="store_true", default=False)
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
|
49 |
+
orig_list = []
|
50 |
+
gen_list = []
|
51 |
+
with open(args.g, 'r') as f:
|
52 |
+
for line in tqdm(f, desc="Reading file"):
|
53 |
+
comps = line.strip().split("\t")
|
54 |
+
if line[0] == 'T':
|
55 |
+
num = int(comps[0][2:])
|
56 |
+
tokens = comps[1].split(" ")
|
57 |
+
orig_list.append((num, tokens))
|
58 |
+
elif line[0] == 'H':
|
59 |
+
num = int(comps[0][2:])
|
60 |
+
tokens = comps[2].split(" ")
|
61 |
+
gen_list.append((num, tokens))
|
62 |
+
|
63 |
+
N = len(orig_list)
|
64 |
+
gen_errors = []
|
65 |
+
parsed = []
|
66 |
+
exact_match = []
|
67 |
+
structure_match = []
|
68 |
+
simplify_match = []
|
69 |
+
|
70 |
+
orig_exprs = {}
|
71 |
+
gen_exprs = {}
|
72 |
+
|
73 |
+
all_aed = []
|
74 |
+
# all_ged = []
|
75 |
+
|
76 |
+
results = []
|
77 |
+
|
78 |
+
for (orig_num, orig_toks), (gen_num, gen_toks) in tqdm(zip(orig_list, gen_list), desc="Parsing expressions", total=N):
|
79 |
+
assert orig_num == gen_num
|
80 |
+
aed = distance(orig_toks, gen_toks) / (len(orig_toks) + len(gen_toks))
|
81 |
+
all_aed.append(aed)
|
82 |
+
res = {"id": gen_num, "aed": aed, "matched": False, "parsed": False}
|
83 |
+
|
84 |
+
if aed == 0:
|
85 |
+
parsed.append(orig_num)
|
86 |
+
exact_match.append(orig_num)
|
87 |
+
structure_match.append(orig_num)
|
88 |
+
res["parsed"] = True
|
89 |
+
res["matched"] = "Exact"
|
90 |
+
results.append(res)
|
91 |
+
continue
|
92 |
+
|
93 |
+
if do_structure_match(orig_toks, gen_toks):
|
94 |
+
structure_match.append(orig_num)
|
95 |
+
res["matched"] = "Structure"
|
96 |
+
|
97 |
+
if "<<unk>>" in orig_toks:
|
98 |
+
# Why this happened?
|
99 |
+
res["parsed"] = False
|
100 |
+
res["matched"] = False
|
101 |
+
results.append(res)
|
102 |
+
continue
|
103 |
+
|
104 |
+
if args.postfix:
|
105 |
+
orig_expr = parse_postfix_to_sympy(orig_toks)
|
106 |
+
else:
|
107 |
+
orig_expr = parse_prefix_to_sympy(orig_toks)
|
108 |
+
try:
|
109 |
+
if args.postfix:
|
110 |
+
gen_expr = parse_postfix_to_sympy(gen_toks)
|
111 |
+
else:
|
112 |
+
gen_expr = parse_prefix_to_sympy(gen_toks)
|
113 |
+
res["parsed"] = True
|
114 |
+
except: # Exception as e:
|
115 |
+
gen_errors.append(gen_num)
|
116 |
+
results.append(res)
|
117 |
+
continue
|
118 |
+
|
119 |
+
parsed.append(gen_num)
|
120 |
+
orig_exprs[gen_num] = orig_expr
|
121 |
+
gen_exprs[gen_num] = gen_expr
|
122 |
+
|
123 |
+
if orig_expr == gen_expr:
|
124 |
+
exact_match.append(gen_num)
|
125 |
+
res["matched"] = "Exact"
|
126 |
+
elif args.simplify and do_simplify_match(orig_expr, gen_expr):
|
127 |
+
simplify_match.append(gen_num)
|
128 |
+
res["matched"] = "Simplify"
|
129 |
+
results.append(res)
|
130 |
+
|
131 |
+
with open(args.r, "w") as resf:
|
132 |
+
for res in results:
|
133 |
+
resf.write("{id} {aed} {parsed} {matched}\n".format(**res))
|
134 |
+
resf.write("\n")
|
135 |
+
print("Total", N, file=resf)
|
136 |
+
print("Parse error", len(gen_errors), percent(len(gen_errors), N), file=resf)
|
137 |
+
print("Exact match", len(exact_match), percent(len(exact_match), N), file=resf)
|
138 |
+
print("Structure match", len(structure_match), percent(len(structure_match), N), file=resf)
|
139 |
+
if args.simplify:
|
140 |
+
print("Simplify match", len(simplify_match), percent(len(simplify_match), N), file=resf)
|
141 |
+
print("Avg SED", sum(all_aed) / len(all_aed), max(all_aed), file=resf)
|
142 |
+
# print("Avg GED", sum(all_ged) / len(all_ged), max(all_ged), file=resf)
|
143 |
+
|
remend/compile_dataset.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import random
|
3 |
+
import sympy as sp
|
4 |
+
import json
|
5 |
+
import subprocess as sproc
|
6 |
+
from os.path import realpath, dirname, join as pjoin
|
7 |
+
from os import makedirs
|
8 |
+
import multiprocessing as mp
|
9 |
+
from time import sleep
|
10 |
+
import logging
|
11 |
+
|
12 |
+
from .implementation import Implementor
|
13 |
+
from .parser import parse_prefix_to_sympy, sympy_to_prefix, constant_fold
|
14 |
+
from .disassemble import DisassemblerARM32, DisassemblerAArch64, DisassemblerX64
|
15 |
+
from .util import DecodeError, timeout, sympy_expr_ok
|
16 |
+
|
17 |
+
SCRIPT = pjoin(dirname(realpath(__file__)), "compile_eqn.sh")
|
18 |
+
|
19 |
+
QUEUE_END = "QUEUE_END_SENTINEL"
|
20 |
+
|
21 |
+
def compile_c(code, elf, arch="arm32", src="/tmp/myfunc.c", opt=0):
|
22 |
+
with open(src, "w") as f:
|
23 |
+
f.write(code)
|
24 |
+
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-c", src, elf, f"-O{opt}"], capture_output=True)
|
25 |
+
if ret.returncode != 0:
|
26 |
+
raise DecodeError("compile failed")
|
27 |
+
|
28 |
+
def compile_fortran(code, elf, arch="arm32", src="/tmp/myfunc.f95", opt=0):
|
29 |
+
with open(src, "w") as f:
|
30 |
+
f.write(code)
|
31 |
+
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-fortran", src, elf, f"-O{opt}"], capture_output=True)
|
32 |
+
if ret.returncode != 0:
|
33 |
+
raise DecodeError("compile failed")
|
34 |
+
|
35 |
+
class EquationCompiler:
|
36 |
+
def __init__(self, q, arch, impl, opt, outdir, prefix, dtype="double"):
|
37 |
+
if "fortran" in impl:
|
38 |
+
self.compiler = compile_fortran
|
39 |
+
else:
|
40 |
+
self.compiler = compile_c
|
41 |
+
|
42 |
+
if arch == "arm32":
|
43 |
+
self.disassembler = DisassemblerARM32
|
44 |
+
elif arch == "aarch64":
|
45 |
+
self.disassembler = DisassemblerAArch64
|
46 |
+
elif arch == "x64":
|
47 |
+
self.disassembler = DisassemblerX64
|
48 |
+
else:
|
49 |
+
raise DecodeError("arch not supported: " + arch)
|
50 |
+
|
51 |
+
self.q = q
|
52 |
+
self.impl = impl
|
53 |
+
self.opt = opt
|
54 |
+
self.outdir = outdir
|
55 |
+
self.prefix = prefix
|
56 |
+
self.dtype = dtype
|
57 |
+
self.arch = arch
|
58 |
+
|
59 |
+
def run(self):
|
60 |
+
outdir = pjoin(self.outdir, f"O{self.opt}", self.impl)
|
61 |
+
makedirs(outdir, exist_ok=True)
|
62 |
+
outfiles = {
|
63 |
+
"asm": open(pjoin(outdir, self.prefix + ".asm"), "w"),
|
64 |
+
"eqn": open(pjoin(outdir, self.prefix + ".eqn"), "w"),
|
65 |
+
"src": open(pjoin(outdir, self.prefix + ".src"), "w"),
|
66 |
+
"const": open(pjoin(outdir, self.prefix + ".const.jsonl"), "w"),
|
67 |
+
"err": open(pjoin(outdir, self.prefix + ".error"), "w")
|
68 |
+
}
|
69 |
+
l = 0
|
70 |
+
tmpsrc = f"/tmp/myfunc_{self.impl}_{self.opt}_{self.prefix}"
|
71 |
+
if "fortran" in self.impl:
|
72 |
+
tmpsrc += ".f95"
|
73 |
+
func = "myfunc_"
|
74 |
+
else:
|
75 |
+
tmpsrc += ".c"
|
76 |
+
func = "myfunc"
|
77 |
+
tmpelf = f"/tmp/myfunc_{self.arch}_{self.impl}_{self.opt}_{self.prefix}.elf"
|
78 |
+
|
79 |
+
while True:
|
80 |
+
data = self.q.get()
|
81 |
+
if data == QUEUE_END:
|
82 |
+
# Queue is closed, break from inf loop
|
83 |
+
break
|
84 |
+
n, expr, expr_const, pref = data
|
85 |
+
impl = Implementor(expr, constants=expr_const, dtype=self.dtype)
|
86 |
+
try:
|
87 |
+
code = impl.implement(self.impl)
|
88 |
+
self.compiler(code, tmpelf, arch=self.arch, src=tmpsrc, opt=self.opt)
|
89 |
+
disasm = self.disassembler(tmpelf, expr_constants=expr_const,
|
90 |
+
match_constants=True)
|
91 |
+
asm = disasm.disassemble(func)
|
92 |
+
if len(disasm.constants) < len(expr_const):
|
93 |
+
print(n, "constants not identified", disasm.constants, expr_const,
|
94 |
+
file=outfiles["err"])
|
95 |
+
continue
|
96 |
+
except DecodeError as e:
|
97 |
+
print(n, "impl error", e, expr, expr_const, pref, file=outfiles["err"])
|
98 |
+
continue
|
99 |
+
|
100 |
+
outfiles["asm"].write(asm + "\n")
|
101 |
+
outfiles["eqn"].write(pref + "\n")
|
102 |
+
outfiles["src"].write(f"==== pick={n} line={l} ====\n" + code + "\n")
|
103 |
+
outfiles["const"].write(json.dumps(expr_const) + "\n")
|
104 |
+
l += 1
|
105 |
+
|
106 |
+
for f in outfiles:
|
107 |
+
outfiles[f].close()
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
import argparse
|
112 |
+
parser = argparse.ArgumentParser("Compile prefix to asm->eqn dataset")
|
113 |
+
parser.add_argument("-f", "--file", required=True, help="Input file")
|
114 |
+
parser.add_argument("--outdir", required=True, help="Output directory")
|
115 |
+
parser.add_argument("--prefix", required=True, help="File prefix")
|
116 |
+
parser.add_argument("--impl", nargs="+", required=True,
|
117 |
+
choices=["dag_c", "cse_c", "dag_fortran", "cse_fortran"])
|
118 |
+
parser.add_argument("--pick", type=float, required=True,
|
119 |
+
help="Ratio of samples to pick (0 to 1)")
|
120 |
+
parser.add_argument("--start", type=int, default=0, help="Start from index")
|
121 |
+
parser.add_argument("--count", type=int, default=0, help="Process only these many")
|
122 |
+
parser.add_argument("--seed", type=int, default=1225)
|
123 |
+
parser.add_argument("--min-tokens", help="Minimum tokens in equations", type=int, default=5)
|
124 |
+
parser.add_argument("--min-ops", help="Minimum ops in equations", type=int, default=5)
|
125 |
+
parser.add_argument("--dtype", help="Implementation datatype", type=str,
|
126 |
+
choices=["double", "float"], default="double")
|
127 |
+
parser.add_argument("--arch", help="Target architecture", type=str,
|
128 |
+
choices=["arm32", "aarch64", "x64"], default="arm32")
|
129 |
+
parser.add_argument("-O", "--opt", nargs="+", type=int, choices=[0, 1, 2, 3], default=[0],
|
130 |
+
help="Optimization level (s)")
|
131 |
+
|
132 |
+
# Dont show warnings
|
133 |
+
logging.getLogger("cle").setLevel(logging.ERROR)
|
134 |
+
|
135 |
+
args = parser.parse_args()
|
136 |
+
random.seed(args.seed)
|
137 |
+
|
138 |
+
eqcompilers = [EquationCompiler(mp.Queue(), args.arch, impl, opt, args.outdir, args.prefix, dtype=args.dtype)
|
139 |
+
for impl in args.impl
|
140 |
+
for opt in args.opt]
|
141 |
+
pool = [mp.Process(target=eqc.run, args=()) for eqc in eqcompilers]
|
142 |
+
for proc in pool:
|
143 |
+
proc.start()
|
144 |
+
|
145 |
+
count = 0
|
146 |
+
prefixf = open(args.file, "r")
|
147 |
+
for n, line in tqdm(enumerate(prefixf), desc="Parsing file"):
|
148 |
+
# Skip for start lines and with some probability
|
149 |
+
if n < args.start or random.random() > args.pick:
|
150 |
+
continue
|
151 |
+
comps = line.strip().split("\t")
|
152 |
+
pref = comps[0][comps[0].find("Y'")+3:]
|
153 |
+
prefl = pref.split(" ")
|
154 |
+
# pref = comps[1].split(" ")
|
155 |
+
if len(prefl) < args.min_tokens:
|
156 |
+
continue
|
157 |
+
try:
|
158 |
+
expr = parse_prefix_to_sympy(prefl)
|
159 |
+
with timeout(10):
|
160 |
+
expr = sp.simplify(expr)
|
161 |
+
if not sympy_expr_ok(expr):
|
162 |
+
# Simplified is bad
|
163 |
+
continue
|
164 |
+
expr, expr_const = constant_fold(expr)
|
165 |
+
pref = " ".join(sympy_to_prefix(expr))
|
166 |
+
except:
|
167 |
+
continue
|
168 |
+
|
169 |
+
if sp.count_ops(expr) < args.min_ops:
|
170 |
+
continue
|
171 |
+
|
172 |
+
for eqc in eqcompilers:
|
173 |
+
# Poll on this queue to get empty
|
174 |
+
while eqc.q.qsize() > 5:
|
175 |
+
sleep(1)
|
176 |
+
eqc.q.put((n, expr, expr_const, pref))
|
177 |
+
count += 1
|
178 |
+
if args.count > 0 and count >= args.count:
|
179 |
+
break
|
180 |
+
|
181 |
+
# Close queues
|
182 |
+
for eqc in eqcompilers:
|
183 |
+
eqc.q.put(QUEUE_END)
|
184 |
+
for proc in pool:
|
185 |
+
proc.join()
|
remend/compile_eqn.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!bin/bash
|
2 |
+
|
3 |
+
MODE=$1
|
4 |
+
SRC=$2
|
5 |
+
ELF=$3
|
6 |
+
OPT=$4
|
7 |
+
|
8 |
+
if [ ! -f "$SRC" ]
|
9 |
+
then
|
10 |
+
echo "Please provide source file"
|
11 |
+
exit 1
|
12 |
+
fi
|
13 |
+
|
14 |
+
if [ "$ELF" == "" ]
|
15 |
+
then
|
16 |
+
echo "Please provide elf file path"
|
17 |
+
exit 1
|
18 |
+
fi
|
19 |
+
|
20 |
+
|
21 |
+
if [ "$MODE" == "arm32-c" ]
|
22 |
+
then
|
23 |
+
arm-linux-gnueabihf-gcc $OPT $SRC -lm -o $ELF
|
24 |
+
elif [ "$MODE" == "arm32-fortran" ]
|
25 |
+
then
|
26 |
+
arm-linux-gnueabihf-gfortran -std=gnu $OPT $SRC -o $ELF
|
27 |
+
elif [ "$MODE" == "aarch64-c" ]
|
28 |
+
then
|
29 |
+
aarch64-linux-gnu-gcc $OPT $SRC -lm -o $ELF
|
30 |
+
elif [ "$MODE" == "aarch64-fortran" ]
|
31 |
+
then
|
32 |
+
aarch64-linux-gnu-gfortran -std=gnu $OPT $SRC -o $ELF
|
33 |
+
elif [ "$MODE" == "x64-c" ]
|
34 |
+
then
|
35 |
+
gcc $OPT $SRC -lm -o $ELF
|
36 |
+
elif [ "$MODE" == "x64-fortran" ]
|
37 |
+
then
|
38 |
+
gfortran -std=gnu $OPT $SRC -o $ELF
|
39 |
+
else
|
40 |
+
echo "Incorrect mode: $MODE. Choose from: {arm32,aarch64,x64}-{c,fortran}"
|
41 |
+
exit 1
|
42 |
+
fi
|
43 |
+
|
44 |
+
# arm-linux-gnueabihf-objdump --no-show-raw-insn --no-addresses -d $1.elf | sed -n -e 's/\s;\s.*$//' -e "/myfunc>:$/,/^$/p" | sed '1d;$d' | tr '\n' ' '
|
remend/convert_generated.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .parser import parse_prefix_to_sympy
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
import argparse
|
5 |
+
parser = argparse.ArgumentParser("Parse result prefix to equation")
|
6 |
+
parser.add_argument("--input", required=True, help="Input result file")
|
7 |
+
args = parser.parse_args()
|
8 |
+
|
9 |
+
res_list = []
|
10 |
+
|
11 |
+
with open(args.input, 'r') as f:
|
12 |
+
for line in f:
|
13 |
+
comps = line.strip().split("\t")
|
14 |
+
if line[0] == 'H':
|
15 |
+
num = int(comps[0][2:])
|
16 |
+
tokens = comps[2].split(" ")
|
17 |
+
res_list.append((num, tokens))
|
18 |
+
|
19 |
+
for n, toks in res_list:
|
20 |
+
try:
|
21 |
+
ex = parse_prefix_to_sympy(toks)
|
22 |
+
print(n, ex)
|
23 |
+
except Exception as e:
|
24 |
+
print(n, "could not parse:", str(e))
|
remend/deduplicate_split.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
def filter_poly(asm, eqn):
|
8 |
+
rejects = {"ln", "exp", "sin", "cos", "sqrt", "tan", "asin", "acos", "atan", "E", "pi", "cot"}
|
9 |
+
return any(t in rejects for t in asm.strip().split(" ")) \
|
10 |
+
or any(t in rejects for t in eqn.strip().split(" "))
|
11 |
+
|
12 |
+
def filter_bigint(asm, eqn):
|
13 |
+
if re.search(r"CONST=[0-9]{4,}", asm):
|
14 |
+
return True
|
15 |
+
return False
|
16 |
+
|
17 |
+
if __name__ == "__main__":
|
18 |
+
import argparse
|
19 |
+
parser = argparse.ArgumentParser("Deduplicate ASM and split files into train/test/valid")
|
20 |
+
parser.add_argument("--inprefix", required=True, help="Prefix of input files")
|
21 |
+
parser.add_argument("--outdir", required=True)
|
22 |
+
parser.add_argument("--split", type=float, default=0.05)
|
23 |
+
parser.add_argument("--seed", type=int, default=1225)
|
24 |
+
parser.add_argument("--filter", choices=["poly", "bigint"], default=None)
|
25 |
+
parser.add_argument("--no-separate-eqn", action="store_true")
|
26 |
+
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
eq_mapped = {}
|
30 |
+
combined_ds = []
|
31 |
+
asm_hash = set()
|
32 |
+
removed = 0
|
33 |
+
|
34 |
+
with open(args.inprefix + ".asm", "r") as asmf, \
|
35 |
+
open(args.inprefix + ".eqn", "r") as eqnf, \
|
36 |
+
open(args.inprefix + ".const.jsonl", "r") as constf:
|
37 |
+
for i, (asm, eqn, const) in tqdm(enumerate(zip(asmf, eqnf, constf)),
|
38 |
+
desc="Read files", leave=False):
|
39 |
+
h = hash(asm)
|
40 |
+
if h in asm_hash:
|
41 |
+
# Skip this repeated line
|
42 |
+
removed += 1
|
43 |
+
continue
|
44 |
+
|
45 |
+
if re.search(r"[0-9]\.[0-9]", eqn):
|
46 |
+
# Float not represented, remove
|
47 |
+
removed += 1
|
48 |
+
continue
|
49 |
+
|
50 |
+
if args.filter == "poly" and filter_poly(asm, eqn):
|
51 |
+
removed += 1
|
52 |
+
continue
|
53 |
+
if args.filter == "bigint" and filter_bigint(asm, eqn):
|
54 |
+
removed += 1
|
55 |
+
continue
|
56 |
+
|
57 |
+
asm_hash.add(h)
|
58 |
+
if args.no_separate_eqn:
|
59 |
+
combined_ds.append((i, asm, eqn, const))
|
60 |
+
else:
|
61 |
+
if eqn not in eq_mapped:
|
62 |
+
eq_mapped[eqn] = []
|
63 |
+
eq_mapped[eqn].append((i, asm, const))
|
64 |
+
|
65 |
+
print("Removed", removed)
|
66 |
+
|
67 |
+
if args.no_separate_eqn:
|
68 |
+
dataset = combined_ds
|
69 |
+
else:
|
70 |
+
dataset = list(eq_mapped.keys())
|
71 |
+
|
72 |
+
random.seed(args.seed)
|
73 |
+
random.shuffle(dataset)
|
74 |
+
|
75 |
+
N = len(dataset)
|
76 |
+
Ntest = int(N * args.split)
|
77 |
+
|
78 |
+
splits = {
|
79 |
+
"train": dataset[:N-2*Ntest],
|
80 |
+
"valid": dataset[N-2*Ntest:N-Ntest],
|
81 |
+
"test": dataset[N-Ntest:]
|
82 |
+
}
|
83 |
+
splitidxs = {s: [] for s in splits}
|
84 |
+
|
85 |
+
idxf = open(os.path.join(args.outdir, "splits.txt"), "w")
|
86 |
+
for s in splits:
|
87 |
+
asmfn = os.path.join(args.outdir, f"{s}.asm")
|
88 |
+
eqnfn = os.path.join(args.outdir, f"{s}.eqn")
|
89 |
+
constfn = os.path.join(args.outdir, f"{s}.const.jsonl")
|
90 |
+
with open(asmfn, "w") as asmf, open(eqnfn, "w") as eqnf, \
|
91 |
+
open(constfn, "w") as constf:
|
92 |
+
if args.no_separate_eqn:
|
93 |
+
for i, asm, eqn, const in splits[s]:
|
94 |
+
asmf.write(asm)
|
95 |
+
eqnf.write(eqn)
|
96 |
+
constf.write(const)
|
97 |
+
splitidxs[s].append(i)
|
98 |
+
else:
|
99 |
+
for eqn in splits[s]:
|
100 |
+
for i, asm, const in eq_mapped[eqn]:
|
101 |
+
asmf.write(asm)
|
102 |
+
eqnf.write(eqn)
|
103 |
+
constf.write(const)
|
104 |
+
splitidxs[s].append(i)
|
105 |
+
print("Split", s, len(splitidxs[s]))
|
106 |
+
idxf.write(f"==== {s} ====\n")
|
107 |
+
for j, i in enumerate(splitidxs[s]):
|
108 |
+
idxf.write(f"{j}: {i}\n")
|
109 |
+
idxf.write("\n")
|
110 |
+
idxf.close()
|
111 |
+
|
remend/disassemble.py
ADDED
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from capstone import *
|
2 |
+
from capstone.arm import *
|
3 |
+
from capstone.arm64 import *
|
4 |
+
from capstone.x86 import *
|
5 |
+
import cle
|
6 |
+
import struct
|
7 |
+
from math import e as CONST_E, pi as CONST_PI
|
8 |
+
import sympy as sp
|
9 |
+
|
10 |
+
from .util import DecodeError
|
11 |
+
|
12 |
+
def int2fp32(v):
|
13 |
+
if type(v) == int:
|
14 |
+
v = struct.unpack("<f", v.to_bytes(4, "little"))
|
15 |
+
v = v[0]
|
16 |
+
return v
|
17 |
+
def int2fp64(v):
|
18 |
+
if type(v) == int:
|
19 |
+
v = struct.unpack("<d", v.to_bytes(8, "little"))
|
20 |
+
v = v[0]
|
21 |
+
return v
|
22 |
+
|
23 |
+
def align4(v):
|
24 |
+
return v & (0xFFFFFFFC)
|
25 |
+
|
26 |
+
class DisassemblerBase:
|
27 |
+
def __init__(self, expr_constants={}, match_constants=False):
|
28 |
+
self.loader = None # Load in child class
|
29 |
+
self.reg_values = {}
|
30 |
+
self.constidx = 0
|
31 |
+
self.constants = {}
|
32 |
+
self.constaddrs = set()
|
33 |
+
self.expr_constants = expr_constants
|
34 |
+
self.match_constants = match_constants
|
35 |
+
|
36 |
+
def get_function_bytes(self, funcname):
|
37 |
+
func = self.loader.find_symbol(funcname)
|
38 |
+
if not func:
|
39 |
+
raise DecodeError(f"Function {funcname} not found in binary")
|
40 |
+
faddr = func.rebased_addr
|
41 |
+
if (not isinstance(self, DisassemblerX64)) and faddr % 2 == 1:
|
42 |
+
# Unaligned address, aligning
|
43 |
+
faddr = faddr - 1
|
44 |
+
fbytes = self.loader.memory.load(faddr, func.size)
|
45 |
+
self.funcrange = faddr, faddr + func.size
|
46 |
+
return faddr, fbytes
|
47 |
+
|
48 |
+
def find_constant(self, constants, value):
|
49 |
+
for ec in constants:
|
50 |
+
if abs(value - constants[ec]) < 1e-5:
|
51 |
+
return ec, ""
|
52 |
+
elif abs(1/value - constants[ec]) < 1e-5:
|
53 |
+
return ec, "1/"
|
54 |
+
elif abs(-value - constants[ec]) < 1e-5:
|
55 |
+
return ec, "-"
|
56 |
+
elif abs(-1/value - constants[ec]) < 1e-5:
|
57 |
+
return ec, "-1/"
|
58 |
+
return False
|
59 |
+
|
60 |
+
def add_constant(self, value, addr=0, size=0):
|
61 |
+
# Don't map known constants like e, pi, 0
|
62 |
+
if value == 0:
|
63 |
+
cname = "CONST=0"
|
64 |
+
elif abs(value - CONST_E) < 1e-7:
|
65 |
+
cname = "CONST=E"
|
66 |
+
elif abs(value - CONST_PI) < 1e-7:
|
67 |
+
cname = "CONST=pi"
|
68 |
+
elif self.match_constants and \
|
69 |
+
(ecmatch := self.find_constant(self.expr_constants, value)):
|
70 |
+
# Gives the name and expression of the matched constant
|
71 |
+
ecname, ecxpr = ecmatch
|
72 |
+
# print(value, ecname, ecxpr, self.expr_constants[ecname])
|
73 |
+
cname = f"{ecxpr}CSYM{ecname[1:]}"
|
74 |
+
self.constants[ecname] = value
|
75 |
+
elif size > 0 and addr in self.constaddrs and \
|
76 |
+
(smatch := self.find_constant(self.constants, value)):
|
77 |
+
sname, sxpr = smatch
|
78 |
+
cname = f"{sxpr}CSYM{sname}"
|
79 |
+
else:
|
80 |
+
rep = sp.nsimplify(value, [sp.E, sp.pi], tolerance=1e-7)
|
81 |
+
if isinstance(rep, sp.Integer) or \
|
82 |
+
(isinstance(rep, sp.Rational) and rep.q <= 16):
|
83 |
+
cname = f"CONST={rep}"
|
84 |
+
elif not self.match_constants:
|
85 |
+
cname = f"CSYM{self.constidx}"
|
86 |
+
self.constants[self.constidx] = value
|
87 |
+
self.constidx += 1
|
88 |
+
else:
|
89 |
+
raise DecodeError(f"Cannot represent unmatched float {value}")
|
90 |
+
|
91 |
+
if size > 0:
|
92 |
+
self.constaddrs |= {addr+i for i in range(size)}
|
93 |
+
return cname
|
94 |
+
|
95 |
+
def disassemble(self, function):
|
96 |
+
raise NotImplementedError("Call disassemble on child classes, not base")
|
97 |
+
|
98 |
+
|
99 |
+
class DisassemblerARM32(DisassemblerBase):
|
100 |
+
def __init__(self, binpath, expr_constants={}, match_constants=False):
|
101 |
+
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
|
102 |
+
self.md = Cs(CS_ARCH_ARM, CS_MODE_THUMB)
|
103 |
+
self.md.detail = True
|
104 |
+
self.loader = cle.Loader(binpath)
|
105 |
+
|
106 |
+
def check_mov_imm(self, insn):
|
107 |
+
if insn.id not in {ARM_INS_MOV, ARM_INS_MOVW,
|
108 |
+
ARM_INS_MOVT, ARM_INS_ADR}:
|
109 |
+
return False
|
110 |
+
ops = list(insn.operands)
|
111 |
+
if len(ops) != 2:
|
112 |
+
return False
|
113 |
+
if ops[0].type != ARM_OP_REG or ops[1].type != ARM_OP_IMM:
|
114 |
+
return False
|
115 |
+
imm = ops[1].value.imm
|
116 |
+
if imm < 0:
|
117 |
+
imm = 2**32 + imm # 2's complement
|
118 |
+
if insn.id == ARM_INS_ADR:
|
119 |
+
# Add PC value
|
120 |
+
imm += insn.address + 4
|
121 |
+
return ops[0].value.reg, imm
|
122 |
+
|
123 |
+
def check_float_store(self, insn):
|
124 |
+
if insn.id not in {ARM_INS_STR, ARM_INS_STRD}:
|
125 |
+
return False
|
126 |
+
ops = list(insn.operands)
|
127 |
+
if insn.id == ARM_INS_STRD:
|
128 |
+
dest = ops[0].value.reg
|
129 |
+
dest2 = ops[1].value.reg
|
130 |
+
if dest not in self.reg_values or dest2 not in self.reg_values:
|
131 |
+
return False
|
132 |
+
fval = int2fp64((self.reg_values[dest2]<<32) + self.reg_values[dest])
|
133 |
+
else:
|
134 |
+
dest = ops[0].value.reg
|
135 |
+
if dest not in self.reg_values:
|
136 |
+
return False
|
137 |
+
fval = int2fp32(self.reg_values[dest])
|
138 |
+
if abs(fval) < 1e-3 or abs(fval) > 100:
|
139 |
+
return False
|
140 |
+
return fval
|
141 |
+
|
142 |
+
def check_ldrd(self, insn):
|
143 |
+
if insn.id != ARM_INS_LDRD:
|
144 |
+
return False
|
145 |
+
ops = insn.op_str.split(", ")
|
146 |
+
if len(ops) != 3:
|
147 |
+
return False
|
148 |
+
mem = ops[2] # format: [<reg> + #<offset>]
|
149 |
+
if mem[0] != "[" or mem[-1] != "]":
|
150 |
+
return False
|
151 |
+
memcomps = mem[1:-1].split(" ")
|
152 |
+
if memcomps[0] == "pc":
|
153 |
+
base = align4(insn.address + 4)
|
154 |
+
else:
|
155 |
+
basereg = ARM_REG_R0 + int(memcomps[0][1:]) # Shitty hack, may malfunction
|
156 |
+
if basereg not in self.reg_values:
|
157 |
+
return False
|
158 |
+
base = align4(self.reg_values[basereg])
|
159 |
+
if len(memcomps) == 3:
|
160 |
+
offset = int(memcomps[2][1:])
|
161 |
+
else:
|
162 |
+
offset = 0
|
163 |
+
addr = base + offset
|
164 |
+
fhex = self.loader.memory.load(addr, 8)
|
165 |
+
fval = struct.unpack("d", fhex)[0]
|
166 |
+
return fval, addr, 8
|
167 |
+
|
168 |
+
def check_vldr(self, insn):
|
169 |
+
if insn.id != ARM_INS_VLDR:
|
170 |
+
return False
|
171 |
+
ops = list(insn.operands)
|
172 |
+
dest = ops[0]
|
173 |
+
if ops[1].type != ARM_OP_MEM:
|
174 |
+
return False
|
175 |
+
mem = ops[1].value.mem
|
176 |
+
if mem.base == ARM_REG_PC:
|
177 |
+
# Align4(PC) + Imm
|
178 |
+
# For whatever reason, in Thumb PC=addr+4
|
179 |
+
addr = align4(insn.address + 4) + mem.disp
|
180 |
+
elif mem.base in self.reg_values:
|
181 |
+
addr = align4(self.reg_values[mem.base]) + mem.disp
|
182 |
+
else:
|
183 |
+
return False
|
184 |
+
if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr:
|
185 |
+
# Out of bounds
|
186 |
+
return False
|
187 |
+
if dest.value.reg >= ARM_REG_D0 and dest.value.reg <= ARM_REG_D31:
|
188 |
+
size = 8
|
189 |
+
fhex = self.loader.memory.load(addr, 8)
|
190 |
+
fval = struct.unpack("d", fhex)[0]
|
191 |
+
else:
|
192 |
+
size = 4
|
193 |
+
fhex = self.loader.memory.load(addr, 4)
|
194 |
+
fval = struct.unpack("f", fhex)[0]
|
195 |
+
return fval, addr, size
|
196 |
+
|
197 |
+
def check_vmov(self, insn):
|
198 |
+
# fconsts/d == vmov.f32/f64 (old/new names)
|
199 |
+
if insn.id not in {ARM_INS_FCONSTS, ARM_INS_FCONSTD}:
|
200 |
+
return False
|
201 |
+
ops = list(insn.operands)
|
202 |
+
if len(ops) != 2 or ops[1].type != ARM_OP_FP:
|
203 |
+
return False
|
204 |
+
fval = ops[1].value.fp
|
205 |
+
destname = insn.reg_name(ops[0].value.reg)
|
206 |
+
asm = f"{insn.mnemonic} {destname}, {fval}"
|
207 |
+
return asm, fval
|
208 |
+
|
209 |
+
def check_branch_symbol(self, insn):
|
210 |
+
if insn.id not in {ARM_INS_B, ARM_INS_BL, ARM_INS_BLX}:
|
211 |
+
return False
|
212 |
+
ops = list(insn.operands)
|
213 |
+
if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
|
214 |
+
return False
|
215 |
+
addr = ops[0].value.imm
|
216 |
+
if addr > self.funcrange[0] and addr < self.funcrange[1]:
|
217 |
+
# Self-branch
|
218 |
+
func = f"SELF+{hex(addr - self.funcrange[0])}"
|
219 |
+
else:
|
220 |
+
func = self.loader.find_plt_stub_name(addr)
|
221 |
+
if func is None:
|
222 |
+
# Some tail call optimized PLT stubs have extra instructions
|
223 |
+
# that are not identified by CLE, so check with offset of 4 also.
|
224 |
+
func = self.loader.find_plt_stub_name(addr + 4)
|
225 |
+
if func is None:
|
226 |
+
return False
|
227 |
+
asm = f"{insn.mnemonic} <{func}>"
|
228 |
+
return asm
|
229 |
+
|
230 |
+
def get_function_bytes(self, funcname):
|
231 |
+
func = self.loader.find_symbol(funcname)
|
232 |
+
if not func:
|
233 |
+
raise DecodeError(f"Function {funcname} not found in binary")
|
234 |
+
faddr = func.rebased_addr
|
235 |
+
if faddr % 2 == 1:
|
236 |
+
# Unaligned address, aligning
|
237 |
+
faddr = faddr - 1
|
238 |
+
fbytes = self.loader.memory.load(faddr, func.size)
|
239 |
+
self.funcrange = faddr, faddr + func.size
|
240 |
+
return faddr, fbytes
|
241 |
+
|
242 |
+
def disassemble(self, funcname):
|
243 |
+
funcaddr, funcbytes = self.get_function_bytes(funcname)
|
244 |
+
disassm = []
|
245 |
+
|
246 |
+
for insn in self.md.disasm(funcbytes, funcaddr):
|
247 |
+
if insn.address in self.constaddrs:
|
248 |
+
# Skip if this is a constant value and not instruction
|
249 |
+
continue
|
250 |
+
|
251 |
+
cname = None
|
252 |
+
asm = None
|
253 |
+
|
254 |
+
if vldr := self.check_vldr(insn):
|
255 |
+
fval, faddr, fsize = vldr
|
256 |
+
cname = self.add_constant(fval, faddr, fsize)
|
257 |
+
elif ldrd := self.check_ldrd(insn):
|
258 |
+
fval, faddr, fsize = ldrd
|
259 |
+
cname = self.add_constant(fval, faddr, fsize)
|
260 |
+
elif strfloat := self.check_float_store(insn):
|
261 |
+
fval = strfloat
|
262 |
+
cname = self.add_constant(fval)
|
263 |
+
elif vmovfloat := self.check_vmov(insn):
|
264 |
+
asm, fval = vmovfloat
|
265 |
+
cname = self.add_constant(fval)
|
266 |
+
elif branch := self.check_branch_symbol(insn):
|
267 |
+
asm = branch
|
268 |
+
|
269 |
+
# Maintain values of immediate moves.
|
270 |
+
# Needs to be done after processing current instruction.
|
271 |
+
if movimm := self.check_mov_imm(insn):
|
272 |
+
reg, imm = movimm
|
273 |
+
if insn.id == ARM_INS_MOVT:
|
274 |
+
if reg not in self.reg_values:
|
275 |
+
self.reg_values[reg] = 0
|
276 |
+
self.reg_values[reg] += imm << 16
|
277 |
+
else:
|
278 |
+
self.reg_values[reg] = imm
|
279 |
+
else:
|
280 |
+
reads, writes = insn.regs_access()
|
281 |
+
for r in writes:
|
282 |
+
# Remove this reg if written to
|
283 |
+
if r in self.reg_values:
|
284 |
+
del self.reg_values[r]
|
285 |
+
|
286 |
+
if not asm:
|
287 |
+
asm = f"{insn.mnemonic} {insn.op_str}"
|
288 |
+
if cname:
|
289 |
+
asm += f", {cname}"
|
290 |
+
disassm.append(asm)
|
291 |
+
|
292 |
+
fulldiss = "; ".join(disassm)
|
293 |
+
return fulldiss
|
294 |
+
|
295 |
+
class DisassemblerAArch64(DisassemblerBase):
|
296 |
+
def __init__(self, binpath, expr_constants={}, match_constants=False):
|
297 |
+
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
|
298 |
+
self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM)
|
299 |
+
self.md.detail = True
|
300 |
+
self.loader = cle.Loader(binpath)
|
301 |
+
|
302 |
+
def reg_size_type(self, reg):
|
303 |
+
# Bit width and datatype of register
|
304 |
+
if reg >= ARM64_REG_W0 and reg <= ARM64_REG_W30:
|
305 |
+
return 32, int
|
306 |
+
elif reg >= ARM64_REG_X0 and reg <= ARM64_REG_X30:
|
307 |
+
return 64, int
|
308 |
+
elif reg >= ARM64_REG_S0 and reg <= ARM64_REG_S31:
|
309 |
+
return 32, float
|
310 |
+
elif reg >= ARM64_REG_D0 and reg <= ARM64_REG_D31:
|
311 |
+
return 64, float
|
312 |
+
return 0, None
|
313 |
+
|
314 |
+
def check_mov_imm(self, insn):
|
315 |
+
if insn.id not in {ARM64_INS_ADRP, ARM64_INS_ADR, ARM64_INS_MOV, ARM64_INS_MOVK}:
|
316 |
+
return False
|
317 |
+
|
318 |
+
ops = insn.operands
|
319 |
+
if len(ops) != 2:
|
320 |
+
return False
|
321 |
+
if ops[0].type != ARM64_OP_REG or ops[1].type != ARM64_OP_IMM:
|
322 |
+
return False
|
323 |
+
|
324 |
+
imm = ops[1].value.imm
|
325 |
+
if ops[1].shift.type == 1: # LSL
|
326 |
+
imm <<= ops[1].shift.value
|
327 |
+
mask = 0xFFFF << ops[1].shift.value
|
328 |
+
|
329 |
+
if insn.id == ARM64_INS_ADRP:
|
330 |
+
# imm -= 0x400000 # Subtract global offset for some reason
|
331 |
+
# imm = ((insn.address + 4) & (~4095)) + imm
|
332 |
+
# Really confused about this, maybe I can use the imm directly
|
333 |
+
pass
|
334 |
+
elif insn.id == ARM64_INS_ADR:
|
335 |
+
imm -= 0x400000 # Subtract global offset for some reason
|
336 |
+
imm += insn.address + 4
|
337 |
+
elif insn.id == ARM64_INS_MOVK:
|
338 |
+
# load previous reg value
|
339 |
+
if ops[0].value.reg in self.reg_values:
|
340 |
+
curr = self.reg_values[ops[0].value.reg]
|
341 |
+
imm = (imm & mask) | (curr & (~mask))
|
342 |
+
|
343 |
+
return ops[0].value.reg, imm
|
344 |
+
|
345 |
+
def check_fmov(self, insn):
|
346 |
+
if insn.id != ARM64_INS_FMOV:
|
347 |
+
return False
|
348 |
+
ops = insn.operands
|
349 |
+
if len(ops) != 2: # or ops[1].type != ARM64_OP_FP:
|
350 |
+
return False
|
351 |
+
|
352 |
+
destsize, _ = self.reg_size_type(ops[0].value.reg)
|
353 |
+
destname = insn.reg_name(ops[0].value.reg)
|
354 |
+
if ops[1].type == ARM64_OP_FP:
|
355 |
+
fval = ops[1].value.fp
|
356 |
+
asm = f"{insn.mnemonic} {destname}, {fval}"
|
357 |
+
elif ops[1].type == ARM64_OP_REG:
|
358 |
+
reg = ops[1].value.reg
|
359 |
+
if reg not in self.reg_values:
|
360 |
+
return False
|
361 |
+
# TODO datatype
|
362 |
+
fhex = self.reg_values[reg]
|
363 |
+
if destsize == 64:
|
364 |
+
if fhex < 0:
|
365 |
+
fhex += 2**64
|
366 |
+
fval = int2fp64(fhex)
|
367 |
+
elif destsize == 32:
|
368 |
+
if fhex < 0:
|
369 |
+
fhex += 2**32
|
370 |
+
fval = int2fp32(fhex)
|
371 |
+
else:
|
372 |
+
return False
|
373 |
+
|
374 |
+
if abs(fval) < 1e-5 or abs(fval) > 1e5:
|
375 |
+
return False
|
376 |
+
asm = f"{insn.mnemonic} {insn.op_str}"
|
377 |
+
return asm, fval
|
378 |
+
|
379 |
+
def check_ldr(self, insn):
|
380 |
+
if insn.id != ARM64_INS_LDR:
|
381 |
+
return False
|
382 |
+
ops = insn.op_str[:-1].split(", ")
|
383 |
+
destsize, desttype = self.reg_size_type(insn.operands[0].value.reg)
|
384 |
+
if len(ops) < 2 or desttype != float:
|
385 |
+
return False
|
386 |
+
reg = ops[1]
|
387 |
+
if reg[0] != "[" or "sp" in reg:
|
388 |
+
return False
|
389 |
+
basereg = ARM64_REG_X0 + int(reg[2:]) # Shitty hack, may malfunction
|
390 |
+
if basereg not in self.reg_values:
|
391 |
+
return False
|
392 |
+
base = align4(self.reg_values[basereg])
|
393 |
+
if len(ops) == 3:
|
394 |
+
offset = ops[2][1:]
|
395 |
+
if offset.startswith("0x"):
|
396 |
+
offset = int(offset[2:], base=16)
|
397 |
+
else:
|
398 |
+
offset = int(offset)
|
399 |
+
else:
|
400 |
+
offset = 0
|
401 |
+
addr = base + offset
|
402 |
+
if destsize == 64:
|
403 |
+
fhex = self.loader.memory.load(addr, 8)
|
404 |
+
fval = struct.unpack("d", fhex)[0]
|
405 |
+
return fval, addr, 8
|
406 |
+
elif destsize == 32:
|
407 |
+
fhex = self.loader.memory.load(addr, 4)
|
408 |
+
fval = struct.unpack("f", fhex)[0]
|
409 |
+
return fval, addr, 4
|
410 |
+
else:
|
411 |
+
return False
|
412 |
+
|
413 |
+
|
414 |
+
def check_branch_symbol(self, insn):
|
415 |
+
if insn.id not in {ARM64_INS_BL, ARM64_INS_B}:
|
416 |
+
return False
|
417 |
+
ops = insn.operands
|
418 |
+
if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
|
419 |
+
return False
|
420 |
+
addr = ops[0].value.imm
|
421 |
+
if addr > self.funcrange[0] and addr < self.funcrange[1]:
|
422 |
+
# Self-branch
|
423 |
+
func = f"SELF+{hex(addr - self.funcrange[0])}"
|
424 |
+
else:
|
425 |
+
func = self.loader.find_plt_stub_name(addr)
|
426 |
+
if func is None:
|
427 |
+
# Some tail call optimized PLT stubs have extra instructions
|
428 |
+
# that are not identified by CLE, so check with offset of 4 also.
|
429 |
+
func = self.loader.find_plt_stub_name(addr + 4)
|
430 |
+
if func is None:
|
431 |
+
return False
|
432 |
+
asm = f"{insn.mnemonic} <{func}>"
|
433 |
+
return asm
|
434 |
+
|
435 |
+
def disassemble(self, funcname):
|
436 |
+
funcaddr, funcbytes = self.get_function_bytes(funcname)
|
437 |
+
disassm = []
|
438 |
+
|
439 |
+
for insn in self.md.disasm(funcbytes, funcaddr):
|
440 |
+
if insn.address in self.constaddrs:
|
441 |
+
# Skip if this is a constant value and not instruction
|
442 |
+
continue
|
443 |
+
|
444 |
+
cname = None
|
445 |
+
asm = None
|
446 |
+
# Maintain values of immediate moves
|
447 |
+
if movimm := self.check_mov_imm(insn):
|
448 |
+
reg, imm = movimm
|
449 |
+
self.reg_values[reg] = imm
|
450 |
+
else:
|
451 |
+
reads, writes = insn.regs_access()
|
452 |
+
for r in writes:
|
453 |
+
# Remove this reg if written to
|
454 |
+
if r in self.reg_values:
|
455 |
+
del self.reg_values[r]
|
456 |
+
|
457 |
+
if fmov := self.check_fmov(insn):
|
458 |
+
asm, fval = fmov
|
459 |
+
cname = self.add_constant(fval)
|
460 |
+
elif ldr := self.check_ldr(insn):
|
461 |
+
fval, faddr, fsize = ldr
|
462 |
+
cname = self.add_constant(fval, faddr, fsize)
|
463 |
+
elif branch := self.check_branch_symbol(insn):
|
464 |
+
asm = branch
|
465 |
+
|
466 |
+
if not asm:
|
467 |
+
asm = f"{insn.mnemonic} {insn.op_str}"
|
468 |
+
if cname:
|
469 |
+
asm += f", {cname}"
|
470 |
+
disassm.append(asm)
|
471 |
+
|
472 |
+
fulldiss = "; ".join(disassm)
|
473 |
+
return fulldiss
|
474 |
+
|
475 |
+
class DisassemblerX64(DisassemblerBase):
|
476 |
+
def __init__(self, binpath, expr_constants={}, match_constants=False):
|
477 |
+
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
|
478 |
+
self.md = Cs(CS_ARCH_X86, CS_MODE_64)
|
479 |
+
self.md.detail = True
|
480 |
+
self.loader = cle.Loader(binpath)
|
481 |
+
|
482 |
+
def check_call_symbol(self, insn):
|
483 |
+
if insn.id != X86_INS_CALL:
|
484 |
+
return False
|
485 |
+
ops = insn.operands
|
486 |
+
# TODO check this ARM_OP
|
487 |
+
if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
|
488 |
+
return False
|
489 |
+
addr = ops[0].value.imm
|
490 |
+
func = self.loader.find_plt_stub_name(addr)
|
491 |
+
if func is None:
|
492 |
+
return False
|
493 |
+
asm = f"{insn.mnemonic} <{func}>"
|
494 |
+
return asm
|
495 |
+
|
496 |
+
def check_fload(self, insn):
|
497 |
+
# Cannot rely on ID because any instruction
|
498 |
+
# can access memory.
|
499 |
+
ops = insn.operands
|
500 |
+
memops = [op for op in ops
|
501 |
+
if (op.type == X86_OP_MEM and
|
502 |
+
op.value.mem.base == X86_REG_RIP)]
|
503 |
+
if len(memops) != 1:
|
504 |
+
return False
|
505 |
+
mem, size = memops[0].value.mem, memops[0].size
|
506 |
+
if size > 8:
|
507 |
+
return False
|
508 |
+
addr = insn.address + insn.size + mem.disp
|
509 |
+
fhex = self.loader.memory.load(addr, size)
|
510 |
+
fval = struct.unpack("f" if size == 4 else "d", fhex)[0]
|
511 |
+
return fval, addr, size
|
512 |
+
|
513 |
+
def disassemble(self, funcname):
|
514 |
+
funcaddr, funcbytes = self.get_function_bytes(funcname)
|
515 |
+
disassm = []
|
516 |
+
|
517 |
+
for insn in self.md.disasm(funcbytes, funcaddr):
|
518 |
+
asm = None
|
519 |
+
cname = None
|
520 |
+
if fload := self.check_fload(insn):
|
521 |
+
fval, faddr, fsize = fload
|
522 |
+
cname = self.add_constant(fval, faddr, fsize)
|
523 |
+
elif call := self.check_call_symbol(insn):
|
524 |
+
asm = call
|
525 |
+
|
526 |
+
if not asm:
|
527 |
+
asm = f"{insn.mnemonic} {insn.op_str}"
|
528 |
+
if cname:
|
529 |
+
asm += f", {cname}"
|
530 |
+
disassm.append(asm)
|
531 |
+
|
532 |
+
fulldiss = "; ".join(disassm)
|
533 |
+
return fulldiss
|
534 |
+
|
535 |
+
|
536 |
+
# Regular
|
537 |
+
if __name__ == "__main__":
|
538 |
+
import argparse
|
539 |
+
parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
|
540 |
+
parser.add_argument("--bin", required=True)
|
541 |
+
parser.add_argument("--func", required=True)
|
542 |
+
parser.add_argument("--arch", required=True)
|
543 |
+
args = parser.parse_args()
|
544 |
+
|
545 |
+
if args.arch == "arm32":
|
546 |
+
D = DisassemblerARM32(args.bin)
|
547 |
+
elif args.arch == "aarch64":
|
548 |
+
D = DisassemblerAArch64(args.bin)
|
549 |
+
elif args.arch == "x64":
|
550 |
+
D = DisassemblerX64(args.bin)
|
551 |
+
diss = D.disassemble(args.func)
|
552 |
+
print(diss)
|
553 |
+
print(D.constants)
|
remend/edit_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
removed = ["encoder.layers.0.in_proj_weight", "encoder.layers.0.in_proj_bias", "encoder.layers.0.out_proj_weight", "encoder.layers.0.out_proj_bias", "encoder.layers.0.fc1_weight", "encoder.layers.0.fc1_bias", "encoder.layers.0.fc2_weight", "encoder.layers.0.fc2_bias", "encoder.layers.1.in_proj_weight", "encoder.layers.1.in_proj_bias", "encoder.layers.1.out_proj_weight", "encoder.layers.1.out_proj_bias", "encoder.layers.1.fc1_weight", "encoder.layers.1.fc1_bias", "encoder.layers.1.fc2_weight", "encoder.layers.1.fc2_bias", "encoder.layers.2.in_proj_weight", "encoder.layers.2.in_proj_bias", "encoder.layers.2.out_proj_weight", "encoder.layers.2.out_proj_bias", "encoder.layers.2.fc1_weight", "encoder.layers.2.fc1_bias", "encoder.layers.2.fc2_weight", "encoder.layers.2.fc2_bias", "encoder.layers.3.in_proj_weight", "encoder.layers.3.in_proj_bias", "encoder.layers.3.out_proj_weight", "encoder.layers.3.out_proj_bias", "encoder.layers.3.fc1_weight", "encoder.layers.3.fc1_bias", "encoder.layers.3.fc2_weight", "encoder.layers.3.fc2_bias", "encoder.layers.4.in_proj_weight", "encoder.layers.4.in_proj_bias", "encoder.layers.4.out_proj_weight", "encoder.layers.4.out_proj_bias", "encoder.layers.4.fc1_weight", "encoder.layers.4.fc1_bias", "encoder.layers.4.fc2_weight", "encoder.layers.4.fc2_bias", "encoder.layers.5.in_proj_weight", "encoder.layers.5.in_proj_bias", "encoder.layers.5.out_proj_weight", "encoder.layers.5.out_proj_bias", "encoder.layers.5.fc1_weight", "encoder.layers.5.fc1_bias", "encoder.layers.5.fc2_weight", "encoder.layers.5.fc2_bias"]
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser("Edit the checkpoint to remove extra dict weights")
|
8 |
+
parser.add_argument("-c", "--checkpoint", required=True, help="Input checkpoint")
|
9 |
+
parser.add_argument("-e", "--edited", required=True, help="Edited checkpoint")
|
10 |
+
args = parser.parse_args()
|
11 |
+
|
12 |
+
sd = torch.load(args.checkpoint, weights_only=False)
|
13 |
+
for k in removed:
|
14 |
+
if k in sd['model']:
|
15 |
+
del sd['model'][k]
|
16 |
+
torch.save(sd, args.edited)
|
remend/eval_generated.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sympy as sp
|
2 |
+
import numpy as np
|
3 |
+
import warnings
|
4 |
+
from sympy.abc import x
|
5 |
+
import sys
|
6 |
+
import json
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from .parser import parse_prefix_to_sympy, isint
|
10 |
+
|
11 |
+
# Ignore sympy lambda warnings.
|
12 |
+
warnings.simplefilter("ignore")
|
13 |
+
|
14 |
+
def percent(a, n):
|
15 |
+
return f"{a/n*100:0.1f}%"
|
16 |
+
|
17 |
+
def do_eval_match(orig_expr, gen_expr):
|
18 |
+
try:
|
19 |
+
origl = sp.lambdify(x, orig_expr)
|
20 |
+
genl = sp.lambdify(x, gen_expr)
|
21 |
+
count = 0
|
22 |
+
|
23 |
+
for v in np.arange(0.2, 1, 0.01):
|
24 |
+
o = origl(v)
|
25 |
+
g = genl(v)
|
26 |
+
if o == float('nan') or o == float('inf'):
|
27 |
+
continue
|
28 |
+
if g == float('nan') or g == float('inf'):
|
29 |
+
continue
|
30 |
+
# if type(o) != np.float64 or type(g) != np.float64:
|
31 |
+
# print(orig_expr, o, gen_expr, g)
|
32 |
+
# return False
|
33 |
+
if abs((o-g)/o) > 1e-5:
|
34 |
+
return False
|
35 |
+
count += 1
|
36 |
+
except:
|
37 |
+
return False
|
38 |
+
return count >= 5
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
import argparse
|
42 |
+
parser = argparse.ArgumentParser("Check generated expressions")
|
43 |
+
parser.add_argument("-g", required=True, help="Generated expressions file")
|
44 |
+
parser.add_argument("-c", required=True, help="Constants file")
|
45 |
+
parser.add_argument("-e", required=True, help="Equations file")
|
46 |
+
parser.add_argument("-r", required=True, help="Results file")
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
gens = []
|
50 |
+
with open(args.g, 'r') as genf, open(args.c) as constf, open(args.e) as eqnf:
|
51 |
+
for line in tqdm(genf, desc="Reading file"):
|
52 |
+
comps = line.strip().split("\t")
|
53 |
+
if line[0] == 'H':
|
54 |
+
num = int(comps[0][2:])
|
55 |
+
tokens = comps[2].split(" ")
|
56 |
+
eqn = next(eqnf)
|
57 |
+
const = next(constf)
|
58 |
+
const = json.loads(const.strip())
|
59 |
+
gens.append((num, tokens, eqn.strip(), const))
|
60 |
+
|
61 |
+
parsed = []
|
62 |
+
matched = []
|
63 |
+
results = []
|
64 |
+
|
65 |
+
for n, toks, eqn, const in tqdm(gens, desc="Evaluating expressions"):
|
66 |
+
res = {"id": n, "parsed": False, "matched": False, "orig": "", "gen": ""}
|
67 |
+
if "<<unk>>" in toks:
|
68 |
+
# Not parsed
|
69 |
+
results.append(res)
|
70 |
+
continue
|
71 |
+
try:
|
72 |
+
gen_expr = parse_prefix_to_sympy(toks)
|
73 |
+
except Exception as e:
|
74 |
+
# Not parsed
|
75 |
+
results.append(res)
|
76 |
+
continue
|
77 |
+
|
78 |
+
res["parsed"] = True
|
79 |
+
parsed.append(n)
|
80 |
+
|
81 |
+
gen_expr = gen_expr.subs([(sp.Symbol("k"+c), const[c]) for c in const])
|
82 |
+
orig_expr = sp.parse_expr(eqn, local_dict={"x0":x})
|
83 |
+
res["orig"] = str(orig_expr)
|
84 |
+
res["gen"] = str(gen_expr)
|
85 |
+
|
86 |
+
if not do_eval_match(orig_expr, gen_expr):
|
87 |
+
results.append(res)
|
88 |
+
continue
|
89 |
+
res["matched"] = True
|
90 |
+
matched.append(n)
|
91 |
+
results.append(res)
|
92 |
+
|
93 |
+
with open(args.r, "w") as resf:
|
94 |
+
for res in results:
|
95 |
+
resf.write("{id} {parsed} {matched} \"{orig}\" \"{gen}\"\n".format(**res))
|
96 |
+
resf.write("\n")
|
97 |
+
N = len(gens)
|
98 |
+
print("Total", N, file=resf)
|
99 |
+
print("Parsed", len(parsed), percent(len(parsed), N), file=resf)
|
100 |
+
print("Matched", len(matched), percent(len(matched), N), file=resf)
|
remend/experiment.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sympy as sp
|
2 |
+
import random
|
3 |
+
|
4 |
+
from .parser import parse_prefix_to_sympy
|
5 |
+
|
6 |
+
isconst = lambda e: not any(c.is_symbol for c in e.atoms())
|
7 |
+
def constfold(expr):
|
8 |
+
q = [expr]
|
9 |
+
cidx = 0
|
10 |
+
subsmap = {}
|
11 |
+
constmap = {}
|
12 |
+
|
13 |
+
while len(q) > 0:
|
14 |
+
curr_expr = q.pop(0)
|
15 |
+
if isinstance(curr_expr, sp.Number) or isconst(e):
|
16 |
+
const_expr = curr_expr.evalf()
|
17 |
+
rep = sp.nsimplify(const_expr, [sp.E, sp.pi])
|
18 |
+
if isinstance(rep, sp.Integer) or \
|
19 |
+
(isinstance(rep, sp.Rational) and rep.q <= 16):
|
20 |
+
subsmap[curr_expr] = rep
|
21 |
+
else:
|
22 |
+
subsmap[curr_expr] = sp.Symbol(f"k{cidx}")
|
23 |
+
constmap[f"k{cidx}"] = float(const_expr)
|
24 |
+
cidx += 1
|
25 |
+
else:
|
26 |
+
for child in curr_expr.args:
|
27 |
+
q.append(child)
|
28 |
+
|
29 |
+
return expr.subs(subsmap), constmap
|
30 |
+
|
31 |
+
def replace_const(expr):
|
32 |
+
cidx = 0
|
33 |
+
subsmap = {}
|
34 |
+
constmap = {}
|
35 |
+
|
36 |
+
for c in sp.preorder_traversal(expr):
|
37 |
+
if isinstance(c, sp.Float):
|
38 |
+
rep = sp.nsimplify(c)
|
39 |
+
if isinstance(rep, sp.Integer) or \
|
40 |
+
(isinstance(rep, sp.Rational) and rep.q <= 16):
|
41 |
+
subsmap[c] = rep
|
42 |
+
else:
|
43 |
+
subsmap[c] = sp.Symbol(f"c{cidx}")
|
44 |
+
constmap[f"c{cidx}"] = float(c)
|
45 |
+
cidx += 1
|
46 |
+
return expr.subs(subsmap), constmap
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
import argparse
|
51 |
+
parser = argparse.ArgumentParser("Random experiments")
|
52 |
+
parser.add_argument("-f", required=True)
|
53 |
+
parser.add_argument("-p", type=float, default=0.1)
|
54 |
+
parser.add_argument("-n", type=int, default=20)
|
55 |
+
args = parser.parse_args()
|
56 |
+
|
57 |
+
random.seed(1225)
|
58 |
+
|
59 |
+
count = 0
|
60 |
+
with open(args.f, "r") as f:
|
61 |
+
for line in f:
|
62 |
+
if random.random() > args.p:
|
63 |
+
continue
|
64 |
+
|
65 |
+
prefl = line.strip().split(" ")
|
66 |
+
|
67 |
+
orig = parse_prefix_to_sympy(prefl)
|
68 |
+
# simp = sp.simplify(expr)
|
69 |
+
expr = constfold(orig)
|
70 |
+
expr, consts = replace_const(expr)
|
71 |
+
print(orig, expr, consts)
|
72 |
+
|
73 |
+
count += 1
|
74 |
+
if count == args.n:
|
75 |
+
break
|
remend/find_duplicates.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from tqdm import tqdm
|
3 |
+
from Levenshtein import distance
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
import argparse
|
7 |
+
parser = argparse.ArgumentParser("Find duplicates in the dataset ASM")
|
8 |
+
parser.add_argument("--train", required=True)
|
9 |
+
# parser.add_argument("--valid", required=True)
|
10 |
+
parser.add_argument("--test", required=True)
|
11 |
+
parser.add_argument("--result", required=False)
|
12 |
+
parser.add_argument("--distance", action="store_true", default=False)
|
13 |
+
args = parser.parse_args()
|
14 |
+
|
15 |
+
train = []
|
16 |
+
train_hash = {}
|
17 |
+
# valid = []
|
18 |
+
test = []
|
19 |
+
with open(args.train, "r") as tf:
|
20 |
+
for idx, line in tqdm(enumerate(tf), desc="Read train", leave=False):
|
21 |
+
train_hash[hash(line)] = idx
|
22 |
+
comps = line.strip().split(" ")
|
23 |
+
train.append(comps)
|
24 |
+
# with open(args.valid, "r") as tf:
|
25 |
+
# for line in tqdm(tf, desc="Read valid", leave=False):
|
26 |
+
# valid.append(line.strip().split(" "))
|
27 |
+
with open(args.test, "r") as tf:
|
28 |
+
for line in tqdm(tf, desc="Read test", leave=False):
|
29 |
+
test.append(line)
|
30 |
+
|
31 |
+
selfcheck = args.train == args.test
|
32 |
+
if args.result:
|
33 |
+
rf = open(args.result, "w")
|
34 |
+
searchdist = args.distance
|
35 |
+
else:
|
36 |
+
searchdist = False # Dont compute if no result file
|
37 |
+
rf = None
|
38 |
+
|
39 |
+
def reswrite(s):
|
40 |
+
if rf:
|
41 |
+
rf.write(s)
|
42 |
+
|
43 |
+
exact = 0
|
44 |
+
for i, testline in tqdm(enumerate(test), desc="Test", total=len(test)):
|
45 |
+
testl = testline.strip().split(" ")
|
46 |
+
htest = hash(testline)
|
47 |
+
if htest in train_hash:
|
48 |
+
# Found exact match
|
49 |
+
j = train_hash[htest]
|
50 |
+
if not selfcheck or j != i:
|
51 |
+
exact += 1
|
52 |
+
reswrite(f"{i} {j} 0 0.0\n")
|
53 |
+
continue
|
54 |
+
|
55 |
+
# If not, then search
|
56 |
+
if searchdist:
|
57 |
+
minavgdist, mindist, minj = 100, 100, -1
|
58 |
+
for j, trainl in enumerate(train):
|
59 |
+
if abs(len(trainl) - len(testl)) > 10:
|
60 |
+
dist = abs(len(trainl) - len(testl)) * 2 # HACK to speed it up
|
61 |
+
else:
|
62 |
+
dist = distance(trainl, testl)
|
63 |
+
avgdist = dist / (len(trainl) + len(testl))
|
64 |
+
if mindist > dist:
|
65 |
+
minavgdist, mindist, minj = avgdist, dist, j
|
66 |
+
|
67 |
+
reswrite(f"{i} {minj} {mindist} {minavgdist}\n")
|
68 |
+
|
69 |
+
print("Exact duplicates:", exact)
|
remend/implementation.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sympy as sp
|
2 |
+
from sympy.codegen import ast
|
3 |
+
import itertools as it
|
4 |
+
import networkx as nx
|
5 |
+
|
6 |
+
from .parser import OPERATORS, sympy_to_dag
|
7 |
+
from .util import DecodeError
|
8 |
+
|
9 |
+
def isnum(s):
|
10 |
+
try:
|
11 |
+
float(s)
|
12 |
+
return True
|
13 |
+
except ValueError:
|
14 |
+
return False
|
15 |
+
|
16 |
+
class Implementor:
|
17 |
+
def __init__(self, expr, constants={}, dtype="double"):
|
18 |
+
self.expr = expr
|
19 |
+
self.constants = constants
|
20 |
+
self.cdtype = dtype
|
21 |
+
self.cpf = "lf" if dtype == "double" else "f"
|
22 |
+
self.fdtype = "double precision" if dtype == "double" else "real"
|
23 |
+
|
24 |
+
def implement(self, impl):
|
25 |
+
if impl == "dag_c":
|
26 |
+
return self.dag_to_c_impl()
|
27 |
+
elif impl == "cse_c":
|
28 |
+
return self.sympy_cse_c_impl()
|
29 |
+
elif impl == "dag_fortran":
|
30 |
+
return self.dag_to_fortran_impl()
|
31 |
+
elif impl == "cse_fortran":
|
32 |
+
return self.sympy_cse_fortran_impl()
|
33 |
+
|
34 |
+
|
35 |
+
def op_c_impl(self, f, children):
|
36 |
+
if f == "add":
|
37 |
+
return " + ".join(children);
|
38 |
+
elif f == "mul":
|
39 |
+
return " * ".join(children);
|
40 |
+
elif f == "pow":
|
41 |
+
assert len(children) == 2
|
42 |
+
if self.cdtype == "double":
|
43 |
+
return f"pow({children[0]}, {children[1]})"
|
44 |
+
else:
|
45 |
+
return f"powf({children[0]}, {children[1]})"
|
46 |
+
elif f == "ln":
|
47 |
+
assert len(children) == 1
|
48 |
+
if self.cdtype == "double":
|
49 |
+
return f"log({children[0]})"
|
50 |
+
else:
|
51 |
+
return f"logf({children[0]})"
|
52 |
+
else:
|
53 |
+
if f in OPERATORS and OPERATORS[f][1] == 1:
|
54 |
+
assert len(children) == 1
|
55 |
+
if self.cdtype == "double":
|
56 |
+
return f"{f}({children[0]})"
|
57 |
+
else:
|
58 |
+
return f"{f}f({children[0]})"
|
59 |
+
else:
|
60 |
+
raise DecodeError(f"C impl: operation {f} not handled")
|
61 |
+
|
62 |
+
def op_f_impl(self, f, children):
|
63 |
+
if f == "add":
|
64 |
+
j = ")+(".join(children)
|
65 |
+
return "(" + j + ")"
|
66 |
+
elif f == "mul":
|
67 |
+
j = ")*(".join(children)
|
68 |
+
return "(" + j + ")"
|
69 |
+
elif f == "pow":
|
70 |
+
assert len(children) == 2
|
71 |
+
return f"({children[0]})**({children[1]})"
|
72 |
+
elif f == "ln":
|
73 |
+
assert len(children) == 1
|
74 |
+
return f"log({children[0]})"
|
75 |
+
else:
|
76 |
+
if f in OPERATORS and OPERATORS[f][1] == 1:
|
77 |
+
assert len(children) == 1
|
78 |
+
return f"{f}({children[0]})"
|
79 |
+
else:
|
80 |
+
raise DecodeError(f"F impl: operation {f} not handled")
|
81 |
+
|
82 |
+
def full_c_code(self, body):
|
83 |
+
pre = f"#include <stdio.h>\n#include <math.h>\n{self.cdtype} myfunc({self.cdtype} x) {{"
|
84 |
+
post = f"}}\nint main() {{ {self.cdtype} x; scanf(\"%{self.cpf}\", &x); printf(\"%{self.cpf}\", myfunc(x)); }}"
|
85 |
+
return f"{pre}\n{body}\n{post}"
|
86 |
+
|
87 |
+
def full_f_code(self, body):
|
88 |
+
pre = "function myfunc(x) result(y)\nimplicit none\n" + \
|
89 |
+
f"{self.fdtype}, intent(in) :: x\n{self.fdtype} :: y, E, pi\n"
|
90 |
+
post = "end function myfunc\nprogram main\nimplicit none\n" + \
|
91 |
+
f"{self.fdtype} :: x\n{self.fdtype} :: myfunc\n" + \
|
92 |
+
"read(*, *) x\nprint *, \"y is:\", myfunc(x)\nend program main"
|
93 |
+
return f"{pre}\n{body}\n{post}"
|
94 |
+
|
95 |
+
def dag_to_c_impl(self):
|
96 |
+
dag = sympy_to_dag(self.expr, csuf="F" if self.cdtype == "float" else "")
|
97 |
+
cstr = ""
|
98 |
+
added_pi, added_E = False, False
|
99 |
+
for c in self.constants:
|
100 |
+
cstr += f"{self.cdtype} {c} = {self.constants[c]};\n"
|
101 |
+
varidx = it.count()
|
102 |
+
for node in reversed(list(nx.topological_sort(dag))):
|
103 |
+
label = dag.nodes[node]["label"]
|
104 |
+
children = [dag.nodes[n]["var"] for n in dag.adj[node]]
|
105 |
+
if len(children) == 0:
|
106 |
+
if label == "pi":
|
107 |
+
if self.cdtype == "float" and not added_pi:
|
108 |
+
cstr += "const float pi = 3.14159265F;\n"
|
109 |
+
added_pi = True
|
110 |
+
else:
|
111 |
+
label = "M_PI"
|
112 |
+
elif label == "E":
|
113 |
+
if self.cdtype == "float" and not added_E:
|
114 |
+
cstr += "const float E = 2.71828183F;\n"
|
115 |
+
added_E = True
|
116 |
+
else:
|
117 |
+
label = "M_E"
|
118 |
+
dag.nodes[node]["var"] = label
|
119 |
+
continue
|
120 |
+
varname = f"t{next(varidx)}"
|
121 |
+
cexpr = self.op_c_impl(label, children)
|
122 |
+
dag.nodes[node]["var"] = varname
|
123 |
+
cstr += f"{self.cdtype} {varname} = {cexpr};\n"
|
124 |
+
retname = varname
|
125 |
+
cstr += f"return {retname};\n"
|
126 |
+
return self.full_c_code(cstr)
|
127 |
+
|
128 |
+
def dag_to_fortran_impl(self):
|
129 |
+
csuf = "" if self.fdtype == "real" else "d0"
|
130 |
+
dag = sympy_to_dag(self.expr, csuf=csuf)
|
131 |
+
varstr = ""
|
132 |
+
fstr = "parameter E = 2.71828183\nparameter pi = 3.14159265\n"
|
133 |
+
for c in self.constants:
|
134 |
+
varstr += f"{self.fdtype} :: {c}\n"
|
135 |
+
fstr += f"parameter {c} = {self.constants[c]}{csuf}\n"
|
136 |
+
varidx = it.count()
|
137 |
+
allvars = []
|
138 |
+
for node in reversed(list(nx.topological_sort(dag))):
|
139 |
+
label = dag.nodes[node]["label"]
|
140 |
+
children = [dag.nodes[n]["var"] for n in dag.adj[node]]
|
141 |
+
if len(children) == 0:
|
142 |
+
dag.nodes[node]["var"] = label
|
143 |
+
continue
|
144 |
+
varname = f"t{next(varidx)}"
|
145 |
+
fexpr = self.op_f_impl(label, children)
|
146 |
+
dag.nodes[node]["var"] = varname
|
147 |
+
fstr += f"{varname} = {fexpr}\n"
|
148 |
+
retname = varname
|
149 |
+
varstr += f"{self.fdtype} :: {varname}\n"
|
150 |
+
fstr += f"y = {retname};\n"
|
151 |
+
fstr = varstr + "\n" + fstr
|
152 |
+
return self.full_f_code(fstr)
|
153 |
+
|
154 |
+
def sympy_cse_c_impl(self):
|
155 |
+
if self.cdtype == "float":
|
156 |
+
extraargs = {
|
157 |
+
"type_aliases": {ast.real: ast.float32},
|
158 |
+
"math_macros": {},
|
159 |
+
}
|
160 |
+
else:
|
161 |
+
extraargs = {}
|
162 |
+
cstr = ""
|
163 |
+
for c in self.constants:
|
164 |
+
cstr += f"{self.cdtype} {c} = {self.constants[c]};\n"
|
165 |
+
xvars, xpr = sp.cse(self.expr)
|
166 |
+
for vname, vxpr in xvars:
|
167 |
+
code = sp.ccode(vxpr, assign_to=vname.name, **extraargs)
|
168 |
+
cstr += f"{self.cdtype} {vname.name}; {code};\n"
|
169 |
+
assert len(xpr) == 1
|
170 |
+
code = sp.ccode(xpr[0], assign_to="y", **extraargs)
|
171 |
+
cstr += f"{self.cdtype} y; {code}; return y;\n"
|
172 |
+
return self.full_c_code(cstr)
|
173 |
+
|
174 |
+
def sympy_cse_fortran_impl(self):
|
175 |
+
csuf = "" if self.fdtype == "real" else "d0"
|
176 |
+
varstr = ""
|
177 |
+
fstr = ""
|
178 |
+
for c in self.constants:
|
179 |
+
varstr += f"{self.fdtype} :: {c}\n"
|
180 |
+
fstr += f"parameter {c} = {self.constants[c]}{csuf}\n"
|
181 |
+
xvars, xpr = sp.cse(self.expr)
|
182 |
+
for vname, vxpr in xvars:
|
183 |
+
varstr += f"{self.fdtype} :: {vname.name}\n"
|
184 |
+
fstr += sp.fcode(vxpr, assign_to=vname.name, standard=95, source_format="free") + "\n"
|
185 |
+
assert len(xpr) == 1
|
186 |
+
fstr += sp.fcode(xpr[0], assign_to="y", standard=95, source_format="free") + "\n"
|
187 |
+
fstr = varstr + "\n" + fstr
|
188 |
+
if self.fdtype == "real":
|
189 |
+
# Hack to fix sympy generation
|
190 |
+
fstr = fstr.replace("d0", "")
|
191 |
+
return self.full_f_code(fstr)
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
# For testing only
|
196 |
+
if __name__ == "__main__":
|
197 |
+
from .parser import parse_prefix_to_sympy, sympy_to_dag
|
198 |
+
|
199 |
+
prefs = "add mul div INT+ 1 INT+ 5 x mul div INT+ 1 INT+ 5 mul x tan pow x INT+ 2".split(" ")
|
200 |
+
exp = parse_prefix_to_sympy(prefs)
|
201 |
+
impl = Implementor(exp, dtype="float")
|
202 |
+
|
203 |
+
print("DAG C:")
|
204 |
+
print(impl.dag_to_c_impl())
|
205 |
+
print("DAG Fortran:")
|
206 |
+
print(impl.dag_to_fortran_impl())
|
207 |
+
print("CSE C:")
|
208 |
+
print(impl.sympy_cse_c_impl())
|
209 |
+
print("CSE Fortran:")
|
210 |
+
print(impl.sympy_cse_fortran_impl())
|
remend/parser.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sympy as sp
|
2 |
+
import networkx as nx
|
3 |
+
import itertools as it
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from .util import DecodeError, sympy_expr_ok
|
7 |
+
|
8 |
+
OPERATORS = {
|
9 |
+
# Elementary functions
|
10 |
+
'add': (lambda a,b: a+b, 2),
|
11 |
+
'sub': (lambda a,b: a-b, 2),
|
12 |
+
'mul': (lambda a,b: a*b, 2),
|
13 |
+
'div': (lambda a,b: a/b, 2),
|
14 |
+
'pow': (lambda a,b: a**b, 2),
|
15 |
+
# 'inv': (lambda a: 1/a, 1),
|
16 |
+
# 'pow2': (lambda a: a**2, 1),
|
17 |
+
# 'pow3': (lambda a: a**3, 1),
|
18 |
+
# 'pow4': (lambda a: a**4, 1),
|
19 |
+
# 'pow5': (lambda a: a**5, 1),
|
20 |
+
'sqrt': (lambda a: sp.sqrt(a), 1),
|
21 |
+
'exp': (lambda a: sp.exp(a), 1),
|
22 |
+
'ln': (lambda a: sp.ln(a), 1),
|
23 |
+
# 'abs': (lambda a: sp.abs(a), 1),
|
24 |
+
# 'sign': (lambda a: sp.sign(a), 1),
|
25 |
+
# Trigonometric Functions
|
26 |
+
'sin': (lambda a: sp.sin(a), 1),
|
27 |
+
'cos': (lambda a: sp.cos(a), 1),
|
28 |
+
'tan': (lambda a: sp.tan(a), 1),
|
29 |
+
'cot': (lambda a: sp.cot(a), 1),
|
30 |
+
'sec': (lambda a: sp.sec(a), 1),
|
31 |
+
'csc': (lambda a: sp.csc(a), 1),
|
32 |
+
# Trigonometric Inverses
|
33 |
+
'asin': (lambda a: sp.asin(a), 1),
|
34 |
+
'acos': (lambda a: sp.acos(a), 1),
|
35 |
+
'atan': (lambda a: sp.atan(a), 1),
|
36 |
+
'acot': (lambda a: sp.acot(a), 1),
|
37 |
+
'asec': (lambda a: sp.asec(a), 1),
|
38 |
+
'acsc': (lambda a: sp.acsc(a), 1),
|
39 |
+
# Hyperbolic
|
40 |
+
# 'sinh': (lambda a: sp.sinh(a), 1),
|
41 |
+
# 'cosh': (lambda a: sp.cosh(a), 1),
|
42 |
+
# 'tanh': (lambda a: sp.tanh(a), 1),
|
43 |
+
}
|
44 |
+
|
45 |
+
CONSTANTS = {
|
46 |
+
'E': sp.E,
|
47 |
+
'pi': sp.pi,
|
48 |
+
'0': 0,
|
49 |
+
'1': 1,
|
50 |
+
'2': 2,
|
51 |
+
'3': 3,
|
52 |
+
'4': 4,
|
53 |
+
'5': 5,
|
54 |
+
'6': 6,
|
55 |
+
'7': 7,
|
56 |
+
'8': 8,
|
57 |
+
'9': 9,
|
58 |
+
}
|
59 |
+
|
60 |
+
VARIABLES = {
|
61 |
+
'x': sp.Symbol('x'),
|
62 |
+
'x0': sp.Symbol('x0'),
|
63 |
+
'x1': sp.Symbol('x1'),
|
64 |
+
|
65 |
+
'c0': sp.Symbol('c0'),
|
66 |
+
'c1': sp.Symbol('c1'),
|
67 |
+
'c2': sp.Symbol('c2'),
|
68 |
+
'c3': sp.Symbol('c3'),
|
69 |
+
'c4': sp.Symbol('c4'),
|
70 |
+
'c5': sp.Symbol('c5'),
|
71 |
+
'c6': sp.Symbol('c6'),
|
72 |
+
'c7': sp.Symbol('c7'),
|
73 |
+
'c8': sp.Symbol('c8'),
|
74 |
+
'c9': sp.Symbol('c9'),
|
75 |
+
'c10': sp.Symbol('c10'),
|
76 |
+
|
77 |
+
'k0': sp.Symbol('k0'),
|
78 |
+
'k1': sp.Symbol('k1'),
|
79 |
+
'k2': sp.Symbol('k2'),
|
80 |
+
'k3': sp.Symbol('k3'),
|
81 |
+
# 'y': sp.Symbol('y'),
|
82 |
+
# 'z': sp.Symbol('z')
|
83 |
+
}
|
84 |
+
|
85 |
+
FUNC_TO_OP = {
|
86 |
+
sp.Add: 'add',
|
87 |
+
sp.Mul: 'mul',
|
88 |
+
sp.Pow: 'pow',
|
89 |
+
|
90 |
+
sp.log: 'ln',
|
91 |
+
sp.sqrt: 'sqrt',
|
92 |
+
sp.exp: 'exp',
|
93 |
+
sp.Abs: 'abs',
|
94 |
+
# 'abs': (lambda a: sp.abs(a), 1),
|
95 |
+
# 'sign': (lambda a: sp.sign(a), 1),
|
96 |
+
# Trigonometric Functions
|
97 |
+
sp.sin: 'sin',
|
98 |
+
sp.cos: 'cos',
|
99 |
+
sp.tan: 'tan',
|
100 |
+
sp.cot: 'cot',
|
101 |
+
sp.sec: 'sec',
|
102 |
+
sp.csc: 'csc',
|
103 |
+
# Trigonometric Inverses
|
104 |
+
sp.asin: 'asin',
|
105 |
+
sp.acos: 'acos',
|
106 |
+
sp.atan: 'atan',
|
107 |
+
sp.acot: 'acot',
|
108 |
+
sp.asec: 'asec',
|
109 |
+
sp.acsc: 'acsc',
|
110 |
+
# Hyperbolic
|
111 |
+
# sp.cosh: 'cosh',
|
112 |
+
# sp.sinh: 'sinh',
|
113 |
+
# sp.tanh: 'tanh'
|
114 |
+
}
|
115 |
+
|
116 |
+
def sympy_func_to_op(f):
|
117 |
+
if f in FUNC_TO_OP:
|
118 |
+
return FUNC_TO_OP[f]
|
119 |
+
else:
|
120 |
+
raise DecodeError(f"Op not found {f}")
|
121 |
+
return str(f)
|
122 |
+
|
123 |
+
def isint(s):
|
124 |
+
try:
|
125 |
+
int(s)
|
126 |
+
return True
|
127 |
+
except ValueError:
|
128 |
+
return False
|
129 |
+
|
130 |
+
def reverse_iter_prefix(prefs):
|
131 |
+
n = len(prefs) - 1
|
132 |
+
# currnum = 0
|
133 |
+
# currpow = 1
|
134 |
+
currnum = []
|
135 |
+
while n >= 0:
|
136 |
+
if isint(prefs[n]) or prefs[n] in ["e", "+", "-", "."]:
|
137 |
+
currnum += prefs[n]
|
138 |
+
# currnum += currpow * int(prefs[n])
|
139 |
+
# currpow *= 10
|
140 |
+
elif prefs[n][:3] == "INT":
|
141 |
+
parsedint = int("".join(reversed(currnum)))
|
142 |
+
if prefs[n][3] == "+":
|
143 |
+
yield parsedint
|
144 |
+
else:
|
145 |
+
yield -parsedint
|
146 |
+
currnum = []
|
147 |
+
# currpow = 1
|
148 |
+
elif prefs[n][:5] == "FLOAT":
|
149 |
+
parsedfloat = float("".join(reversed(currnum)))
|
150 |
+
if prefs[n][5] == "+":
|
151 |
+
yield parsedfloat
|
152 |
+
else:
|
153 |
+
yield -parsedfloat
|
154 |
+
currnum = []
|
155 |
+
else:
|
156 |
+
yield prefs[n]
|
157 |
+
n -= 1
|
158 |
+
|
159 |
+
def parse_prefix_to_sympy(prefs):
|
160 |
+
stack = []
|
161 |
+
for val in reverse_iter_prefix(prefs):
|
162 |
+
# print(stack, val)
|
163 |
+
if val in OPERATORS:
|
164 |
+
spop, numops = OPERATORS[val]
|
165 |
+
operands = [stack.pop() for i in range(numops)]
|
166 |
+
expr = spop(*operands)
|
167 |
+
stack.append(expr)
|
168 |
+
elif val in CONSTANTS:
|
169 |
+
stack.append(CONSTANTS[val])
|
170 |
+
elif val in VARIABLES:
|
171 |
+
stack.append(VARIABLES[val])
|
172 |
+
elif type(val) == int or type(val) == float:
|
173 |
+
stack.append(val)
|
174 |
+
elif val == "(" or val == ")":
|
175 |
+
# Simply ignore brackets
|
176 |
+
continue
|
177 |
+
else:
|
178 |
+
raise DecodeError(f"{val} invalid")
|
179 |
+
|
180 |
+
if len(stack) != 1:
|
181 |
+
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
|
182 |
+
expr = stack.pop()
|
183 |
+
if not sympy_expr_ok(expr):
|
184 |
+
raise DecodeError("Complex or infinite expression")
|
185 |
+
return expr
|
186 |
+
|
187 |
+
def parse_postfix_to_sympy(prefs):
|
188 |
+
stack = []
|
189 |
+
postfix = reversed(list(reverse_iter_prefix(prefs)))
|
190 |
+
for val in postfix:
|
191 |
+
if val in OPERATORS:
|
192 |
+
spop, numops = OPERATORS[val]
|
193 |
+
operands = [stack.pop() for i in range(numops)]
|
194 |
+
expr = spop(*operands)
|
195 |
+
stack.append(expr)
|
196 |
+
elif val in CONSTANTS:
|
197 |
+
stack.append(CONSTANTS[val])
|
198 |
+
elif val in VARIABLES:
|
199 |
+
stack.append(VARIABLES[val])
|
200 |
+
elif type(val) == int or type(val) == float:
|
201 |
+
stack.append(val)
|
202 |
+
elif val == "(" or val == ")":
|
203 |
+
# Simply ignore brackets
|
204 |
+
continue
|
205 |
+
else:
|
206 |
+
raise DecodeError(f"{val} invalid")
|
207 |
+
|
208 |
+
if len(stack) != 1:
|
209 |
+
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
|
210 |
+
expr = stack.pop()
|
211 |
+
if not sympy_expr_ok(expr):
|
212 |
+
raise DecodeError("Complex or infinite expression")
|
213 |
+
return expr
|
214 |
+
|
215 |
+
|
216 |
+
def parse_prefix_to_tree(prefs):
|
217 |
+
tree = nx.DiGraph()
|
218 |
+
stack = []
|
219 |
+
newidx = len(prefs)
|
220 |
+
for nidx, val in enumerate(reverse_iter_prefix(prefs)):
|
221 |
+
tree.add_node(nidx, label=val)
|
222 |
+
if val in OPERATORS:
|
223 |
+
_, numops = OPERATORS[val]
|
224 |
+
childs = [stack.pop() for i in range(numops)]
|
225 |
+
if val in {"pow", "sub", "div"}:
|
226 |
+
# Ordered children
|
227 |
+
tree.add_node(newidx, label="lhs")
|
228 |
+
tree.add_node(newidx+1, label="rhs")
|
229 |
+
tree.add_edge(nidx, newidx)
|
230 |
+
tree.add_edge(nidx, newidx+1)
|
231 |
+
tree.add_edge(newidx, childs[0])
|
232 |
+
tree.add_edge(newidx+1, childs[1])
|
233 |
+
newidx += 2
|
234 |
+
else:
|
235 |
+
for c in childs:
|
236 |
+
tree.add_edge(nidx, c)
|
237 |
+
elif val in CONSTANTS or val in VARIABLES or type(val) == int:
|
238 |
+
pass
|
239 |
+
else:
|
240 |
+
raise DecodeError(f"Val {val} invalid")
|
241 |
+
stack.append(nidx)
|
242 |
+
|
243 |
+
if len(stack) != 1:
|
244 |
+
raise DecodeError(f"Stack not empty, invalid expression: {prefs} || {stack}")
|
245 |
+
|
246 |
+
return tree, stack.pop() # Root node
|
247 |
+
|
248 |
+
def sympy_to_dag(expression, csuf=""):
|
249 |
+
dag = nx.DiGraph()
|
250 |
+
seen = {}
|
251 |
+
nitr = it.count()
|
252 |
+
|
253 |
+
def _dfs(node):
|
254 |
+
children = []
|
255 |
+
for child in node.args:
|
256 |
+
if child in seen:
|
257 |
+
cid = seen[child]
|
258 |
+
else:
|
259 |
+
cid = _dfs(child)
|
260 |
+
children.append(cid)
|
261 |
+
|
262 |
+
nid = next(nitr)
|
263 |
+
dag.add_node(nid, expr=node)
|
264 |
+
seen[node] = nid
|
265 |
+
for cid in children:
|
266 |
+
dag.add_edge(nid, cid)
|
267 |
+
return nid
|
268 |
+
|
269 |
+
_dfs(expression)
|
270 |
+
for node in dag.nodes:
|
271 |
+
if len(dag.adj[node]) == 0:
|
272 |
+
e = dag.nodes[node]["expr"]
|
273 |
+
if isinstance(e, sp.Integer):
|
274 |
+
dag.nodes[node]["label"] = f"{e}.0{csuf}"
|
275 |
+
elif isinstance(e, sp.Rational):
|
276 |
+
dag.nodes[node]["label"] = f"{e.p}.0{csuf}/{e.q}.0{csuf}"
|
277 |
+
elif isinstance(e, sp.Float):
|
278 |
+
dag.nodes[node]["label"] = f"{float(e)}{csuf}"
|
279 |
+
else:
|
280 |
+
dag.nodes[node]["label"] = str(e)
|
281 |
+
else:
|
282 |
+
dag.nodes[node]["label"] = sympy_func_to_op(dag.nodes[node]["expr"].func)
|
283 |
+
|
284 |
+
return dag
|
285 |
+
|
286 |
+
def sympy_to_prefix(expr):
|
287 |
+
trav = []
|
288 |
+
|
289 |
+
def _pre(node):
|
290 |
+
nonlocal trav
|
291 |
+
if isinstance(node, sp.Rational):
|
292 |
+
if node.q != 1:
|
293 |
+
trav.append("div")
|
294 |
+
_pre(node.p)
|
295 |
+
_pre(node.q)
|
296 |
+
else:
|
297 |
+
_pre(node.p)
|
298 |
+
elif isinstance(node, sp.Integer) or isinstance(node, int):
|
299 |
+
v = int(node)
|
300 |
+
if v >= 0:
|
301 |
+
trav.append("INT+")
|
302 |
+
trav.extend(list(str(v)))
|
303 |
+
else:
|
304 |
+
trav.append("INT-")
|
305 |
+
trav.extend(list(str(-v)))
|
306 |
+
elif isinstance(node, sp.Symbol):
|
307 |
+
trav.append(str(node))
|
308 |
+
elif isinstance(node, sp.Mul):
|
309 |
+
mulargs = []
|
310 |
+
divargs = []
|
311 |
+
children = node.args
|
312 |
+
for child in children:
|
313 |
+
if isinstance(child, sp.Pow) and \
|
314 |
+
isinstance(child.args[1], sp.Integer) and child.args[1] == -1:
|
315 |
+
divargs.append(child.args[0])
|
316 |
+
else:
|
317 |
+
mulargs.append(child)
|
318 |
+
if len(divargs) > 0:
|
319 |
+
trav.append("div")
|
320 |
+
if len(mulargs) == 0:
|
321 |
+
trav.append("INT+")
|
322 |
+
trav.append("1")
|
323 |
+
# Insert numerator
|
324 |
+
for i, child in enumerate(mulargs):
|
325 |
+
if i < len(mulargs) - 1:
|
326 |
+
trav.append("mul")
|
327 |
+
_pre(child)
|
328 |
+
# Insert denominator
|
329 |
+
for i, child in enumerate(divargs):
|
330 |
+
if i < len(divargs) - 1:
|
331 |
+
trav.append("mul")
|
332 |
+
_pre(child)
|
333 |
+
elif isinstance(node, sp.Add):
|
334 |
+
addargs = []
|
335 |
+
subargs = []
|
336 |
+
children = node.args
|
337 |
+
for child in children:
|
338 |
+
if isinstance(child, sp.Mul) and len(child.args) == 2 and \
|
339 |
+
isinstance(child.args[1], sp.Integer) and child.args[1] == -1:
|
340 |
+
subargs.append(child.args[0])
|
341 |
+
elif isinstance(child, sp.Mul) and len(child.args) == 2 and \
|
342 |
+
isinstance(child.args[0], sp.Integer) and child.args[0] == -1:
|
343 |
+
subargs.append(child.args[1])
|
344 |
+
else:
|
345 |
+
addargs.append(child)
|
346 |
+
if len(subargs) > 0:
|
347 |
+
trav.append("sub")
|
348 |
+
if len(addargs) == 0:
|
349 |
+
trav.append("INT+")
|
350 |
+
trav.append("0")
|
351 |
+
# Insert numerator
|
352 |
+
for i, child in enumerate(addargs):
|
353 |
+
if i < len(addargs) - 1:
|
354 |
+
trav.append("add")
|
355 |
+
_pre(child)
|
356 |
+
# Insert denominator
|
357 |
+
for i, child in enumerate(subargs):
|
358 |
+
if i < len(subargs) - 1:
|
359 |
+
trav.append("add")
|
360 |
+
_pre(child)
|
361 |
+
elif isinstance(node, sp.Float):
|
362 |
+
rep = sp.nsimplify(node, tolerance=1e-7)
|
363 |
+
if isinstance(rep, sp.Integer):
|
364 |
+
_pre(rep)
|
365 |
+
elif isinstance(rep, sp.Rational) and rep.q <= 16:
|
366 |
+
_pre(rep)
|
367 |
+
else:
|
368 |
+
raise DecodeError(f"Float {node} encountered while generating")
|
369 |
+
# trav.append(str(node))
|
370 |
+
elif node == sp.E or node == sp.pi:
|
371 |
+
# Transcendental constants
|
372 |
+
trav.append(str(node))
|
373 |
+
else:
|
374 |
+
op = sympy_func_to_op(node.func)
|
375 |
+
children = node.args
|
376 |
+
for i, child in enumerate(children):
|
377 |
+
# Insert op repeatedly to maintain binary tree
|
378 |
+
if i == 0 or i < len(children) - 1:
|
379 |
+
trav.append(op)
|
380 |
+
_pre(child)
|
381 |
+
_pre(expr)
|
382 |
+
return trav
|
383 |
+
|
384 |
+
def constant_fold(expr):
|
385 |
+
q = [expr]
|
386 |
+
cidx = 0
|
387 |
+
subsmap = {}
|
388 |
+
constmap = {}
|
389 |
+
|
390 |
+
isconst = lambda e: not any(c.is_symbol for c in e.atoms())
|
391 |
+
|
392 |
+
while len(q) > 0:
|
393 |
+
curr_expr = q.pop(0)
|
394 |
+
if isinstance(curr_expr, sp.Number) or isconst(curr_expr):
|
395 |
+
const_expr = curr_expr.evalf()
|
396 |
+
rep = sp.nsimplify(const_expr, [sp.E, sp.pi], tolerance=1e-7)
|
397 |
+
if isinstance(rep, sp.Integer) or \
|
398 |
+
(isinstance(rep, sp.Rational) and rep.q <= 16) or \
|
399 |
+
rep == sp.E or rep == sp.pi:
|
400 |
+
subsmap[curr_expr] = rep
|
401 |
+
else:
|
402 |
+
val = float(const_expr)
|
403 |
+
found = False
|
404 |
+
for c in constmap:
|
405 |
+
if abs(val - constmap[c]) < 1e-7:
|
406 |
+
subsmap[curr_expr] = sp.Symbol(c)
|
407 |
+
found = True
|
408 |
+
elif abs(1/val - constmap[c]) < 1e-7:
|
409 |
+
subsmap[curr_expr] = 1/sp.Symbol(c)
|
410 |
+
found = True
|
411 |
+
elif abs(-val - constmap[c]) < 1e-7:
|
412 |
+
subsmap[curr_expr] = -sp.Symbol(c)
|
413 |
+
found = True
|
414 |
+
elif abs(-1/val - constmap[c]) < 1e-7:
|
415 |
+
subsmap[curr_expr] = -1/sp.Symbol(c)
|
416 |
+
found = True
|
417 |
+
if not found:
|
418 |
+
subsmap[curr_expr] = sp.Symbol(f"k{cidx}")
|
419 |
+
constmap[f"k{cidx}"] = val
|
420 |
+
cidx += 1
|
421 |
+
else:
|
422 |
+
for child in curr_expr.args:
|
423 |
+
q.append(child)
|
424 |
+
|
425 |
+
return expr.subs(subsmap), constmap
|
426 |
+
|
427 |
+
|
428 |
+
# For testing only
|
429 |
+
if __name__ == "__main__":
|
430 |
+
prefs = "add mul INT- 1 x mul pow ln INT+ 4 INT- 1 add x mul INT- 1 pow x INT+ 5".split(" ")
|
431 |
+
exp = parse_prefix_to_sympy(prefs)
|
432 |
+
exp = sp.simplify(exp)
|
433 |
+
print(exp)
|
434 |
+
print(constant_fold(exp))
|
435 |
+
|
436 |
+
# prefs = "mul x mul pow cos INT+ 4 INT- 3 pow ln INT+ 3 INT- 6".split(" ")
|
437 |
+
# exp = parse_prefix_to_sympy(prefs)
|
438 |
+
# print(exp)
|
439 |
+
# dag = sympy_to_dag(exp)
|
440 |
+
|
441 |
+
# exp = sp.parse_expr("(((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - (-((x0) + (x0)))) / (-((x0) + (x0)))) * ((-((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))))) * ((((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0))))) * k0) - ((x0) + ((x0) - ((-((x0) + (x0))) / ((x0) + (x0)))))))", evaluate=False)
|
442 |
+
# # print(sympy_to_prefix(exp))
|
443 |
+
|
444 |
+
# simp = sp.simplify(exp)
|
445 |
+
# pre = sympy_to_prefix(simp)
|
446 |
+
# print(pre)
|
447 |
+
# repars = parse_prefix_to_sympy(pre)
|
448 |
+
# print(simp)
|
449 |
+
# print(repars)
|
remend/plot_loss.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from matplotlib import pyplot as plt
|
2 |
+
import json
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
import argparse
|
6 |
+
parser = argparse.ArgumentParser("Plot loss for the training log")
|
7 |
+
parser.add_argument("-t", "--trainlog", required=True, help="Training log file")
|
8 |
+
parser.add_argument("-l", "--loss", help="Loss plot to save (optional)")
|
9 |
+
parser.add_argument("--log-scale", default=False, action="store_true", help="Log scale")
|
10 |
+
parser.add_argument("-P", "--no-plot", default=True, action="store_false", help="Don't open matplotlib figure")
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
train_inner_upd, train_inner_loss = [], []
|
14 |
+
train_upd, train_loss = [], []
|
15 |
+
val_upd, val_loss = [], []
|
16 |
+
|
17 |
+
with open(args.trainlog, "r") as tl:
|
18 |
+
for line in tl:
|
19 |
+
# Filter out json
|
20 |
+
if line[0] != "{":
|
21 |
+
continue
|
22 |
+
try:
|
23 |
+
data = json.loads(line.strip())
|
24 |
+
except:
|
25 |
+
continue
|
26 |
+
if "loss" in data:
|
27 |
+
loss = float(data["loss"])
|
28 |
+
upd = int(data["num_updates"])
|
29 |
+
if len(train_inner_upd) == 0 or train_inner_upd[-1] < upd:
|
30 |
+
train_inner_upd.append(upd)
|
31 |
+
train_inner_loss.append(loss)
|
32 |
+
if "valid_loss" in data:
|
33 |
+
loss = float(data["valid_loss"])
|
34 |
+
upd = int(data["valid_num_updates"])
|
35 |
+
if len(val_upd) == 0 or val_upd[-1] < upd:
|
36 |
+
val_upd.append(upd)
|
37 |
+
val_loss.append(loss)
|
38 |
+
if "train_loss" in data:
|
39 |
+
loss = float(data["train_loss"])
|
40 |
+
upd = int(data["train_num_updates"])
|
41 |
+
if len(train_upd) == 0 or train_upd[-1] < upd:
|
42 |
+
train_upd.append(upd)
|
43 |
+
train_loss.append(loss)
|
44 |
+
|
45 |
+
plt.figure()
|
46 |
+
plt.plot(train_upd, train_loss, "r")
|
47 |
+
plt.plot(val_upd, val_loss, "b")
|
48 |
+
if len(train_inner_upd) > 0:
|
49 |
+
plt.plot(train_inner_upd, train_inner_loss, "r", alpha=0.3)
|
50 |
+
plt.legend(["train", "valid"])
|
51 |
+
if args.log_scale:
|
52 |
+
plt.yscale("log")
|
53 |
+
elif min(min(train_loss), min(val_loss)) < 1:
|
54 |
+
plt.ylim((0, 1))
|
55 |
+
plt.xlabel("Updates")
|
56 |
+
plt.ylabel("Loss")
|
57 |
+
if args.loss:
|
58 |
+
plt.savefig(args.loss)
|
59 |
+
if args.no_plot:
|
60 |
+
plt.show()
|
remend/preprocess_remaqe.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from tqdm import tqdm
|
4 |
+
import itertools as it
|
5 |
+
import sympy as sp
|
6 |
+
|
7 |
+
from .disassemble import DisassemblerARM32
|
8 |
+
from .parser import sympy_to_prefix, isint
|
9 |
+
|
10 |
+
def match_constants(exprconst, asmconst, constsym, eps=1e-5):
|
11 |
+
def _close(a, b):
|
12 |
+
return abs(a - b) <= eps
|
13 |
+
mapping = {}
|
14 |
+
mapped = set()
|
15 |
+
|
16 |
+
for ec in exprconst:
|
17 |
+
ecf = float(exprconst[ec])
|
18 |
+
ecsym = constsym[ec]
|
19 |
+
if abs(ecf) < eps:
|
20 |
+
continue
|
21 |
+
for ac in asmconst:
|
22 |
+
acf = asmconst[ac]
|
23 |
+
acsym = constsym[ac]
|
24 |
+
if _close(acf, ecf):
|
25 |
+
mapping[ecsym] = acsym
|
26 |
+
mapped.add(ec)
|
27 |
+
break
|
28 |
+
if _close(acf, 1/ecf):
|
29 |
+
mapping[ecsym] = 1/acsym
|
30 |
+
mapped.add(ec)
|
31 |
+
break
|
32 |
+
if _close(acf, -ecf):
|
33 |
+
mapping[ecsym] = -acsym
|
34 |
+
mapped.add(ec)
|
35 |
+
break
|
36 |
+
return mapping, mapped
|
37 |
+
|
38 |
+
def replace_naming(pref):
|
39 |
+
ret = []
|
40 |
+
for p in pref:
|
41 |
+
if p == "x0":
|
42 |
+
ret.append("x")
|
43 |
+
elif p[0] == "c" and isint(p[1:]):
|
44 |
+
# Constant
|
45 |
+
ret.append("k"+p[1:])
|
46 |
+
else:
|
47 |
+
ret.append(p)
|
48 |
+
return ret
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
import argparse
|
53 |
+
parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
|
54 |
+
parser.add_argument("--list", required=True)
|
55 |
+
parser.add_argument("--prefix", required=True)
|
56 |
+
args = parser.parse_args()
|
57 |
+
|
58 |
+
with open(args.list, "r") as f:
|
59 |
+
mdllist = list(f)
|
60 |
+
opts = ["O0", "O1", "O2", "O3"]
|
61 |
+
|
62 |
+
asmf = open(args.prefix + ".asm", "w")
|
63 |
+
eqnf = open(args.prefix + ".eqn", "w")
|
64 |
+
constf = open(args.prefix + ".const.jsonl", "w")
|
65 |
+
|
66 |
+
basedir = os.path.dirname(args.list)
|
67 |
+
for mdl in tqdm(mdllist):
|
68 |
+
mdl = mdl.strip()
|
69 |
+
mdlname = os.path.basename(mdl)
|
70 |
+
with open(os.path.join(basedir, mdl, "expressions.json")) as f:
|
71 |
+
expressions = json.load(f)
|
72 |
+
yexpr = expressions["expressions"]["y"]
|
73 |
+
exprconsts = {c: float(expressions["constants"][c]) for c in expressions["constants"]}
|
74 |
+
if len(exprconsts) > 4:
|
75 |
+
continue
|
76 |
+
yexpr = sp.parse_expr(yexpr)
|
77 |
+
exprconstsym = {c: sp.Symbol(c) for c in expressions["constants"]}
|
78 |
+
|
79 |
+
for opt in opts:
|
80 |
+
funcname = f"{mdlname}_run"
|
81 |
+
binf = os.path.join(basedir, mdl, opt, f"c_bin.elf")
|
82 |
+
D = DisassemblerARM32(binf)
|
83 |
+
diss = D.disassemble(funcname)
|
84 |
+
constants = D.constants
|
85 |
+
if len(constants) > 3:
|
86 |
+
continue
|
87 |
+
|
88 |
+
exprconstsym.update({c: sp.Symbol(f"c{c}") for c in constants})
|
89 |
+
mapping, mapped = match_constants(exprconsts, constants, exprconstsym)
|
90 |
+
if len(mapped) != len(constants):
|
91 |
+
continue
|
92 |
+
|
93 |
+
exprsubs = yexpr.subs(mapping)
|
94 |
+
exprprefix = replace_naming(sympy_to_prefix(exprsubs))
|
95 |
+
|
96 |
+
asmf.write(diss + "\n")
|
97 |
+
eqnf.write(" ".join(exprprefix) + "\n")
|
98 |
+
constf.write(json.dumps(constants) + "\n")
|
99 |
+
|
100 |
+
asmf.close()
|
101 |
+
eqnf.close()
|
102 |
+
constf.close()
|
remend/util.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
import signal
|
3 |
+
import sympy as sp
|
4 |
+
|
5 |
+
def timeout_handler(signum, frame):
|
6 |
+
raise TimeoutError("Block timed out")
|
7 |
+
@contextmanager
|
8 |
+
def timeout(duration):
|
9 |
+
signal.signal(signal.SIGALRM, timeout_handler)
|
10 |
+
signal.alarm(duration)
|
11 |
+
try:
|
12 |
+
yield
|
13 |
+
finally:
|
14 |
+
signal.alarm(0)
|
15 |
+
|
16 |
+
class DecodeError(Exception):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def sympy_expr_ok(expr):
|
20 |
+
atoms = expr.atoms()
|
21 |
+
return not (sp.I in atoms or sp.oo in atoms or sp.zoo in atoms or sp.nan in atoms)
|