REMEND / remend /preprocess_remaqe.py
udiboy1209's picture
Add REMEND python module
7145fd6
raw
history blame
3.16 kB
import os
import json
from tqdm import tqdm
import itertools as it
import sympy as sp
from .disassemble import DisassemblerARM32
from .parser import sympy_to_prefix, isint
def match_constants(exprconst, asmconst, constsym, eps=1e-5):
def _close(a, b):
return abs(a - b) <= eps
mapping = {}
mapped = set()
for ec in exprconst:
ecf = float(exprconst[ec])
ecsym = constsym[ec]
if abs(ecf) < eps:
continue
for ac in asmconst:
acf = asmconst[ac]
acsym = constsym[ac]
if _close(acf, ecf):
mapping[ecsym] = acsym
mapped.add(ec)
break
if _close(acf, 1/ecf):
mapping[ecsym] = 1/acsym
mapped.add(ec)
break
if _close(acf, -ecf):
mapping[ecsym] = -acsym
mapped.add(ec)
break
return mapping, mapped
def replace_naming(pref):
ret = []
for p in pref:
if p == "x0":
ret.append("x")
elif p[0] == "c" and isint(p[1:]):
# Constant
ret.append("k"+p[1:])
else:
ret.append(p)
return ret
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
parser.add_argument("--list", required=True)
parser.add_argument("--prefix", required=True)
args = parser.parse_args()
with open(args.list, "r") as f:
mdllist = list(f)
opts = ["O0", "O1", "O2", "O3"]
asmf = open(args.prefix + ".asm", "w")
eqnf = open(args.prefix + ".eqn", "w")
constf = open(args.prefix + ".const.jsonl", "w")
basedir = os.path.dirname(args.list)
for mdl in tqdm(mdllist):
mdl = mdl.strip()
mdlname = os.path.basename(mdl)
with open(os.path.join(basedir, mdl, "expressions.json")) as f:
expressions = json.load(f)
yexpr = expressions["expressions"]["y"]
exprconsts = {c: float(expressions["constants"][c]) for c in expressions["constants"]}
if len(exprconsts) > 4:
continue
yexpr = sp.parse_expr(yexpr)
exprconstsym = {c: sp.Symbol(c) for c in expressions["constants"]}
for opt in opts:
funcname = f"{mdlname}_run"
binf = os.path.join(basedir, mdl, opt, f"c_bin.elf")
D = DisassemblerARM32(binf)
diss = D.disassemble(funcname)
constants = D.constants
if len(constants) > 3:
continue
exprconstsym.update({c: sp.Symbol(f"c{c}") for c in constants})
mapping, mapped = match_constants(exprconsts, constants, exprconstsym)
if len(mapped) != len(constants):
continue
exprsubs = yexpr.subs(mapping)
exprprefix = replace_naming(sympy_to_prefix(exprsubs))
asmf.write(diss + "\n")
eqnf.write(" ".join(exprprefix) + "\n")
constf.write(json.dumps(constants) + "\n")
asmf.close()
eqnf.close()
constf.close()