File size: 3,781 Bytes
7145fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import sys
import random
import os
import re
from tqdm import tqdm

def filter_poly(asm, eqn):
    rejects = {"ln", "exp", "sin", "cos", "sqrt", "tan", "asin", "acos", "atan", "E", "pi", "cot"}
    return any(t in rejects for t in asm.strip().split(" ")) \
            or any(t in rejects for t in eqn.strip().split(" "))

def filter_bigint(asm, eqn):
    if re.search(r"CONST=[0-9]{4,}", asm):
        return True
    return False

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Deduplicate ASM and split files into train/test/valid")
    parser.add_argument("--inprefix", required=True, help="Prefix of input files")
    parser.add_argument("--outdir", required=True)
    parser.add_argument("--split", type=float, default=0.05)
    parser.add_argument("--seed", type=int, default=1225)
    parser.add_argument("--filter", choices=["poly", "bigint"], default=None)
    parser.add_argument("--no-separate-eqn", action="store_true")
    
    args = parser.parse_args()
    
    eq_mapped = {}
    combined_ds = []
    asm_hash = set()
    removed = 0

    with open(args.inprefix + ".asm", "r") as asmf, \
            open(args.inprefix + ".eqn", "r") as eqnf, \
            open(args.inprefix + ".const.jsonl", "r") as constf:
        for i, (asm, eqn, const) in tqdm(enumerate(zip(asmf, eqnf, constf)),
                                         desc="Read files", leave=False):
            h = hash(asm)
            if h in asm_hash:
                # Skip this repeated line
                removed += 1
                continue

            if re.search(r"[0-9]\.[0-9]", eqn):
                # Float not represented, remove
                removed += 1
                continue

            if args.filter == "poly" and filter_poly(asm, eqn):
                removed += 1
                continue
            if args.filter == "bigint" and filter_bigint(asm, eqn):
                removed += 1
                continue

            asm_hash.add(h)
            if args.no_separate_eqn:
                combined_ds.append((i, asm, eqn, const))
            else:
                if eqn not in eq_mapped:
                    eq_mapped[eqn] = []
                eq_mapped[eqn].append((i, asm, const))

    print("Removed", removed)

    if args.no_separate_eqn:
        dataset = combined_ds
    else:
        dataset = list(eq_mapped.keys())

    random.seed(args.seed)
    random.shuffle(dataset)

    N = len(dataset)
    Ntest = int(N * args.split)
    
    splits = {
        "train": dataset[:N-2*Ntest],
        "valid": dataset[N-2*Ntest:N-Ntest],
        "test": dataset[N-Ntest:]
    }
    splitidxs = {s: [] for s in splits}

    idxf = open(os.path.join(args.outdir, "splits.txt"), "w")
    for s in splits:
        asmfn = os.path.join(args.outdir, f"{s}.asm")
        eqnfn = os.path.join(args.outdir, f"{s}.eqn")
        constfn = os.path.join(args.outdir, f"{s}.const.jsonl")
        with open(asmfn, "w") as asmf, open(eqnfn, "w") as eqnf, \
                open(constfn, "w") as constf:
            if args.no_separate_eqn:
                for i, asm, eqn, const in splits[s]:
                    asmf.write(asm)
                    eqnf.write(eqn)
                    constf.write(const)
                    splitidxs[s].append(i)
            else:
                for eqn in splits[s]:
                    for i, asm, const in eq_mapped[eqn]:
                        asmf.write(asm)
                        eqnf.write(eqn)
                        constf.write(const)
                        splitidxs[s].append(i)
        print("Split", s, len(splitidxs[s]))
        idxf.write(f"==== {s} ====\n")
        for j, i in enumerate(splitidxs[s]):
            idxf.write(f"{j}: {i}\n")
        idxf.write("\n")
    idxf.close()