|
import sys |
|
from tqdm import tqdm |
|
from Levenshtein import distance |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser("Find duplicates in the dataset ASM") |
|
parser.add_argument("--train", required=True) |
|
|
|
parser.add_argument("--test", required=True) |
|
parser.add_argument("--result", required=False) |
|
parser.add_argument("--distance", action="store_true", default=False) |
|
args = parser.parse_args() |
|
|
|
train = [] |
|
train_hash = {} |
|
|
|
test = [] |
|
with open(args.train, "r") as tf: |
|
for idx, line in tqdm(enumerate(tf), desc="Read train", leave=False): |
|
train_hash[hash(line)] = idx |
|
comps = line.strip().split(" ") |
|
train.append(comps) |
|
|
|
|
|
|
|
with open(args.test, "r") as tf: |
|
for line in tqdm(tf, desc="Read test", leave=False): |
|
test.append(line) |
|
|
|
selfcheck = args.train == args.test |
|
if args.result: |
|
rf = open(args.result, "w") |
|
searchdist = args.distance |
|
else: |
|
searchdist = False |
|
rf = None |
|
|
|
def reswrite(s): |
|
if rf: |
|
rf.write(s) |
|
|
|
exact = 0 |
|
for i, testline in tqdm(enumerate(test), desc="Test", total=len(test)): |
|
testl = testline.strip().split(" ") |
|
htest = hash(testline) |
|
if htest in train_hash: |
|
|
|
j = train_hash[htest] |
|
if not selfcheck or j != i: |
|
exact += 1 |
|
reswrite(f"{i} {j} 0 0.0\n") |
|
continue |
|
|
|
|
|
if searchdist: |
|
minavgdist, mindist, minj = 100, 100, -1 |
|
for j, trainl in enumerate(train): |
|
if abs(len(trainl) - len(testl)) > 10: |
|
dist = abs(len(trainl) - len(testl)) * 2 |
|
else: |
|
dist = distance(trainl, testl) |
|
avgdist = dist / (len(trainl) + len(testl)) |
|
if mindist > dist: |
|
minavgdist, mindist, minj = avgdist, dist, j |
|
|
|
reswrite(f"{i} {minj} {mindist} {minavgdist}\n") |
|
|
|
print("Exact duplicates:", exact) |
|
|