File size: 2,394 Bytes
7145fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import sys
from tqdm import tqdm
from Levenshtein import distance

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Find duplicates in the dataset ASM")
    parser.add_argument("--train", required=True)
    # parser.add_argument("--valid", required=True)
    parser.add_argument("--test", required=True)
    parser.add_argument("--result", required=False)
    parser.add_argument("--distance", action="store_true", default=False)
    args = parser.parse_args()
    
    train = []
    train_hash = {}
    # valid = []
    test = []
    with open(args.train, "r") as tf:
        for idx, line in tqdm(enumerate(tf), desc="Read train", leave=False):
            train_hash[hash(line)] = idx
            comps = line.strip().split(" ")
            train.append(comps)
    # with open(args.valid, "r") as tf:
    #     for line in tqdm(tf, desc="Read valid", leave=False):
    #         valid.append(line.strip().split(" "))
    with open(args.test, "r") as tf:
        for line in tqdm(tf, desc="Read test", leave=False):
            test.append(line)

    selfcheck = args.train == args.test
    if args.result:
        rf = open(args.result, "w")
        searchdist = args.distance
    else:
        searchdist = False # Dont compute if no result file
        rf = None

    def reswrite(s):
        if rf:
            rf.write(s)

    exact = 0
    for i, testline in tqdm(enumerate(test), desc="Test", total=len(test)):
        testl = testline.strip().split(" ")
        htest = hash(testline)
        if htest in train_hash:
            # Found exact match
            j = train_hash[htest]
            if not selfcheck or j != i:
                exact += 1
                reswrite(f"{i} {j} 0 0.0\n")
                continue

        # If not, then search
        if searchdist:
            minavgdist, mindist, minj = 100, 100, -1 
            for j, trainl in enumerate(train):
                if abs(len(trainl) - len(testl)) > 10:
                    dist = abs(len(trainl) - len(testl)) * 2 # HACK to speed it up
                else:
                    dist = distance(trainl, testl)
                avgdist = dist / (len(trainl) + len(testl))
                if mindist > dist:
                    minavgdist, mindist, minj = avgdist, dist, j

            reswrite(f"{i} {minj} {mindist} {minavgdist}\n")

    print("Exact duplicates:", exact)