|
from tqdm import tqdm |
|
import random |
|
import sympy as sp |
|
import json |
|
import subprocess as sproc |
|
from os.path import realpath, dirname, join as pjoin |
|
from os import makedirs |
|
import multiprocessing as mp |
|
from time import sleep |
|
import logging |
|
|
|
from .implementation import Implementor |
|
from .parser import parse_prefix_to_sympy, sympy_to_prefix, constant_fold |
|
from .disassemble import DisassemblerARM32, DisassemblerAArch64, DisassemblerX64 |
|
from .util import DecodeError, timeout, sympy_expr_ok |
|
|
|
SCRIPT = pjoin(dirname(realpath(__file__)), "compile_eqn.sh") |
|
|
|
QUEUE_END = "QUEUE_END_SENTINEL" |
|
|
|
def compile_c(code, elf, arch="arm32", src="/tmp/myfunc.c", opt=0): |
|
with open(src, "w") as f: |
|
f.write(code) |
|
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-c", src, elf, f"-O{opt}"], capture_output=True) |
|
if ret.returncode != 0: |
|
raise DecodeError("compile failed") |
|
|
|
def compile_fortran(code, elf, arch="arm32", src="/tmp/myfunc.f95", opt=0): |
|
with open(src, "w") as f: |
|
f.write(code) |
|
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-fortran", src, elf, f"-O{opt}"], capture_output=True) |
|
if ret.returncode != 0: |
|
raise DecodeError("compile failed") |
|
|
|
class EquationCompiler: |
|
def __init__(self, q, arch, impl, opt, outdir, prefix, dtype="double"): |
|
if "fortran" in impl: |
|
self.compiler = compile_fortran |
|
else: |
|
self.compiler = compile_c |
|
|
|
if arch == "arm32": |
|
self.disassembler = DisassemblerARM32 |
|
elif arch == "aarch64": |
|
self.disassembler = DisassemblerAArch64 |
|
elif arch == "x64": |
|
self.disassembler = DisassemblerX64 |
|
else: |
|
raise DecodeError("arch not supported: " + arch) |
|
|
|
self.q = q |
|
self.impl = impl |
|
self.opt = opt |
|
self.outdir = outdir |
|
self.prefix = prefix |
|
self.dtype = dtype |
|
self.arch = arch |
|
|
|
def run(self): |
|
outdir = pjoin(self.outdir, f"O{self.opt}", self.impl) |
|
makedirs(outdir, exist_ok=True) |
|
outfiles = { |
|
"asm": open(pjoin(outdir, self.prefix + ".asm"), "w"), |
|
"eqn": open(pjoin(outdir, self.prefix + ".eqn"), "w"), |
|
"src": open(pjoin(outdir, self.prefix + ".src"), "w"), |
|
"const": open(pjoin(outdir, self.prefix + ".const.jsonl"), "w"), |
|
"err": open(pjoin(outdir, self.prefix + ".error"), "w") |
|
} |
|
l = 0 |
|
tmpsrc = f"/tmp/myfunc_{self.impl}_{self.opt}_{self.prefix}" |
|
if "fortran" in self.impl: |
|
tmpsrc += ".f95" |
|
func = "myfunc_" |
|
else: |
|
tmpsrc += ".c" |
|
func = "myfunc" |
|
tmpelf = f"/tmp/myfunc_{self.arch}_{self.impl}_{self.opt}_{self.prefix}.elf" |
|
|
|
while True: |
|
data = self.q.get() |
|
if data == QUEUE_END: |
|
|
|
break |
|
n, expr, expr_const, pref = data |
|
impl = Implementor(expr, constants=expr_const, dtype=self.dtype) |
|
try: |
|
code = impl.implement(self.impl) |
|
self.compiler(code, tmpelf, arch=self.arch, src=tmpsrc, opt=self.opt) |
|
disasm = self.disassembler(tmpelf, expr_constants=expr_const, |
|
match_constants=True) |
|
asm = disasm.disassemble(func) |
|
if len(disasm.constants) < len(expr_const): |
|
print(n, "constants not identified", disasm.constants, expr_const, |
|
file=outfiles["err"]) |
|
continue |
|
except DecodeError as e: |
|
print(n, "impl error", e, expr, expr_const, pref, file=outfiles["err"]) |
|
continue |
|
|
|
outfiles["asm"].write(asm + "\n") |
|
outfiles["eqn"].write(pref + "\n") |
|
outfiles["src"].write(f"==== pick={n} line={l} ====\n" + code + "\n") |
|
outfiles["const"].write(json.dumps(expr_const) + "\n") |
|
l += 1 |
|
|
|
for f in outfiles: |
|
outfiles[f].close() |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser("Compile prefix to asm->eqn dataset") |
|
parser.add_argument("-f", "--file", required=True, help="Input file") |
|
parser.add_argument("--outdir", required=True, help="Output directory") |
|
parser.add_argument("--prefix", required=True, help="File prefix") |
|
parser.add_argument("--impl", nargs="+", required=True, |
|
choices=["dag_c", "cse_c", "dag_fortran", "cse_fortran"]) |
|
parser.add_argument("--pick", type=float, required=True, |
|
help="Ratio of samples to pick (0 to 1)") |
|
parser.add_argument("--start", type=int, default=0, help="Start from index") |
|
parser.add_argument("--count", type=int, default=0, help="Process only these many") |
|
parser.add_argument("--seed", type=int, default=1225) |
|
parser.add_argument("--min-tokens", help="Minimum tokens in equations", type=int, default=5) |
|
parser.add_argument("--min-ops", help="Minimum ops in equations", type=int, default=5) |
|
parser.add_argument("--dtype", help="Implementation datatype", type=str, |
|
choices=["double", "float"], default="double") |
|
parser.add_argument("--arch", help="Target architecture", type=str, |
|
choices=["arm32", "aarch64", "x64"], default="arm32") |
|
parser.add_argument("-O", "--opt", nargs="+", type=int, choices=[0, 1, 2, 3], default=[0], |
|
help="Optimization level (s)") |
|
|
|
|
|
logging.getLogger("cle").setLevel(logging.ERROR) |
|
|
|
args = parser.parse_args() |
|
random.seed(args.seed) |
|
|
|
eqcompilers = [EquationCompiler(mp.Queue(), args.arch, impl, opt, args.outdir, args.prefix, dtype=args.dtype) |
|
for impl in args.impl |
|
for opt in args.opt] |
|
pool = [mp.Process(target=eqc.run, args=()) for eqc in eqcompilers] |
|
for proc in pool: |
|
proc.start() |
|
|
|
count = 0 |
|
prefixf = open(args.file, "r") |
|
for n, line in tqdm(enumerate(prefixf), desc="Parsing file"): |
|
|
|
if n < args.start or random.random() > args.pick: |
|
continue |
|
comps = line.strip().split("\t") |
|
pref = comps[0][comps[0].find("Y'")+3:] |
|
prefl = pref.split(" ") |
|
|
|
if len(prefl) < args.min_tokens: |
|
continue |
|
try: |
|
expr = parse_prefix_to_sympy(prefl) |
|
with timeout(10): |
|
expr = sp.simplify(expr) |
|
if not sympy_expr_ok(expr): |
|
|
|
continue |
|
expr, expr_const = constant_fold(expr) |
|
pref = " ".join(sympy_to_prefix(expr)) |
|
except: |
|
continue |
|
|
|
if sp.count_ops(expr) < args.min_ops: |
|
continue |
|
|
|
for eqc in eqcompilers: |
|
|
|
while eqc.q.qsize() > 5: |
|
sleep(1) |
|
eqc.q.put((n, expr, expr_const, pref)) |
|
count += 1 |
|
if args.count > 0 and count >= args.count: |
|
break |
|
|
|
|
|
for eqc in eqcompilers: |
|
eqc.q.put(QUEUE_END) |
|
for proc in pool: |
|
proc.join() |
|
|