File size: 7,305 Bytes
7145fd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from tqdm import tqdm
import random
import sympy as sp
import json
import subprocess as sproc
from os.path import realpath, dirname, join as pjoin
from os import makedirs
import multiprocessing as mp
from time import sleep
import logging
from .implementation import Implementor
from .parser import parse_prefix_to_sympy, sympy_to_prefix, constant_fold
from .disassemble import DisassemblerARM32, DisassemblerAArch64, DisassemblerX64
from .util import DecodeError, timeout, sympy_expr_ok
SCRIPT = pjoin(dirname(realpath(__file__)), "compile_eqn.sh")
QUEUE_END = "QUEUE_END_SENTINEL"
def compile_c(code, elf, arch="arm32", src="/tmp/myfunc.c", opt=0):
with open(src, "w") as f:
f.write(code)
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-c", src, elf, f"-O{opt}"], capture_output=True)
if ret.returncode != 0:
raise DecodeError("compile failed")
def compile_fortran(code, elf, arch="arm32", src="/tmp/myfunc.f95", opt=0):
with open(src, "w") as f:
f.write(code)
ret = sproc.run(["bash", "-e", SCRIPT, arch+"-fortran", src, elf, f"-O{opt}"], capture_output=True)
if ret.returncode != 0:
raise DecodeError("compile failed")
class EquationCompiler:
def __init__(self, q, arch, impl, opt, outdir, prefix, dtype="double"):
if "fortran" in impl:
self.compiler = compile_fortran
else:
self.compiler = compile_c
if arch == "arm32":
self.disassembler = DisassemblerARM32
elif arch == "aarch64":
self.disassembler = DisassemblerAArch64
elif arch == "x64":
self.disassembler = DisassemblerX64
else:
raise DecodeError("arch not supported: " + arch)
self.q = q
self.impl = impl
self.opt = opt
self.outdir = outdir
self.prefix = prefix
self.dtype = dtype
self.arch = arch
def run(self):
outdir = pjoin(self.outdir, f"O{self.opt}", self.impl)
makedirs(outdir, exist_ok=True)
outfiles = {
"asm": open(pjoin(outdir, self.prefix + ".asm"), "w"),
"eqn": open(pjoin(outdir, self.prefix + ".eqn"), "w"),
"src": open(pjoin(outdir, self.prefix + ".src"), "w"),
"const": open(pjoin(outdir, self.prefix + ".const.jsonl"), "w"),
"err": open(pjoin(outdir, self.prefix + ".error"), "w")
}
l = 0
tmpsrc = f"/tmp/myfunc_{self.impl}_{self.opt}_{self.prefix}"
if "fortran" in self.impl:
tmpsrc += ".f95"
func = "myfunc_"
else:
tmpsrc += ".c"
func = "myfunc"
tmpelf = f"/tmp/myfunc_{self.arch}_{self.impl}_{self.opt}_{self.prefix}.elf"
while True:
data = self.q.get()
if data == QUEUE_END:
# Queue is closed, break from inf loop
break
n, expr, expr_const, pref = data
impl = Implementor(expr, constants=expr_const, dtype=self.dtype)
try:
code = impl.implement(self.impl)
self.compiler(code, tmpelf, arch=self.arch, src=tmpsrc, opt=self.opt)
disasm = self.disassembler(tmpelf, expr_constants=expr_const,
match_constants=True)
asm = disasm.disassemble(func)
if len(disasm.constants) < len(expr_const):
print(n, "constants not identified", disasm.constants, expr_const,
file=outfiles["err"])
continue
except DecodeError as e:
print(n, "impl error", e, expr, expr_const, pref, file=outfiles["err"])
continue
outfiles["asm"].write(asm + "\n")
outfiles["eqn"].write(pref + "\n")
outfiles["src"].write(f"==== pick={n} line={l} ====\n" + code + "\n")
outfiles["const"].write(json.dumps(expr_const) + "\n")
l += 1
for f in outfiles:
outfiles[f].close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Compile prefix to asm->eqn dataset")
parser.add_argument("-f", "--file", required=True, help="Input file")
parser.add_argument("--outdir", required=True, help="Output directory")
parser.add_argument("--prefix", required=True, help="File prefix")
parser.add_argument("--impl", nargs="+", required=True,
choices=["dag_c", "cse_c", "dag_fortran", "cse_fortran"])
parser.add_argument("--pick", type=float, required=True,
help="Ratio of samples to pick (0 to 1)")
parser.add_argument("--start", type=int, default=0, help="Start from index")
parser.add_argument("--count", type=int, default=0, help="Process only these many")
parser.add_argument("--seed", type=int, default=1225)
parser.add_argument("--min-tokens", help="Minimum tokens in equations", type=int, default=5)
parser.add_argument("--min-ops", help="Minimum ops in equations", type=int, default=5)
parser.add_argument("--dtype", help="Implementation datatype", type=str,
choices=["double", "float"], default="double")
parser.add_argument("--arch", help="Target architecture", type=str,
choices=["arm32", "aarch64", "x64"], default="arm32")
parser.add_argument("-O", "--opt", nargs="+", type=int, choices=[0, 1, 2, 3], default=[0],
help="Optimization level (s)")
# Dont show warnings
logging.getLogger("cle").setLevel(logging.ERROR)
args = parser.parse_args()
random.seed(args.seed)
eqcompilers = [EquationCompiler(mp.Queue(), args.arch, impl, opt, args.outdir, args.prefix, dtype=args.dtype)
for impl in args.impl
for opt in args.opt]
pool = [mp.Process(target=eqc.run, args=()) for eqc in eqcompilers]
for proc in pool:
proc.start()
count = 0
prefixf = open(args.file, "r")
for n, line in tqdm(enumerate(prefixf), desc="Parsing file"):
# Skip for start lines and with some probability
if n < args.start or random.random() > args.pick:
continue
comps = line.strip().split("\t")
pref = comps[0][comps[0].find("Y'")+3:]
prefl = pref.split(" ")
# pref = comps[1].split(" ")
if len(prefl) < args.min_tokens:
continue
try:
expr = parse_prefix_to_sympy(prefl)
with timeout(10):
expr = sp.simplify(expr)
if not sympy_expr_ok(expr):
# Simplified is bad
continue
expr, expr_const = constant_fold(expr)
pref = " ".join(sympy_to_prefix(expr))
except:
continue
if sp.count_ops(expr) < args.min_ops:
continue
for eqc in eqcompilers:
# Poll on this queue to get empty
while eqc.q.qsize() > 5:
sleep(1)
eqc.q.put((n, expr, expr_const, pref))
count += 1
if args.count > 0 and count >= args.count:
break
# Close queues
for eqc in eqcompilers:
eqc.q.put(QUEUE_END)
for proc in pool:
proc.join()
|