|
import sys |
|
import random |
|
import os |
|
import re |
|
from tqdm import tqdm |
|
|
|
def filter_poly(asm, eqn): |
|
rejects = {"ln", "exp", "sin", "cos", "sqrt", "tan", "asin", "acos", "atan", "E", "pi", "cot"} |
|
return any(t in rejects for t in asm.strip().split(" ")) \ |
|
or any(t in rejects for t in eqn.strip().split(" ")) |
|
|
|
def filter_bigint(asm, eqn): |
|
if re.search(r"CONST=[0-9]{4,}", asm): |
|
return True |
|
return False |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser("Deduplicate ASM and split files into train/test/valid") |
|
parser.add_argument("--inprefix", required=True, help="Prefix of input files") |
|
parser.add_argument("--outdir", required=True) |
|
parser.add_argument("--split", type=float, default=0.05) |
|
parser.add_argument("--seed", type=int, default=1225) |
|
parser.add_argument("--filter", choices=["poly", "bigint"], default=None) |
|
parser.add_argument("--no-separate-eqn", action="store_true") |
|
|
|
args = parser.parse_args() |
|
|
|
eq_mapped = {} |
|
combined_ds = [] |
|
asm_hash = set() |
|
removed = 0 |
|
|
|
with open(args.inprefix + ".asm", "r") as asmf, \ |
|
open(args.inprefix + ".eqn", "r") as eqnf, \ |
|
open(args.inprefix + ".const.jsonl", "r") as constf: |
|
for i, (asm, eqn, const) in tqdm(enumerate(zip(asmf, eqnf, constf)), |
|
desc="Read files", leave=False): |
|
h = hash(asm) |
|
if h in asm_hash: |
|
|
|
removed += 1 |
|
continue |
|
|
|
if re.search(r"[0-9]\.[0-9]", eqn): |
|
|
|
removed += 1 |
|
continue |
|
|
|
if args.filter == "poly" and filter_poly(asm, eqn): |
|
removed += 1 |
|
continue |
|
if args.filter == "bigint" and filter_bigint(asm, eqn): |
|
removed += 1 |
|
continue |
|
|
|
asm_hash.add(h) |
|
if args.no_separate_eqn: |
|
combined_ds.append((i, asm, eqn, const)) |
|
else: |
|
if eqn not in eq_mapped: |
|
eq_mapped[eqn] = [] |
|
eq_mapped[eqn].append((i, asm, const)) |
|
|
|
print("Removed", removed) |
|
|
|
if args.no_separate_eqn: |
|
dataset = combined_ds |
|
else: |
|
dataset = list(eq_mapped.keys()) |
|
|
|
random.seed(args.seed) |
|
random.shuffle(dataset) |
|
|
|
N = len(dataset) |
|
Ntest = int(N * args.split) |
|
|
|
splits = { |
|
"train": dataset[:N-2*Ntest], |
|
"valid": dataset[N-2*Ntest:N-Ntest], |
|
"test": dataset[N-Ntest:] |
|
} |
|
splitidxs = {s: [] for s in splits} |
|
|
|
idxf = open(os.path.join(args.outdir, "splits.txt"), "w") |
|
for s in splits: |
|
asmfn = os.path.join(args.outdir, f"{s}.asm") |
|
eqnfn = os.path.join(args.outdir, f"{s}.eqn") |
|
constfn = os.path.join(args.outdir, f"{s}.const.jsonl") |
|
with open(asmfn, "w") as asmf, open(eqnfn, "w") as eqnf, \ |
|
open(constfn, "w") as constf: |
|
if args.no_separate_eqn: |
|
for i, asm, eqn, const in splits[s]: |
|
asmf.write(asm) |
|
eqnf.write(eqn) |
|
constf.write(const) |
|
splitidxs[s].append(i) |
|
else: |
|
for eqn in splits[s]: |
|
for i, asm, const in eq_mapped[eqn]: |
|
asmf.write(asm) |
|
eqnf.write(eqn) |
|
constf.write(const) |
|
splitidxs[s].append(i) |
|
print("Split", s, len(splitidxs[s])) |
|
idxf.write(f"==== {s} ====\n") |
|
for j, i in enumerate(splitidxs[s]): |
|
idxf.write(f"{j}: {i}\n") |
|
idxf.write("\n") |
|
idxf.close() |
|
|
|
|