REMEND / remend /find_duplicates.py
udiboy1209's picture
Add REMEND python module
7145fd6
import sys
from tqdm import tqdm
from Levenshtein import distance
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Find duplicates in the dataset ASM")
parser.add_argument("--train", required=True)
# parser.add_argument("--valid", required=True)
parser.add_argument("--test", required=True)
parser.add_argument("--result", required=False)
parser.add_argument("--distance", action="store_true", default=False)
args = parser.parse_args()
train = []
train_hash = {}
# valid = []
test = []
with open(args.train, "r") as tf:
for idx, line in tqdm(enumerate(tf), desc="Read train", leave=False):
train_hash[hash(line)] = idx
comps = line.strip().split(" ")
train.append(comps)
# with open(args.valid, "r") as tf:
# for line in tqdm(tf, desc="Read valid", leave=False):
# valid.append(line.strip().split(" "))
with open(args.test, "r") as tf:
for line in tqdm(tf, desc="Read test", leave=False):
test.append(line)
selfcheck = args.train == args.test
if args.result:
rf = open(args.result, "w")
searchdist = args.distance
else:
searchdist = False # Dont compute if no result file
rf = None
def reswrite(s):
if rf:
rf.write(s)
exact = 0
for i, testline in tqdm(enumerate(test), desc="Test", total=len(test)):
testl = testline.strip().split(" ")
htest = hash(testline)
if htest in train_hash:
# Found exact match
j = train_hash[htest]
if not selfcheck or j != i:
exact += 1
reswrite(f"{i} {j} 0 0.0\n")
continue
# If not, then search
if searchdist:
minavgdist, mindist, minj = 100, 100, -1
for j, trainl in enumerate(train):
if abs(len(trainl) - len(testl)) > 10:
dist = abs(len(trainl) - len(testl)) * 2 # HACK to speed it up
else:
dist = distance(trainl, testl)
avgdist = dist / (len(trainl) + len(testl))
if mindist > dist:
minavgdist, mindist, minj = avgdist, dist, j
reswrite(f"{i} {minj} {mindist} {minavgdist}\n")
print("Exact duplicates:", exact)