import sys import random import os import re from tqdm import tqdm def filter_poly(asm, eqn): rejects = {"ln", "exp", "sin", "cos", "sqrt", "tan", "asin", "acos", "atan", "E", "pi", "cot"} return any(t in rejects for t in asm.strip().split(" ")) \ or any(t in rejects for t in eqn.strip().split(" ")) def filter_bigint(asm, eqn): if re.search(r"CONST=[0-9]{4,}", asm): return True return False if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Deduplicate ASM and split files into train/test/valid") parser.add_argument("--inprefix", required=True, help="Prefix of input files") parser.add_argument("--outdir", required=True) parser.add_argument("--split", type=float, default=0.05) parser.add_argument("--seed", type=int, default=1225) parser.add_argument("--filter", choices=["poly", "bigint"], default=None) parser.add_argument("--no-separate-eqn", action="store_true") args = parser.parse_args() eq_mapped = {} combined_ds = [] asm_hash = set() removed = 0 with open(args.inprefix + ".asm", "r") as asmf, \ open(args.inprefix + ".eqn", "r") as eqnf, \ open(args.inprefix + ".const.jsonl", "r") as constf: for i, (asm, eqn, const) in tqdm(enumerate(zip(asmf, eqnf, constf)), desc="Read files", leave=False): h = hash(asm) if h in asm_hash: # Skip this repeated line removed += 1 continue if re.search(r"[0-9]\.[0-9]", eqn): # Float not represented, remove removed += 1 continue if args.filter == "poly" and filter_poly(asm, eqn): removed += 1 continue if args.filter == "bigint" and filter_bigint(asm, eqn): removed += 1 continue asm_hash.add(h) if args.no_separate_eqn: combined_ds.append((i, asm, eqn, const)) else: if eqn not in eq_mapped: eq_mapped[eqn] = [] eq_mapped[eqn].append((i, asm, const)) print("Removed", removed) if args.no_separate_eqn: dataset = combined_ds else: dataset = list(eq_mapped.keys()) random.seed(args.seed) random.shuffle(dataset) N = len(dataset) Ntest = int(N * args.split) splits = { "train": dataset[:N-2*Ntest], "valid": dataset[N-2*Ntest:N-Ntest], "test": dataset[N-Ntest:] } splitidxs = {s: [] for s in splits} idxf = open(os.path.join(args.outdir, "splits.txt"), "w") for s in splits: asmfn = os.path.join(args.outdir, f"{s}.asm") eqnfn = os.path.join(args.outdir, f"{s}.eqn") constfn = os.path.join(args.outdir, f"{s}.const.jsonl") with open(asmfn, "w") as asmf, open(eqnfn, "w") as eqnf, \ open(constfn, "w") as constf: if args.no_separate_eqn: for i, asm, eqn, const in splits[s]: asmf.write(asm) eqnf.write(eqn) constf.write(const) splitidxs[s].append(i) else: for eqn in splits[s]: for i, asm, const in eq_mapped[eqn]: asmf.write(asm) eqnf.write(eqn) constf.write(const) splitidxs[s].append(i) print("Split", s, len(splitidxs[s])) idxf.write(f"==== {s} ====\n") for j, i in enumerate(splitidxs[s]): idxf.write(f"{j}: {i}\n") idxf.write("\n") idxf.close()