from tqdm import tqdm import random import sympy as sp import json import subprocess as sproc from os.path import realpath, dirname, join as pjoin from os import makedirs import multiprocessing as mp from time import sleep import logging from .implementation import Implementor from .parser import parse_prefix_to_sympy, sympy_to_prefix, constant_fold from .disassemble import DisassemblerARM32, DisassemblerAArch64, DisassemblerX64 from .util import DecodeError, timeout, sympy_expr_ok SCRIPT = pjoin(dirname(realpath(__file__)), "compile_eqn.sh") QUEUE_END = "QUEUE_END_SENTINEL" def compile_c(code, elf, arch="arm32", src="/tmp/myfunc.c", opt=0): with open(src, "w") as f: f.write(code) ret = sproc.run(["bash", "-e", SCRIPT, arch+"-c", src, elf, f"-O{opt}"], capture_output=True) if ret.returncode != 0: raise DecodeError("compile failed") def compile_fortran(code, elf, arch="arm32", src="/tmp/myfunc.f95", opt=0): with open(src, "w") as f: f.write(code) ret = sproc.run(["bash", "-e", SCRIPT, arch+"-fortran", src, elf, f"-O{opt}"], capture_output=True) if ret.returncode != 0: raise DecodeError("compile failed") class EquationCompiler: def __init__(self, q, arch, impl, opt, outdir, prefix, dtype="double"): if "fortran" in impl: self.compiler = compile_fortran else: self.compiler = compile_c if arch == "arm32": self.disassembler = DisassemblerARM32 elif arch == "aarch64": self.disassembler = DisassemblerAArch64 elif arch == "x64": self.disassembler = DisassemblerX64 else: raise DecodeError("arch not supported: " + arch) self.q = q self.impl = impl self.opt = opt self.outdir = outdir self.prefix = prefix self.dtype = dtype self.arch = arch def run(self): outdir = pjoin(self.outdir, f"O{self.opt}", self.impl) makedirs(outdir, exist_ok=True) outfiles = { "asm": open(pjoin(outdir, self.prefix + ".asm"), "w"), "eqn": open(pjoin(outdir, self.prefix + ".eqn"), "w"), "src": open(pjoin(outdir, self.prefix + ".src"), "w"), "const": open(pjoin(outdir, self.prefix + ".const.jsonl"), "w"), "err": open(pjoin(outdir, self.prefix + ".error"), "w") } l = 0 tmpsrc = f"/tmp/myfunc_{self.impl}_{self.opt}_{self.prefix}" if "fortran" in self.impl: tmpsrc += ".f95" func = "myfunc_" else: tmpsrc += ".c" func = "myfunc" tmpelf = f"/tmp/myfunc_{self.arch}_{self.impl}_{self.opt}_{self.prefix}.elf" while True: data = self.q.get() if data == QUEUE_END: # Queue is closed, break from inf loop break n, expr, expr_const, pref = data impl = Implementor(expr, constants=expr_const, dtype=self.dtype) try: code = impl.implement(self.impl) self.compiler(code, tmpelf, arch=self.arch, src=tmpsrc, opt=self.opt) disasm = self.disassembler(tmpelf, expr_constants=expr_const, match_constants=True) asm = disasm.disassemble(func) if len(disasm.constants) < len(expr_const): print(n, "constants not identified", disasm.constants, expr_const, file=outfiles["err"]) continue except DecodeError as e: print(n, "impl error", e, expr, expr_const, pref, file=outfiles["err"]) continue outfiles["asm"].write(asm + "\n") outfiles["eqn"].write(pref + "\n") outfiles["src"].write(f"==== pick={n} line={l} ====\n" + code + "\n") outfiles["const"].write(json.dumps(expr_const) + "\n") l += 1 for f in outfiles: outfiles[f].close() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Compile prefix to asm->eqn dataset") parser.add_argument("-f", "--file", required=True, help="Input file") parser.add_argument("--outdir", required=True, help="Output directory") parser.add_argument("--prefix", required=True, help="File prefix") parser.add_argument("--impl", nargs="+", required=True, choices=["dag_c", "cse_c", "dag_fortran", "cse_fortran"]) parser.add_argument("--pick", type=float, required=True, help="Ratio of samples to pick (0 to 1)") parser.add_argument("--start", type=int, default=0, help="Start from index") parser.add_argument("--count", type=int, default=0, help="Process only these many") parser.add_argument("--seed", type=int, default=1225) parser.add_argument("--min-tokens", help="Minimum tokens in equations", type=int, default=5) parser.add_argument("--min-ops", help="Minimum ops in equations", type=int, default=5) parser.add_argument("--dtype", help="Implementation datatype", type=str, choices=["double", "float"], default="double") parser.add_argument("--arch", help="Target architecture", type=str, choices=["arm32", "aarch64", "x64"], default="arm32") parser.add_argument("-O", "--opt", nargs="+", type=int, choices=[0, 1, 2, 3], default=[0], help="Optimization level (s)") # Dont show warnings logging.getLogger("cle").setLevel(logging.ERROR) args = parser.parse_args() random.seed(args.seed) eqcompilers = [EquationCompiler(mp.Queue(), args.arch, impl, opt, args.outdir, args.prefix, dtype=args.dtype) for impl in args.impl for opt in args.opt] pool = [mp.Process(target=eqc.run, args=()) for eqc in eqcompilers] for proc in pool: proc.start() count = 0 prefixf = open(args.file, "r") for n, line in tqdm(enumerate(prefixf), desc="Parsing file"): # Skip for start lines and with some probability if n < args.start or random.random() > args.pick: continue comps = line.strip().split("\t") pref = comps[0][comps[0].find("Y'")+3:] prefl = pref.split(" ") # pref = comps[1].split(" ") if len(prefl) < args.min_tokens: continue try: expr = parse_prefix_to_sympy(prefl) with timeout(10): expr = sp.simplify(expr) if not sympy_expr_ok(expr): # Simplified is bad continue expr, expr_const = constant_fold(expr) pref = " ".join(sympy_to_prefix(expr)) except: continue if sp.count_ops(expr) < args.min_ops: continue for eqc in eqcompilers: # Poll on this queue to get empty while eqc.q.qsize() > 5: sleep(1) eqc.q.put((n, expr, expr_const, pref)) count += 1 if args.count > 0 and count >= args.count: break # Close queues for eqc in eqcompilers: eqc.q.put(QUEUE_END) for proc in pool: proc.join()