REMEND / remend /preprocess_remaqe.py
udiboy1209's picture
Add REMEND python module
7145fd6
import os
import json
from tqdm import tqdm
import itertools as it
import sympy as sp
from .disassemble import DisassemblerARM32
from .parser import sympy_to_prefix, isint
def match_constants(exprconst, asmconst, constsym, eps=1e-5):
def _close(a, b):
return abs(a - b) <= eps
mapping = {}
mapped = set()
for ec in exprconst:
ecf = float(exprconst[ec])
ecsym = constsym[ec]
if abs(ecf) < eps:
continue
for ac in asmconst:
acf = asmconst[ac]
acsym = constsym[ac]
if _close(acf, ecf):
mapping[ecsym] = acsym
mapped.add(ec)
break
if _close(acf, 1/ecf):
mapping[ecsym] = 1/acsym
mapped.add(ec)
break
if _close(acf, -ecf):
mapping[ecsym] = -acsym
mapped.add(ec)
break
return mapping, mapped
def replace_naming(pref):
ret = []
for p in pref:
if p == "x0":
ret.append("x")
elif p[0] == "c" and isint(p[1:]):
# Constant
ret.append("k"+p[1:])
else:
ret.append(p)
return ret
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
parser.add_argument("--list", required=True)
parser.add_argument("--prefix", required=True)
args = parser.parse_args()
with open(args.list, "r") as f:
mdllist = list(f)
opts = ["O0", "O1", "O2", "O3"]
asmf = open(args.prefix + ".asm", "w")
eqnf = open(args.prefix + ".eqn", "w")
constf = open(args.prefix + ".const.jsonl", "w")
basedir = os.path.dirname(args.list)
for mdl in tqdm(mdllist):
mdl = mdl.strip()
mdlname = os.path.basename(mdl)
with open(os.path.join(basedir, mdl, "expressions.json")) as f:
expressions = json.load(f)
yexpr = expressions["expressions"]["y"]
exprconsts = {c: float(expressions["constants"][c]) for c in expressions["constants"]}
if len(exprconsts) > 4:
continue
yexpr = sp.parse_expr(yexpr)
exprconstsym = {c: sp.Symbol(c) for c in expressions["constants"]}
for opt in opts:
funcname = f"{mdlname}_run"
binf = os.path.join(basedir, mdl, opt, f"c_bin.elf")
D = DisassemblerARM32(binf)
diss = D.disassemble(funcname)
constants = D.constants
if len(constants) > 3:
continue
exprconstsym.update({c: sp.Symbol(f"c{c}") for c in constants})
mapping, mapped = match_constants(exprconsts, constants, exprconstsym)
if len(mapped) != len(constants):
continue
exprsubs = yexpr.subs(mapping)
exprprefix = replace_naming(sympy_to_prefix(exprsubs))
asmf.write(diss + "\n")
eqnf.write(" ".join(exprprefix) + "\n")
constf.write(json.dumps(constants) + "\n")
asmf.close()
eqnf.close()
constf.close()