REMEND / remend /deduplicate_split.py
udiboy1209's picture
Add REMEND python module
7145fd6
import sys
import random
import os
import re
from tqdm import tqdm
def filter_poly(asm, eqn):
rejects = {"ln", "exp", "sin", "cos", "sqrt", "tan", "asin", "acos", "atan", "E", "pi", "cot"}
return any(t in rejects for t in asm.strip().split(" ")) \
or any(t in rejects for t in eqn.strip().split(" "))
def filter_bigint(asm, eqn):
if re.search(r"CONST=[0-9]{4,}", asm):
return True
return False
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Deduplicate ASM and split files into train/test/valid")
parser.add_argument("--inprefix", required=True, help="Prefix of input files")
parser.add_argument("--outdir", required=True)
parser.add_argument("--split", type=float, default=0.05)
parser.add_argument("--seed", type=int, default=1225)
parser.add_argument("--filter", choices=["poly", "bigint"], default=None)
parser.add_argument("--no-separate-eqn", action="store_true")
args = parser.parse_args()
eq_mapped = {}
combined_ds = []
asm_hash = set()
removed = 0
with open(args.inprefix + ".asm", "r") as asmf, \
open(args.inprefix + ".eqn", "r") as eqnf, \
open(args.inprefix + ".const.jsonl", "r") as constf:
for i, (asm, eqn, const) in tqdm(enumerate(zip(asmf, eqnf, constf)),
desc="Read files", leave=False):
h = hash(asm)
if h in asm_hash:
# Skip this repeated line
removed += 1
continue
if re.search(r"[0-9]\.[0-9]", eqn):
# Float not represented, remove
removed += 1
continue
if args.filter == "poly" and filter_poly(asm, eqn):
removed += 1
continue
if args.filter == "bigint" and filter_bigint(asm, eqn):
removed += 1
continue
asm_hash.add(h)
if args.no_separate_eqn:
combined_ds.append((i, asm, eqn, const))
else:
if eqn not in eq_mapped:
eq_mapped[eqn] = []
eq_mapped[eqn].append((i, asm, const))
print("Removed", removed)
if args.no_separate_eqn:
dataset = combined_ds
else:
dataset = list(eq_mapped.keys())
random.seed(args.seed)
random.shuffle(dataset)
N = len(dataset)
Ntest = int(N * args.split)
splits = {
"train": dataset[:N-2*Ntest],
"valid": dataset[N-2*Ntest:N-Ntest],
"test": dataset[N-Ntest:]
}
splitidxs = {s: [] for s in splits}
idxf = open(os.path.join(args.outdir, "splits.txt"), "w")
for s in splits:
asmfn = os.path.join(args.outdir, f"{s}.asm")
eqnfn = os.path.join(args.outdir, f"{s}.eqn")
constfn = os.path.join(args.outdir, f"{s}.const.jsonl")
with open(asmfn, "w") as asmf, open(eqnfn, "w") as eqnf, \
open(constfn, "w") as constf:
if args.no_separate_eqn:
for i, asm, eqn, const in splits[s]:
asmf.write(asm)
eqnf.write(eqn)
constf.write(const)
splitidxs[s].append(i)
else:
for eqn in splits[s]:
for i, asm, const in eq_mapped[eqn]:
asmf.write(asm)
eqnf.write(eqn)
constf.write(const)
splitidxs[s].append(i)
print("Split", s, len(splitidxs[s]))
idxf.write(f"==== {s} ====\n")
for j, i in enumerate(splitidxs[s]):
idxf.write(f"{j}: {i}\n")
idxf.write("\n")
idxf.close()