import sys from tqdm import tqdm from Levenshtein import distance if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Find duplicates in the dataset ASM") parser.add_argument("--train", required=True) # parser.add_argument("--valid", required=True) parser.add_argument("--test", required=True) parser.add_argument("--result", required=False) parser.add_argument("--distance", action="store_true", default=False) args = parser.parse_args() train = [] train_hash = {} # valid = [] test = [] with open(args.train, "r") as tf: for idx, line in tqdm(enumerate(tf), desc="Read train", leave=False): train_hash[hash(line)] = idx comps = line.strip().split(" ") train.append(comps) # with open(args.valid, "r") as tf: # for line in tqdm(tf, desc="Read valid", leave=False): # valid.append(line.strip().split(" ")) with open(args.test, "r") as tf: for line in tqdm(tf, desc="Read test", leave=False): test.append(line) selfcheck = args.train == args.test if args.result: rf = open(args.result, "w") searchdist = args.distance else: searchdist = False # Dont compute if no result file rf = None def reswrite(s): if rf: rf.write(s) exact = 0 for i, testline in tqdm(enumerate(test), desc="Test", total=len(test)): testl = testline.strip().split(" ") htest = hash(testline) if htest in train_hash: # Found exact match j = train_hash[htest] if not selfcheck or j != i: exact += 1 reswrite(f"{i} {j} 0 0.0\n") continue # If not, then search if searchdist: minavgdist, mindist, minj = 100, 100, -1 for j, trainl in enumerate(train): if abs(len(trainl) - len(testl)) > 10: dist = abs(len(trainl) - len(testl)) * 2 # HACK to speed it up else: dist = distance(trainl, testl) avgdist = dist / (len(trainl) + len(testl)) if mindist > dist: minavgdist, mindist, minj = avgdist, dist, j reswrite(f"{i} {minj} {mindist} {minavgdist}\n") print("Exact duplicates:", exact)