REMEND / remend /compile_dataset.py
udiboy1209's picture
Add REMEND python module
7145fd6
raw
history blame
7.31 kB
from tqdm import tqdm
import random
import sympy as sp
import json
import subprocess as sproc
from os.path import realpath, dirname, join as pjoin
from os import makedirs
import multiprocessing as mp
from time import sleep
import logging
from .implementation import Implementor
from .parser import parse_prefix_to_sympy, sympy_to_prefix, constant_fold
from .disassemble import DisassemblerARM32, DisassemblerAArch64, DisassemblerX64
from .util import DecodeError, timeout, sympy_expr_ok
SCRIPT = pjoin(dirname(realpath(__file__)), "compile_eqn.sh")
QUEUE_END = "QUEUE_END_SENTINEL"
def compile_c(code, elf, arch="arm32", src="/tmp/myfunc.c", opt=0):
with open(src, "w") as f:
f.write(code)
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-c", src, elf, f"-O{opt}"], capture_output=True)
if ret.returncode != 0:
raise DecodeError("compile failed")
def compile_fortran(code, elf, arch="arm32", src="/tmp/myfunc.f95", opt=0):
with open(src, "w") as f:
f.write(code)
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-fortran", src, elf, f"-O{opt}"], capture_output=True)
if ret.returncode != 0:
raise DecodeError("compile failed")
class EquationCompiler:
def __init__(self, q, arch, impl, opt, outdir, prefix, dtype="double"):
if "fortran" in impl:
self.compiler = compile_fortran
else:
self.compiler = compile_c
if arch == "arm32":
self.disassembler = DisassemblerARM32
elif arch == "aarch64":
self.disassembler = DisassemblerAArch64
elif arch == "x64":
self.disassembler = DisassemblerX64
else:
raise DecodeError("arch not supported: " + arch)
self.q = q
self.impl = impl
self.opt = opt
self.outdir = outdir
self.prefix = prefix
self.dtype = dtype
self.arch = arch
def run(self):
outdir = pjoin(self.outdir, f"O{self.opt}", self.impl)
makedirs(outdir, exist_ok=True)
outfiles = {
"asm": open(pjoin(outdir, self.prefix + ".asm"), "w"),
"eqn": open(pjoin(outdir, self.prefix + ".eqn"), "w"),
"src": open(pjoin(outdir, self.prefix + ".src"), "w"),
"const": open(pjoin(outdir, self.prefix + ".const.jsonl"), "w"),
"err": open(pjoin(outdir, self.prefix + ".error"), "w")
}
l = 0
tmpsrc = f"/tmp/myfunc_{self.impl}_{self.opt}_{self.prefix}"
if "fortran" in self.impl:
tmpsrc += ".f95"
func = "myfunc_"
else:
tmpsrc += ".c"
func = "myfunc"
tmpelf = f"/tmp/myfunc_{self.arch}_{self.impl}_{self.opt}_{self.prefix}.elf"
while True:
data = self.q.get()
if data == QUEUE_END:
# Queue is closed, break from inf loop
break
n, expr, expr_const, pref = data
impl = Implementor(expr, constants=expr_const, dtype=self.dtype)
try:
code = impl.implement(self.impl)
self.compiler(code, tmpelf, arch=self.arch, src=tmpsrc, opt=self.opt)
disasm = self.disassembler(tmpelf, expr_constants=expr_const,
match_constants=True)
asm = disasm.disassemble(func)
if len(disasm.constants) < len(expr_const):
print(n, "constants not identified", disasm.constants, expr_const,
file=outfiles["err"])
continue
except DecodeError as e:
print(n, "impl error", e, expr, expr_const, pref, file=outfiles["err"])
continue
outfiles["asm"].write(asm + "\n")
outfiles["eqn"].write(pref + "\n")
outfiles["src"].write(f"==== pick={n} line={l} ====\n" + code + "\n")
outfiles["const"].write(json.dumps(expr_const) + "\n")
l += 1
for f in outfiles:
outfiles[f].close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Compile prefix to asm->eqn dataset")
parser.add_argument("-f", "--file", required=True, help="Input file")
parser.add_argument("--outdir", required=True, help="Output directory")
parser.add_argument("--prefix", required=True, help="File prefix")
parser.add_argument("--impl", nargs="+", required=True,
choices=["dag_c", "cse_c", "dag_fortran", "cse_fortran"])
parser.add_argument("--pick", type=float, required=True,
help="Ratio of samples to pick (0 to 1)")
parser.add_argument("--start", type=int, default=0, help="Start from index")
parser.add_argument("--count", type=int, default=0, help="Process only these many")
parser.add_argument("--seed", type=int, default=1225)
parser.add_argument("--min-tokens", help="Minimum tokens in equations", type=int, default=5)
parser.add_argument("--min-ops", help="Minimum ops in equations", type=int, default=5)
parser.add_argument("--dtype", help="Implementation datatype", type=str,
choices=["double", "float"], default="double")
parser.add_argument("--arch", help="Target architecture", type=str,
choices=["arm32", "aarch64", "x64"], default="arm32")
parser.add_argument("-O", "--opt", nargs="+", type=int, choices=[0, 1, 2, 3], default=[0],
help="Optimization level (s)")
# Dont show warnings
logging.getLogger("cle").setLevel(logging.ERROR)
args = parser.parse_args()
random.seed(args.seed)
eqcompilers = [EquationCompiler(mp.Queue(), args.arch, impl, opt, args.outdir, args.prefix, dtype=args.dtype)
for impl in args.impl
for opt in args.opt]
pool = [mp.Process(target=eqc.run, args=()) for eqc in eqcompilers]
for proc in pool:
proc.start()
count = 0
prefixf = open(args.file, "r")
for n, line in tqdm(enumerate(prefixf), desc="Parsing file"):
# Skip for start lines and with some probability
if n < args.start or random.random() > args.pick:
continue
comps = line.strip().split("\t")
pref = comps[0][comps[0].find("Y'")+3:]
prefl = pref.split(" ")
# pref = comps[1].split(" ")
if len(prefl) < args.min_tokens:
continue
try:
expr = parse_prefix_to_sympy(prefl)
with timeout(10):
expr = sp.simplify(expr)
if not sympy_expr_ok(expr):
# Simplified is bad
continue
expr, expr_const = constant_fold(expr)
pref = " ".join(sympy_to_prefix(expr))
except:
continue
if sp.count_ops(expr) < args.min_ops:
continue
for eqc in eqcompilers:
# Poll on this queue to get empty
while eqc.q.qsize() > 5:
sleep(1)
eqc.q.put((n, expr, expr_const, pref))
count += 1
if args.count > 0 and count >= args.count:
break
# Close queues
for eqc in eqcompilers:
eqc.q.put(QUEUE_END)
for proc in pool:
proc.join()