Upload 11 files
Browse files- combined_sampler.py +30 -0
- hkpoly_evaluation_phase1.py +106 -0
- hkpoly_evaluation_phase2.py +114 -0
- loss.py +377 -0
- model.py +207 -0
- rb_evaluation_phase1.py +148 -0
- rb_evaluation_phase2.py +164 -0
- requirements.txt +202 -0
- train_combined.py +273 -0
- train_combined_fusion.py +301 -0
- utils.py +117 -0
combined_sampler.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.data.sampler import Sampler
|
| 5 |
+
from tqdm import *
|
| 6 |
+
|
| 7 |
+
class BalancedSampler(Sampler):
|
| 8 |
+
def __init__(self, data_source, batch_size, images_per_class=3):
|
| 9 |
+
self.data_source = data_source
|
| 10 |
+
self.ys = np.array(data_source.all_labels)
|
| 11 |
+
self.num_groups = batch_size // images_per_class
|
| 12 |
+
self.batch_size = batch_size
|
| 13 |
+
self.num_instances = images_per_class
|
| 14 |
+
self.num_samples = len(self.ys)
|
| 15 |
+
self.num_classes = len(set(self.ys))
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return self.num_samples
|
| 19 |
+
|
| 20 |
+
def __iter__(self):
|
| 21 |
+
num_batches = len(self.data_source) // self.batch_size
|
| 22 |
+
ret = []
|
| 23 |
+
while num_batches > 0:
|
| 24 |
+
sampled_classes = np.random.choice(self.num_classes, self.num_groups, replace=False)
|
| 25 |
+
for i in range(len(sampled_classes)):
|
| 26 |
+
ith_class_idxs = np.nonzero(self.ys == sampled_classes[i])[0]
|
| 27 |
+
class_sel = np.random.choice(ith_class_idxs, size=self.num_instances, replace=True)
|
| 28 |
+
ret.extend(np.random.permutation(class_sel))
|
| 29 |
+
num_batches -= 1
|
| 30 |
+
return iter(ret)
|
hkpoly_evaluation_phase1.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# script to evaluated HKPolyU testing dataset on finetuned model after phase 1
|
| 2 |
+
import torch
|
| 3 |
+
from datasets.hkpoly_test import hktest
|
| 4 |
+
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from model import SwinModel_domain_agnostic as Model
|
| 8 |
+
from sklearn.metrics import roc_curve, auc
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
def calculate_tar_at_far(fpr, tpr, target_fars):
|
| 12 |
+
tar_at_far = {}
|
| 13 |
+
for far in target_fars:
|
| 14 |
+
if far in fpr:
|
| 15 |
+
tar = tpr[np.where(fpr == far)][0]
|
| 16 |
+
else:
|
| 17 |
+
tar = np.interp(far, fpr, tpr)
|
| 18 |
+
tar_at_far[far] = tar
|
| 19 |
+
return tar_at_far
|
| 20 |
+
|
| 21 |
+
if __name__ == '__main__':
|
| 22 |
+
device = torch.device('cuda')
|
| 23 |
+
data = hktest(split = 'test')
|
| 24 |
+
dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
|
| 25 |
+
model = Model().to(device)
|
| 26 |
+
checkpoint = torch.load("ridgeformer_checkpoints/phase1_ft_hkpoly.pt",map_location = torch.device('cpu'))
|
| 27 |
+
model.load_state_dict(checkpoint,strict=False)
|
| 28 |
+
model.eval()
|
| 29 |
+
|
| 30 |
+
cl_feats, cb_feats, cl_labels, cb_labels, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list()
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
for (x_cl, x_cb, label) in tqdm(dataloader):
|
| 33 |
+
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
|
| 34 |
+
x_cl_feat, x_cl_token = model.get_embeddings(x_cl,'contactless')
|
| 35 |
+
x_cb_feat,x_cb_token = model.get_embeddings(x_cb,'contactbased')
|
| 36 |
+
cl_feats_unnormed.append(x_cl_feat.cpu().detach().numpy())
|
| 37 |
+
cb_feats_unnormed.append(x_cb_feat.cpu().detach().numpy())
|
| 38 |
+
x_cl_feat = l2_norm(x_cl_feat).cpu().detach().numpy()
|
| 39 |
+
x_cb_feat = l2_norm(x_cb_feat).cpu().detach().numpy()
|
| 40 |
+
label = label.cpu().detach().numpy()
|
| 41 |
+
cl_feats.append(x_cl_feat)
|
| 42 |
+
cb_feats.append(x_cb_feat)
|
| 43 |
+
cl_labels.append(label)
|
| 44 |
+
cb_labels.append(label)
|
| 45 |
+
|
| 46 |
+
cl_feats = np.concatenate(cl_feats)
|
| 47 |
+
cb_feats = np.concatenate(cb_feats)
|
| 48 |
+
cl_feats_unnormed = np.concatenate(cl_feats_unnormed)
|
| 49 |
+
cb_feats_unnormed = np.concatenate(cb_feats_unnormed)
|
| 50 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
| 51 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
| 52 |
+
|
| 53 |
+
# CB2CL
|
| 54 |
+
squared_diff = np.sum(np.square(cl_feats_unnormed[:, np.newaxis] - cb_feats_unnormed), axis=2)
|
| 55 |
+
distance = -1 * np.sqrt(squared_diff)
|
| 56 |
+
similarities = np.dot(cl_feats,np.transpose(cb_feats))
|
| 57 |
+
scores_mat = similarities + 0.1 * distance
|
| 58 |
+
|
| 59 |
+
scores = scores_mat.flatten().tolist()
|
| 60 |
+
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
|
| 61 |
+
ids_mod = list()
|
| 62 |
+
for i in labels:
|
| 63 |
+
if i==True:
|
| 64 |
+
ids_mod.append(1)
|
| 65 |
+
else:
|
| 66 |
+
ids_mod.append(0)
|
| 67 |
+
|
| 68 |
+
fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
|
| 69 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
| 70 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
| 71 |
+
tar_far_102 = tpr[upper_fpr_idx]
|
| 72 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx], thresh[lower_fpr_idx])
|
| 73 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx], thresh[upper_fpr_idx])
|
| 74 |
+
|
| 75 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
| 76 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
| 77 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 78 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
|
| 79 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
|
| 80 |
+
|
| 81 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
| 82 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
| 83 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 84 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
|
| 85 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
|
| 86 |
+
|
| 87 |
+
fnr = 1 - tpr
|
| 88 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 89 |
+
roc_auc = auc(fpr, tpr)
|
| 90 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
| 91 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
| 92 |
+
eer_cb2cl = EER * 100
|
| 93 |
+
|
| 94 |
+
cbcltf102 = tar_far_102 * 100
|
| 95 |
+
cbcltf103 = tar_far_103 * 100
|
| 96 |
+
cbcltf104 = tar_far_104 * 100
|
| 97 |
+
cl_label = cl_label.cpu().detach()
|
| 98 |
+
cb_label = cb_label.cpu().detach()
|
| 99 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
| 100 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
| 101 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
| 102 |
+
|
| 103 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 1) * 100} %")
|
| 104 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 10) * 100} %")
|
| 105 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 50) * 100} %")
|
| 106 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_label, cb_label, 100) * 100} %")
|
hkpoly_evaluation_phase2.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# script to evaluate HKPolyU testing dataset on finetuned model after phase 2
|
| 2 |
+
import torch
|
| 3 |
+
from datasets.hkpoly_test import hktest
|
| 4 |
+
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from model import SwinModel_Fusion as Model
|
| 8 |
+
from sklearn.metrics import roc_curve, auc
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
def calculate_tar_at_far(fpr, tpr, target_fars):
|
| 12 |
+
tar_at_far = {}
|
| 13 |
+
for far in target_fars:
|
| 14 |
+
if far in fpr:
|
| 15 |
+
tar = tpr[np.where(fpr == far)][0]
|
| 16 |
+
else:
|
| 17 |
+
tar = np.interp(far, fpr, tpr)
|
| 18 |
+
tar_at_far[far] = tar
|
| 19 |
+
return tar_at_far
|
| 20 |
+
|
| 21 |
+
def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
|
| 22 |
+
cl_tokens = torch.cat(cl_tokens)
|
| 23 |
+
cb_tokens = torch.cat(cb_tokens)
|
| 24 |
+
batch_size = cl_tokens.shape[0]
|
| 25 |
+
shard_size = 20
|
| 26 |
+
similarity_matrix = torch.zeros((batch_size, batch_size))
|
| 27 |
+
for i_start in tqdm(range(0, batch_size, shard_size)):
|
| 28 |
+
i_end = min(i_start + shard_size, batch_size)
|
| 29 |
+
shard_i = cl_tokens[i_start:i_end]
|
| 30 |
+
for j_start in range(0, batch_size, shard_size):
|
| 31 |
+
j_end = min(j_start + shard_size, batch_size)
|
| 32 |
+
shard_j = cb_tokens[j_start:j_end]
|
| 33 |
+
batch_i = shard_i.unsqueeze(1)
|
| 34 |
+
batch_j = shard_j.unsqueeze(0)
|
| 35 |
+
pairwise_i = batch_i.expand(-1, shard_size, -1, -1)
|
| 36 |
+
pairwise_j = batch_j.expand(shard_size, -1, -1, -1)
|
| 37 |
+
similarity_scores, distances = model.combine_features(pairwise_i.reshape(-1, 197, 1024), pairwise_j.reshape(-1, 197, 1024))
|
| 38 |
+
scores = similarity_scores - 0.1 * distances
|
| 39 |
+
scores = scores.reshape(shard_size, shard_size)
|
| 40 |
+
similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
|
| 41 |
+
return similarity_matrix
|
| 42 |
+
|
| 43 |
+
if __name__ == '__main__':
|
| 44 |
+
device = torch.device('cuda')
|
| 45 |
+
data = hktest(split = 'test')
|
| 46 |
+
dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
|
| 47 |
+
model = Model().to(device)
|
| 48 |
+
checkpoint = torch.load("ridgeformer_checkpoints/phase2_ft_hkpoly.pt",map_location = torch.device('cpu'))
|
| 49 |
+
model.load_state_dict(checkpoint,strict=False)
|
| 50 |
+
model.eval()
|
| 51 |
+
|
| 52 |
+
cl_feats, cb_feats, cl_labels, cb_labels, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list()
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
for (x_cl, x_cb, label) in tqdm(dataloader):
|
| 55 |
+
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
|
| 56 |
+
x_cl_token = model.get_tokens(x_cl,'contactless')
|
| 57 |
+
x_cb_token = model.get_tokens(x_cb,'contactbased')
|
| 58 |
+
label = label.cpu().detach().numpy()
|
| 59 |
+
cl_feats.append(x_cl_token)
|
| 60 |
+
cb_feats.append(x_cb_token)
|
| 61 |
+
cl_labels.append(label)
|
| 62 |
+
cb_labels.append(label)
|
| 63 |
+
|
| 64 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
| 65 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
| 66 |
+
|
| 67 |
+
# CB2CL
|
| 68 |
+
scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
|
| 69 |
+
scores = scores_mat.cpu().detach().numpy().flatten().tolist()
|
| 70 |
+
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
|
| 71 |
+
ids_mod = list()
|
| 72 |
+
for i in labels:
|
| 73 |
+
if i==True:
|
| 74 |
+
ids_mod.append(1)
|
| 75 |
+
else:
|
| 76 |
+
ids_mod.append(0)
|
| 77 |
+
|
| 78 |
+
fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
|
| 79 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
| 80 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
| 81 |
+
tar_far_102 = tpr[upper_fpr_idx]
|
| 82 |
+
|
| 83 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
| 84 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
| 85 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 86 |
+
|
| 87 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
| 88 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
| 89 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 90 |
+
|
| 91 |
+
fnr = 1 - tpr
|
| 92 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 93 |
+
roc_auc = auc(fpr, tpr)
|
| 94 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
| 95 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
| 96 |
+
eer_cb2cl = EER * 100
|
| 97 |
+
cbcltf102 = tar_far_102 * 100
|
| 98 |
+
cbcltf103 = tar_far_103 * 100
|
| 99 |
+
cbcltf104 = tar_far_104 * 100
|
| 100 |
+
cl_label = cl_label.cpu().detach()
|
| 101 |
+
cb_label = cb_label.cpu().detach()
|
| 102 |
+
|
| 103 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
| 104 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
| 105 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
| 106 |
+
|
| 107 |
+
recall_dict = dict()
|
| 108 |
+
for i in range(1,101):
|
| 109 |
+
recall_dict[i] = compute_recall_at_k(scores_mat, cl_label, cb_label, i)
|
| 110 |
+
|
| 111 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
|
| 112 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
|
| 113 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
|
| 114 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
|
loss.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pytorch_metric_learning import losses
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.init
|
| 5 |
+
import torchvision.models as models
|
| 6 |
+
from torch.autograd import Variable
|
| 7 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
| 8 |
+
from torch.nn.utils.weight_norm import weight_norm
|
| 9 |
+
import torch.backends.cudnn as cudnn
|
| 10 |
+
from torch.nn.utils.clip_grad import clip_grad_norm
|
| 11 |
+
import numpy as np
|
| 12 |
+
import os
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import itertools
|
| 15 |
+
|
| 16 |
+
torch.autograd.set_detect_anomaly(True)
|
| 17 |
+
|
| 18 |
+
class DualMSLoss_FineGrained(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Compute contrastive loss
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, margin=0, max_violation=False):
|
| 23 |
+
super(DualMSLoss_FineGrained, self).__init__()
|
| 24 |
+
self.margin = margin
|
| 25 |
+
self.max_violation = max_violation
|
| 26 |
+
self.thresh = 0.5
|
| 27 |
+
self.margin = 0.7 # 0.1
|
| 28 |
+
self.scale_pos = 2
|
| 29 |
+
self.scale_neg = 40.0
|
| 30 |
+
|
| 31 |
+
def ms_sample(self,sim_mat,label):
|
| 32 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 33 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 34 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 35 |
+
neg_mask = 1 - pos_mask
|
| 36 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 37 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 38 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 39 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 40 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 41 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 42 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 43 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 44 |
+
|
| 45 |
+
return pos_loss + neg_loss
|
| 46 |
+
|
| 47 |
+
def ms_sample_cbcb_clcl(self,sim_mat,label):
|
| 48 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 49 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 50 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 51 |
+
|
| 52 |
+
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
|
| 53 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 54 |
+
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 55 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 56 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 57 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 58 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 59 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 60 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 61 |
+
|
| 62 |
+
return pos_loss + neg_loss
|
| 63 |
+
|
| 64 |
+
def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
|
| 65 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 66 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 67 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 68 |
+
|
| 69 |
+
n_sha = pos_mask.shape[0]
|
| 70 |
+
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
|
| 71 |
+
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
|
| 72 |
+
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
|
| 73 |
+
|
| 74 |
+
neg_mask = 1-pos_mask
|
| 75 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 76 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 77 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 78 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 79 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 80 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 81 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 82 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 83 |
+
|
| 84 |
+
return pos_loss + neg_loss
|
| 85 |
+
|
| 86 |
+
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
|
| 87 |
+
B, T, D = tensor1.shape
|
| 88 |
+
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
|
| 89 |
+
|
| 90 |
+
for start_idx in range(0, T, shard_size):
|
| 91 |
+
end_idx = min(start_idx + shard_size, T)
|
| 92 |
+
|
| 93 |
+
# Get the shard
|
| 94 |
+
shard_tensor1 = tensor1[:, start_idx:end_idx, :]
|
| 95 |
+
shard_tensor2 = tensor2[:, start_idx:end_idx, :]
|
| 96 |
+
|
| 97 |
+
# Reshape and expand
|
| 98 |
+
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
|
| 99 |
+
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
|
| 100 |
+
|
| 101 |
+
# Compute cosine similarity for the shard
|
| 102 |
+
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
|
| 103 |
+
|
| 104 |
+
# Accumulate the sum of cosine similarities
|
| 105 |
+
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
|
| 106 |
+
|
| 107 |
+
# Normalize by the total number of elements (T*T)
|
| 108 |
+
average_sim_matrix /= (T * T)
|
| 109 |
+
|
| 110 |
+
return average_sim_matrix
|
| 111 |
+
|
| 112 |
+
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device):
|
| 113 |
+
|
| 114 |
+
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
|
| 115 |
+
n = sim_mat_clcl.shape[0]
|
| 116 |
+
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
|
| 117 |
+
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
|
| 118 |
+
|
| 119 |
+
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
|
| 120 |
+
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
|
| 121 |
+
|
| 122 |
+
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
|
| 123 |
+
return loss4 + loss2 + loss3#+ (1.5*loss2) + (1.5*loss3) # + loss2 + loss3#+ loss5 # 0.1*loss5 + loss3
|
| 124 |
+
|
| 125 |
+
def l2_norm(self, input):
|
| 126 |
+
input_size = input.size()
|
| 127 |
+
buffer = torch.pow(input, 2)
|
| 128 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
| 129 |
+
norm = torch.sqrt(normp)
|
| 130 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
| 131 |
+
output = _output.view(input_size)
|
| 132 |
+
|
| 133 |
+
return output
|
| 134 |
+
|
| 135 |
+
class DualMSLoss_FineGrained_domain_agnostic(nn.Module):
|
| 136 |
+
"""
|
| 137 |
+
Compute contrastive loss
|
| 138 |
+
"""
|
| 139 |
+
def __init__(self, margin=0, max_violation=False):
|
| 140 |
+
super(DualMSLoss_FineGrained_domain_agnostic, self).__init__()
|
| 141 |
+
self.margin = margin
|
| 142 |
+
self.max_violation = max_violation
|
| 143 |
+
self.thresh = 0.5
|
| 144 |
+
self.margin = 0.5 # 0.1
|
| 145 |
+
self.scale_pos = 2
|
| 146 |
+
self.scale_neg = 40.0
|
| 147 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 148 |
+
|
| 149 |
+
def ms_sample(self,sim_mat,label):
|
| 150 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 151 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 152 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 153 |
+
neg_mask = 1 - pos_mask
|
| 154 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 155 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 156 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 157 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 158 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 159 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 160 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 161 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 162 |
+
|
| 163 |
+
return pos_loss + neg_loss
|
| 164 |
+
|
| 165 |
+
def ms_sample_cbcb_clcl(self,sim_mat,label):
|
| 166 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 167 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 168 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 169 |
+
|
| 170 |
+
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
|
| 171 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 172 |
+
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 173 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 174 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 175 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 176 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 177 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 178 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 179 |
+
|
| 180 |
+
return pos_loss + neg_loss
|
| 181 |
+
|
| 182 |
+
def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
|
| 183 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 184 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 185 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 186 |
+
|
| 187 |
+
n_sha = pos_mask.shape[0]
|
| 188 |
+
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
|
| 189 |
+
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
|
| 190 |
+
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
|
| 191 |
+
|
| 192 |
+
neg_mask = 1-pos_mask
|
| 193 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 194 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 195 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 196 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 197 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 198 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 199 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 200 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 201 |
+
|
| 202 |
+
return pos_loss + neg_loss
|
| 203 |
+
|
| 204 |
+
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
|
| 205 |
+
B, T, D = tensor1.shape
|
| 206 |
+
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
|
| 207 |
+
|
| 208 |
+
for start_idx in range(0, T, shard_size):
|
| 209 |
+
end_idx = min(start_idx + shard_size, T)
|
| 210 |
+
|
| 211 |
+
# Get the shard
|
| 212 |
+
shard_tensor1 = tensor1[:, start_idx:end_idx, :]
|
| 213 |
+
shard_tensor2 = tensor2[:, start_idx:end_idx, :]
|
| 214 |
+
|
| 215 |
+
# Reshape and expand
|
| 216 |
+
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
|
| 217 |
+
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
|
| 218 |
+
|
| 219 |
+
# Compute cosine similarity for the shard
|
| 220 |
+
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
|
| 221 |
+
|
| 222 |
+
# Accumulate the sum of cosine similarities
|
| 223 |
+
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
|
| 224 |
+
|
| 225 |
+
# Normalize by the total number of elements (T*T)
|
| 226 |
+
average_sim_matrix /= (T * T)
|
| 227 |
+
|
| 228 |
+
return average_sim_matrix
|
| 229 |
+
|
| 230 |
+
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device, domain_class_cl, domain_class_cb, domain_class_cl_gt, domain_class_cb_gt):
|
| 231 |
+
|
| 232 |
+
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
|
| 233 |
+
n = sim_mat_clcl.shape[0]
|
| 234 |
+
|
| 235 |
+
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
|
| 236 |
+
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
|
| 237 |
+
|
| 238 |
+
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
|
| 239 |
+
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
|
| 240 |
+
|
| 241 |
+
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
|
| 242 |
+
|
| 243 |
+
pred = torch.cat([domain_class_cl, domain_class_cb])
|
| 244 |
+
gt = torch.cat([domain_class_cl_gt, domain_class_cb_gt])
|
| 245 |
+
|
| 246 |
+
domain_class_loss = self.criterion(pred,gt)
|
| 247 |
+
return loss4 + loss2 + loss3 + (3*domain_class_loss)
|
| 248 |
+
|
| 249 |
+
def l2_norm(self, input):
|
| 250 |
+
input_size = input.size()
|
| 251 |
+
buffer = torch.pow(input, 2)
|
| 252 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
| 253 |
+
norm = torch.sqrt(normp)
|
| 254 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
| 255 |
+
output = _output.view(input_size)
|
| 256 |
+
|
| 257 |
+
return output
|
| 258 |
+
|
| 259 |
+
class DualMSLoss_FineGrained_domain_agnostic_ft(nn.Module):
|
| 260 |
+
"""
|
| 261 |
+
Compute contrastive loss
|
| 262 |
+
"""
|
| 263 |
+
def __init__(self, margin=0, max_violation=False):
|
| 264 |
+
super(DualMSLoss_FineGrained_domain_agnostic_ft, self).__init__()
|
| 265 |
+
self.margin = margin
|
| 266 |
+
self.max_violation = max_violation
|
| 267 |
+
self.thresh = 0.5
|
| 268 |
+
self.margin = 0.7 # 0.1
|
| 269 |
+
self.scale_pos = 2
|
| 270 |
+
self.scale_neg = 40.0
|
| 271 |
+
self.criterion = nn.CrossEntropyLoss()
|
| 272 |
+
|
| 273 |
+
def ms_sample(self,sim_mat,label):
|
| 274 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 275 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 276 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 277 |
+
neg_mask = 1 - pos_mask
|
| 278 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 279 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 280 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 281 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 282 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 283 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 284 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 285 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 286 |
+
|
| 287 |
+
return pos_loss + neg_loss
|
| 288 |
+
|
| 289 |
+
def ms_sample_cbcb_clcl(self,sim_mat,label):
|
| 290 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 291 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 292 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 293 |
+
|
| 294 |
+
pos_mask = pos_mask + torch.eye(pos_mask.shape[0]).cuda()
|
| 295 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 296 |
+
N_sim = torch.where(pos_mask == 0,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 297 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 298 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 299 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 300 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 301 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 302 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 303 |
+
|
| 304 |
+
return pos_loss + neg_loss
|
| 305 |
+
|
| 306 |
+
def ms_sample_cbcb_clcl_trans(self,sim_mat,label):
|
| 307 |
+
pos_exp = torch.exp(-self.scale_pos*(sim_mat-self.thresh))
|
| 308 |
+
neg_exp = torch.exp( self.scale_neg*(sim_mat-self.thresh))
|
| 309 |
+
pos_mask = torch.eq(label.view(-1,1)-label.view(1,-1),0.0).float().cuda()
|
| 310 |
+
|
| 311 |
+
n_sha = pos_mask.shape[0]
|
| 312 |
+
mask_pos = torch.ones(n_sha, n_sha, dtype=torch.bool)
|
| 313 |
+
mask_pos = mask_pos.triu(1) | mask_pos.tril(-1)
|
| 314 |
+
pos_mask = torch.transpose(torch.transpose(pos_mask[mask_pos].reshape(n_sha, n_sha-1),0,1),0,1)
|
| 315 |
+
|
| 316 |
+
neg_mask = 1-pos_mask
|
| 317 |
+
P_sim = torch.where(pos_mask == 1,sim_mat,torch.ones_like(pos_exp)*1e16)
|
| 318 |
+
N_sim = torch.where(neg_mask == 1,sim_mat,torch.ones_like(neg_exp)*-1e16)
|
| 319 |
+
min_P_sim,_ = torch.min(P_sim,dim=1,keepdim=True)
|
| 320 |
+
max_N_sim,_ = torch.max(N_sim,dim=1,keepdim=True)
|
| 321 |
+
hard_P_sim = torch.where(P_sim - self.margin < max_N_sim,pos_exp,torch.zeros_like(pos_exp)).sum(dim=-1)
|
| 322 |
+
hard_N_sim = torch.where(N_sim + self.margin > min_P_sim,neg_exp,torch.zeros_like(neg_exp)).sum(dim=-1)
|
| 323 |
+
pos_loss = torch.log(1+hard_P_sim).sum()/self.scale_pos
|
| 324 |
+
neg_loss = torch.log(1+hard_N_sim).sum()/self.scale_neg
|
| 325 |
+
|
| 326 |
+
return pos_loss + neg_loss
|
| 327 |
+
|
| 328 |
+
def compute_sharded_cosine_similarity(self, tensor1, tensor2, shard_size):
|
| 329 |
+
B, T, D = tensor1.shape
|
| 330 |
+
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
|
| 331 |
+
|
| 332 |
+
for start_idx in range(0, T, shard_size):
|
| 333 |
+
end_idx = min(start_idx + shard_size, T)
|
| 334 |
+
|
| 335 |
+
# Get the shard
|
| 336 |
+
shard_tensor1 = tensor1[:, start_idx:end_idx, :]
|
| 337 |
+
shard_tensor2 = tensor2[:, start_idx:end_idx, :]
|
| 338 |
+
|
| 339 |
+
# Reshape and expand
|
| 340 |
+
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
|
| 341 |
+
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
|
| 342 |
+
|
| 343 |
+
# Compute cosine similarity for the shard
|
| 344 |
+
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
|
| 345 |
+
|
| 346 |
+
# Accumulate the sum of cosine similarities
|
| 347 |
+
average_sim_matrix += torch.sum(shard_cos_sim, dim=[2, 3])
|
| 348 |
+
|
| 349 |
+
# Normalize by the total number of elements (T*T)
|
| 350 |
+
average_sim_matrix /= (T * T)
|
| 351 |
+
|
| 352 |
+
return average_sim_matrix
|
| 353 |
+
|
| 354 |
+
def forward(self, x_contactless, x_contactbased, x_cl_tokens, x_cb_tokens, labels, device, domain_class_cl, domain_class_cb, domain_class_cl_gt, domain_class_cb_gt):
|
| 355 |
+
|
| 356 |
+
sim_mat_clcl = F.linear(self.l2_norm(x_contactless), self.l2_norm(x_contactless))
|
| 357 |
+
n = sim_mat_clcl.shape[0]
|
| 358 |
+
|
| 359 |
+
sim_mat_cbcb = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactbased))
|
| 360 |
+
sim_mat_cbcl = F.linear(self.l2_norm(x_contactbased), self.l2_norm(x_contactless))
|
| 361 |
+
|
| 362 |
+
loss2 = self.ms_sample_cbcb_clcl(sim_mat_clcl, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_clcl.t(), labels).cuda()
|
| 363 |
+
loss3 = self.ms_sample_cbcb_clcl(sim_mat_cbcb, labels).cuda() + self.ms_sample_cbcb_clcl(sim_mat_cbcb.t(), labels).cuda()
|
| 364 |
+
|
| 365 |
+
loss4 = self.ms_sample(sim_mat_cbcl, labels).cuda() + self.ms_sample(sim_mat_cbcl.t(), labels).cuda()
|
| 366 |
+
|
| 367 |
+
return loss4 + loss2 + loss3
|
| 368 |
+
|
| 369 |
+
def l2_norm(self, input):
|
| 370 |
+
input_size = input.size()
|
| 371 |
+
buffer = torch.pow(input, 2)
|
| 372 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
| 373 |
+
norm = torch.sqrt(normp)
|
| 374 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
| 375 |
+
output = _output.view(input_size)
|
| 376 |
+
|
| 377 |
+
return output
|
model.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torchvision import datasets, transforms
|
| 8 |
+
from torch.optim.lr_scheduler import StepLR
|
| 9 |
+
import torchvision.models as models
|
| 10 |
+
import timm
|
| 11 |
+
from pprint import pprint
|
| 12 |
+
import numpy as np
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from torch.utils.data.sampler import BatchSampler
|
| 15 |
+
from gradient_reversal.module import GradientReversal
|
| 16 |
+
|
| 17 |
+
class SwinModel(nn.Module):
|
| 18 |
+
def __init__(self):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
|
| 21 |
+
self.swin_cb = self.swin_cl
|
| 22 |
+
|
| 23 |
+
self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
nn.Linear(1024, 1024))
|
| 26 |
+
self.linear_cb = nn.Linear(1024, 1024)
|
| 27 |
+
|
| 28 |
+
def freeze_encoder(self):
|
| 29 |
+
for param in self.swin_cl.parameters():
|
| 30 |
+
param.requires_grad = False
|
| 31 |
+
for param in self.swin_cb.parameters():
|
| 32 |
+
param.requires_grad = False
|
| 33 |
+
|
| 34 |
+
def unfreeze_encoder(self):
|
| 35 |
+
for param in self.swin_cl.parameters():
|
| 36 |
+
param.requires_grad = True
|
| 37 |
+
for param in self.swin_cb.parameters():
|
| 38 |
+
param.requires_grad = True
|
| 39 |
+
|
| 40 |
+
def get_embeddings(self, image, ftype):
|
| 41 |
+
linear = self.linear_cl if ftype == "contactless" else self.linear_cl
|
| 42 |
+
swin = self.swin_cl if ftype == "contactless" else self.swin_cb
|
| 43 |
+
|
| 44 |
+
tokens = swin(image)
|
| 45 |
+
emb_mean = tokens.mean(dim=1)
|
| 46 |
+
feat = linear(emb_mean)
|
| 47 |
+
tokens_transformed = linear(tokens)
|
| 48 |
+
return feat, tokens
|
| 49 |
+
|
| 50 |
+
def forward(self, x_cl, x_cb):
|
| 51 |
+
x_cl_tokens = self.swin_cl(x_cl)
|
| 52 |
+
x_cb_tokens = self.swin_cb(x_cb)
|
| 53 |
+
|
| 54 |
+
x_cl_mean = x_cl_tokens.mean(dim=1)
|
| 55 |
+
x_cb_mean = x_cb_tokens.mean(dim=1)
|
| 56 |
+
|
| 57 |
+
x_cl = self.linear_cl(x_cl_mean)
|
| 58 |
+
x_cl_tokens_transformed = self.linear_cl(x_cl_tokens)
|
| 59 |
+
|
| 60 |
+
x_cb = self.linear_cl(x_cb_mean)
|
| 61 |
+
x_cb_tokens_transformed = self.linear_cl(x_cb_tokens)
|
| 62 |
+
|
| 63 |
+
return x_cl, x_cb, x_cl_tokens, x_cb_tokens
|
| 64 |
+
|
| 65 |
+
class SwinModel_domain_agnostic(nn.Module):
|
| 66 |
+
def __init__(self):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
|
| 69 |
+
self.swin_cb = self.swin_cl #timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
|
| 70 |
+
|
| 71 |
+
self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
|
| 72 |
+
nn.ReLU(),
|
| 73 |
+
nn.Linear(1024, 1024))
|
| 74 |
+
self.linear_cb = nn.Linear(1024, 1024)
|
| 75 |
+
self.classify = nn.Sequential(GradientReversal(alpha=0.6), # original 0.8
|
| 76 |
+
nn.Linear(1024,512),
|
| 77 |
+
nn.ReLU(),
|
| 78 |
+
nn.Linear(512,8))
|
| 79 |
+
|
| 80 |
+
def freeze_encoder(self):
|
| 81 |
+
for param in self.swin_cl.parameters():
|
| 82 |
+
param.requires_grad = False
|
| 83 |
+
for param in self.swin_cb.parameters():
|
| 84 |
+
param.requires_grad = False
|
| 85 |
+
|
| 86 |
+
def unfreeze_encoder(self):
|
| 87 |
+
for param in self.swin_cl.parameters():
|
| 88 |
+
param.requires_grad = True
|
| 89 |
+
for param in self.swin_cb.parameters():
|
| 90 |
+
param.requires_grad = True
|
| 91 |
+
|
| 92 |
+
def get_embeddings(self, image, ftype):
|
| 93 |
+
linear = self.linear_cl if ftype == "contactless" else self.linear_cl
|
| 94 |
+
swin = self.swin_cl if ftype == "contactless" else self.swin_cb
|
| 95 |
+
|
| 96 |
+
tokens = swin(image)
|
| 97 |
+
emb_mean = tokens.mean(dim=1)
|
| 98 |
+
feat = linear(emb_mean)
|
| 99 |
+
tokens_transformed = linear(tokens)
|
| 100 |
+
return feat, tokens
|
| 101 |
+
|
| 102 |
+
def forward(self, x_cl, x_cb):
|
| 103 |
+
x_cl_tokens = self.swin_cl(x_cl)
|
| 104 |
+
x_cb_tokens = self.swin_cb(x_cb)
|
| 105 |
+
|
| 106 |
+
x_cl_mean = x_cl_tokens.mean(dim=1)
|
| 107 |
+
x_cb_mean = x_cb_tokens.mean(dim=1)
|
| 108 |
+
|
| 109 |
+
x_cl = self.linear_cl(x_cl_mean)
|
| 110 |
+
x_cl_tokens_transformed = self.linear_cl(x_cl_tokens)
|
| 111 |
+
|
| 112 |
+
x_cb = self.linear_cl(x_cb_mean)
|
| 113 |
+
x_cb_tokens_transformed = self.linear_cl(x_cb_tokens)
|
| 114 |
+
|
| 115 |
+
domain_class_cl = self.classify(x_cl_mean)
|
| 116 |
+
domain_class_cb = self.classify(x_cb_mean)
|
| 117 |
+
|
| 118 |
+
return x_cl, x_cb, x_cl_tokens, x_cb_tokens, domain_class_cl, domain_class_cb
|
| 119 |
+
|
| 120 |
+
class SwinModel_Fusion(nn.Module):
|
| 121 |
+
def __init__(self):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.feature_dim = 1024
|
| 124 |
+
self.swin_cl = timm.create_model('vit_large_patch16_224_in21k', pretrained=True, num_classes=0)
|
| 125 |
+
self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.feature_dim, nhead=4, dropout=0.5, batch_first=True, norm_first=True, activation="gelu")
|
| 126 |
+
self.fusion = nn.TransformerEncoder(self.encoder_layer, num_layers=2)
|
| 127 |
+
self.sep_token = nn.Parameter(torch.randn(1, 1, self.feature_dim))
|
| 128 |
+
self.output_logit_mlp = nn.Sequential(nn.Linear(1024, 512),
|
| 129 |
+
nn.ReLU(),
|
| 130 |
+
nn.Dropout(),
|
| 131 |
+
nn.Linear(512, 1))
|
| 132 |
+
self.linear_cl = nn.Sequential(nn.Linear(1024, 1024),
|
| 133 |
+
nn.ReLU(),
|
| 134 |
+
nn.Linear(1024, 1024))
|
| 135 |
+
|
| 136 |
+
def load_pretrained_models(self, swin_cl_path, fusion_ckpt_path):
|
| 137 |
+
swin_cl_state_dict = torch.load(swin_cl_path)
|
| 138 |
+
new_dict = {}
|
| 139 |
+
for key in swin_cl_state_dict.keys():
|
| 140 |
+
if "swin_cl" in key:
|
| 141 |
+
new_dict[key.replace("swin_cl.","")] = swin_cl_state_dict[key]
|
| 142 |
+
self.swin_cl.load_state_dict(new_dict)
|
| 143 |
+
|
| 144 |
+
fusion_params = torch.load(fusion_ckpt_path)
|
| 145 |
+
new_dict = {}
|
| 146 |
+
for key in fusion_params.keys():
|
| 147 |
+
if "encoder_layer" in key:
|
| 148 |
+
new_dict[key.replace("encoder_layer.","")] = fusion_params[key]
|
| 149 |
+
self.encoder_layer.load_state_dict(new_dict)
|
| 150 |
+
|
| 151 |
+
new_dict = {}
|
| 152 |
+
for key in fusion_params.keys():
|
| 153 |
+
if "fusion" in key:
|
| 154 |
+
new_dict[key.replace("fusion.","")] = fusion_params[key]
|
| 155 |
+
self.fusion.load_state_dict(new_dict)
|
| 156 |
+
|
| 157 |
+
self.sep_token = nn.Parameter(fusion_params["sep_token"])
|
| 158 |
+
|
| 159 |
+
new_dict = {}
|
| 160 |
+
for key in fusion_params.keys():
|
| 161 |
+
if "output_logit_mlp" in key:
|
| 162 |
+
new_dict[key.replace("output_logit_mlp.","")] = fusion_params[key]
|
| 163 |
+
self.output_logit_mlp.load_state_dict(new_dict)
|
| 164 |
+
|
| 165 |
+
def l2_norm(self,input):
|
| 166 |
+
input_size = input.shape[0]
|
| 167 |
+
buffer = torch.pow(input, 2)
|
| 168 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
| 169 |
+
norm = torch.sqrt(normp)
|
| 170 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
| 171 |
+
return _output
|
| 172 |
+
|
| 173 |
+
def combine_features(self, fingerprint_1_tokens, fingerprint_2_tokens):
|
| 174 |
+
# This function takes a pair of embeddings [B, 49, 1024], [B, 49, 1024] and returns a B logit scores [B]
|
| 175 |
+
# fingerprint_1_tokens = self.linear_cl(fingerprint_1_tokens)
|
| 176 |
+
# fingerprint_2_tokens = self.linear_cl(fingerprint_2_tokens)
|
| 177 |
+
batch_size = fingerprint_1_tokens.shape[0]
|
| 178 |
+
sep_token = self.sep_token.repeat(batch_size, 1, 1)
|
| 179 |
+
combine_features = torch.cat((fingerprint_1_tokens, sep_token, fingerprint_2_tokens), dim=1)
|
| 180 |
+
fused_match_representation = self.fusion(combine_features)
|
| 181 |
+
fingerprint_1 = fused_match_representation[:,:197,:].mean(dim=1)
|
| 182 |
+
fingerprint_2 = fused_match_representation[:,198:,:].mean(dim=1)
|
| 183 |
+
|
| 184 |
+
fingerprint_1_norm = self.l2_norm(fingerprint_1)
|
| 185 |
+
fingerprint_2_norm = self.l2_norm(fingerprint_2)
|
| 186 |
+
|
| 187 |
+
similarities = torch.sum(fingerprint_1_norm * fingerprint_2_norm, axis=1)
|
| 188 |
+
|
| 189 |
+
differences = fingerprint_1 - fingerprint_2
|
| 190 |
+
squared_differences = differences ** 2
|
| 191 |
+
sum_squared_differences = torch.sum(squared_differences, axis=1)
|
| 192 |
+
distances = torch.sqrt(sum_squared_differences)
|
| 193 |
+
return similarities, distances
|
| 194 |
+
|
| 195 |
+
def get_tokens(self, image, ftype):
|
| 196 |
+
swin = self.swin_cl
|
| 197 |
+
tokens = swin(image)
|
| 198 |
+
return tokens
|
| 199 |
+
|
| 200 |
+
def freeze_backbone(self):
|
| 201 |
+
for param in self.swin_cl.parameters():
|
| 202 |
+
param.requires_grad = False
|
| 203 |
+
|
| 204 |
+
def forward(self, x_cl, x_cb):
|
| 205 |
+
x_cl_tokens = self.swin_cl(x_cl)
|
| 206 |
+
x_cb_tokens = self.swin_cl(x_cb)
|
| 207 |
+
return x_cl_tokens, x_cb_tokens
|
rb_evaluation_phase1.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from datasets.rb_loader import RB_loader
|
| 3 |
+
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from model import SwinModel_domain_agnostic as Model
|
| 7 |
+
from sklearn.metrics import roc_curve, auc
|
| 8 |
+
import json
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
if __name__ == '__main__':
|
| 12 |
+
device = torch.device('cuda')
|
| 13 |
+
data = RB_loader(split = 'test')
|
| 14 |
+
dataloader = torch.utils.data.DataLoader(data,batch_size = 16, num_workers = 1, pin_memory = True)
|
| 15 |
+
model = Model().to(device)
|
| 16 |
+
checkpoint = torch.load("ridgeformer_checkpoints/phase1_scratch.pt",map_location = torch.device('cpu'))
|
| 17 |
+
model.load_state_dict(checkpoint,strict=False)
|
| 18 |
+
|
| 19 |
+
model.eval()
|
| 20 |
+
cl_feats, cb_feats, cl_labels, cb_labels, cl_fnames, cb_fnames, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list(),list(),list()
|
| 21 |
+
print("Computing Test Recall")
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
for (x_cl, x_cb, target, cl_fname, cb_fname) in tqdm(dataloader):
|
| 24 |
+
x_cl, x_cb, target = x_cl.to(device), x_cb.to(device), target.to(device)
|
| 25 |
+
x_cl, _ = model.get_embeddings(x_cl, ftype="contactless")
|
| 26 |
+
x_cb, _ = model.get_embeddings(x_cb, ftype="contactbased")
|
| 27 |
+
cl_feats_unnormed.append(x_cl.cpu().detach().numpy())
|
| 28 |
+
cb_feats_unnormed.append(x_cb.cpu().detach().numpy())
|
| 29 |
+
x_cl = l2_norm(x_cl).cpu().detach().numpy()
|
| 30 |
+
x_cb = l2_norm(x_cb).cpu().detach().numpy()
|
| 31 |
+
target = target.cpu().detach().numpy()
|
| 32 |
+
cl_feats.append(x_cl)
|
| 33 |
+
cb_feats.append(x_cb)
|
| 34 |
+
cl_labels.append(target)
|
| 35 |
+
cb_labels.append(target)
|
| 36 |
+
cl_fnames.extend(cl_fname)
|
| 37 |
+
cb_fnames.extend(cb_fname)
|
| 38 |
+
|
| 39 |
+
cl_feats = torch.from_numpy(np.concatenate(cl_feats))
|
| 40 |
+
cb_feats = torch.from_numpy(np.concatenate(cb_feats))
|
| 41 |
+
cl_labels = torch.from_numpy(np.concatenate(cl_labels))
|
| 42 |
+
cb_labels = torch.from_numpy(np.concatenate(cb_labels))
|
| 43 |
+
cl_feats_unnormed = torch.from_numpy(np.concatenate(cl_feats_unnormed))
|
| 44 |
+
cb_feats_unnormed = torch.from_numpy(np.concatenate(cb_feats_unnormed))
|
| 45 |
+
|
| 46 |
+
unique_labels, indices = torch.unique(cb_labels, return_inverse=True)
|
| 47 |
+
unique_feats = torch.stack([cb_feats[indices == i].mean(dim=0) for i in range(len(unique_labels))])
|
| 48 |
+
cb_feats = unique_feats
|
| 49 |
+
unique_labels, indices = torch.unique(cb_labels, return_inverse=True)
|
| 50 |
+
unique_feats = torch.stack([cb_feats_unnormed[indices == i].mean(dim=0) for i in range(len(unique_labels))])
|
| 51 |
+
cb_labels = unique_labels
|
| 52 |
+
cb_feats_unnormed = unique_feats
|
| 53 |
+
|
| 54 |
+
# CL2CB <---------------------------------------->
|
| 55 |
+
cl_feats = cl_feats.numpy()
|
| 56 |
+
cb_feats = cb_feats.numpy()
|
| 57 |
+
cb_feats_unnormed = cb_feats_unnormed.numpy()
|
| 58 |
+
cl_feats_unnormed = cl_feats_unnormed.numpy()
|
| 59 |
+
|
| 60 |
+
squared_diff = np.sum(np.square(cl_feats_unnormed[:, np.newaxis] - cb_feats_unnormed), axis=2)
|
| 61 |
+
distance = -1 * np.sqrt(squared_diff)
|
| 62 |
+
similarities = np.dot(cl_feats,np.transpose(cb_feats))
|
| 63 |
+
scores_mat = similarities + 0.1 * distance
|
| 64 |
+
scores = scores_mat.flatten().tolist()
|
| 65 |
+
|
| 66 |
+
ids = torch.eq(cl_labels.view(-1,1)-cb_labels.view(1,-1),0.0).flatten().tolist()
|
| 67 |
+
ids_mod = list()
|
| 68 |
+
for x in ids:
|
| 69 |
+
if x==True:
|
| 70 |
+
ids_mod.append(1)
|
| 71 |
+
else:
|
| 72 |
+
ids_mod.append(0)
|
| 73 |
+
fpr,tpr,thresh = roc_curve(ids_mod,scores,drop_intermediate=True)
|
| 74 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
| 75 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
| 76 |
+
tar_far_102 = tpr[upper_fpr_idx]
|
| 77 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx], thresh[lower_fpr_idx])
|
| 78 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx], thresh[upper_fpr_idx])
|
| 79 |
+
|
| 80 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
| 81 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
| 82 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 83 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
|
| 84 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
|
| 85 |
+
|
| 86 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
| 87 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
| 88 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 89 |
+
print(tpr[lower_fpr_idx], lower_fpr_idx, fpr[lower_fpr_idx])
|
| 90 |
+
print(tpr[upper_fpr_idx], upper_fpr_idx, fpr[upper_fpr_idx])
|
| 91 |
+
|
| 92 |
+
fnr = 1 - tpr
|
| 93 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 94 |
+
roc_auc = auc(fpr, tpr)
|
| 95 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
| 96 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
| 97 |
+
eer_cb2cl = EER * 100
|
| 98 |
+
cbcltf102 = tar_far_102 * 100
|
| 99 |
+
cbcltf103 = tar_far_103 * 100
|
| 100 |
+
cbcltf104 = tar_far_104 * 100
|
| 101 |
+
cl_labels = cl_labels.cpu().detach()
|
| 102 |
+
cb_labels = cb_labels.cpu().detach()
|
| 103 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
| 104 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
| 105 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
| 106 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 1) * 100} %")
|
| 107 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 10) * 100} %")
|
| 108 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 50) * 100} %")
|
| 109 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(torch.from_numpy(scores_mat), cl_labels, cb_labels, 100) * 100} %")
|
| 110 |
+
|
| 111 |
+
################################################################################
|
| 112 |
+
|
| 113 |
+
# CL2CL
|
| 114 |
+
scores = torch.from_numpy(np.dot(cl_feats,np.transpose(cl_feats)))
|
| 115 |
+
row, col = torch.triu_indices(row=scores.size(0), col=scores.size(1), offset=1)
|
| 116 |
+
scores = scores[row, col]
|
| 117 |
+
scores = scores.numpy().flatten().tolist()
|
| 118 |
+
labels = torch.eq(cl_labels.view(-1,1) - cl_labels.view(1,-1),0.0).float().cuda()
|
| 119 |
+
labels = labels[torch.triu(torch.ones(labels.shape),diagonal = 1) == 1].tolist()
|
| 120 |
+
fpr,tpr,_ = roc_curve(labels,scores)
|
| 121 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
| 122 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
| 123 |
+
tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 124 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
| 125 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
| 126 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 127 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
| 128 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
| 129 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 130 |
+
clcltf102 = tar_far_102 * 100
|
| 131 |
+
clcltf103 = tar_far_103 * 100
|
| 132 |
+
clcltf104 = tar_far_104 * 100
|
| 133 |
+
fnr = 1 - tpr
|
| 134 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 135 |
+
roc_auc = auc(fpr, tpr)
|
| 136 |
+
print(f"ROCAUC for CL2CL: {roc_auc * 100} %")
|
| 137 |
+
print(f"EER for CL2CL: {EER * 100} %")
|
| 138 |
+
eer_cl2cl = EER * 100
|
| 139 |
+
print(f"TAR@FAR=10^-2 for CL2CL: {tar_far_102 * 100} %")
|
| 140 |
+
print(f"TAR@FAR=10^-3 for CL2CL: {tar_far_103 * 100} %")
|
| 141 |
+
print(f"TAR@FAR=10^-4 for CL2CL: {tar_far_104 * 100} %")
|
| 142 |
+
cl_labels = cl_labels.cpu().detach().numpy()
|
| 143 |
+
recall_score = Prev_RetMetric([cl_feats,cl_feats],[cl_labels,cl_labels],cl2cl = True)
|
| 144 |
+
cl2clk1 = recall_score.recall_k(k=1) * 100
|
| 145 |
+
print(f"R@1 for CL2CL: {recall_score.recall_k(k=1) * 100} %")
|
| 146 |
+
print(f"R@10 for CL2CL: {recall_score.recall_k(k=10) * 100} %")
|
| 147 |
+
print(f"R@50 for CL2CL: {recall_score.recall_k(k=50) * 100} %")
|
| 148 |
+
print(f"R@100 for CL2CL: {recall_score.recall_k(k=100) * 100} %")
|
rb_evaluation_phase2.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from datasets.rb_loader_cl import RB_loader_cl
|
| 3 |
+
from datasets.rb_loader_cb import RB_loader_cb
|
| 4 |
+
from utils import Prev_RetMetric, l2_norm, compute_recall_at_k
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from model import SwinModel_Fusion as Model
|
| 8 |
+
from sklearn.metrics import roc_curve, auc
|
| 9 |
+
import json
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
|
| 13 |
+
cl_tokens = torch.cat(cl_tokens)
|
| 14 |
+
cb_tokens = torch.cat(cb_tokens)
|
| 15 |
+
|
| 16 |
+
batch_size_cl = cl_tokens.shape[0]
|
| 17 |
+
batch_size_cb = cb_tokens.shape[0]
|
| 18 |
+
shard_size = 20
|
| 19 |
+
similarity_matrix = torch.zeros((batch_size_cl, batch_size_cb))
|
| 20 |
+
for i_start in tqdm(range(0, batch_size_cl, shard_size)):
|
| 21 |
+
i_end = min(i_start + shard_size, batch_size_cl)
|
| 22 |
+
shard_i = cl_tokens[i_start:i_end]
|
| 23 |
+
for j_start in range(0, batch_size_cb, shard_size):
|
| 24 |
+
j_end = min(j_start + shard_size, batch_size_cb)
|
| 25 |
+
shard_j = cb_tokens[j_start:j_end]
|
| 26 |
+
batch_i = shard_i.unsqueeze(1)
|
| 27 |
+
batch_j = shard_j.unsqueeze(0)
|
| 28 |
+
|
| 29 |
+
pairwise_i = batch_i.expand(-1, shard_j.shape[0], -1, -1)
|
| 30 |
+
pairwise_j = batch_j.expand(shard_i.shape[0], -1, -1, -1)
|
| 31 |
+
|
| 32 |
+
similarity_scores, distances = model.combine_features(
|
| 33 |
+
pairwise_i.reshape(-1, 197, shard_i.shape[-1]),
|
| 34 |
+
pairwise_j.reshape(-1, 197, shard_j.shape[-1])
|
| 35 |
+
)
|
| 36 |
+
scores = similarity_scores - 0.1 * distances #-0.1
|
| 37 |
+
scores = scores.reshape(shard_i.shape[0], shard_j.shape[0])
|
| 38 |
+
similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
|
| 39 |
+
return similarity_matrix
|
| 40 |
+
|
| 41 |
+
device = torch.device('cuda')
|
| 42 |
+
data_cl = RB_loader_cl(split="test")
|
| 43 |
+
data_cb = RB_loader_cb(split="test")
|
| 44 |
+
dataloader_cb = torch.utils.data.DataLoader(data_cb,batch_size = 16, num_workers = 1, pin_memory = True)
|
| 45 |
+
dataloader_cl = torch.utils.data.DataLoader(data_cl,batch_size = 16, num_workers = 1, pin_memory = True)
|
| 46 |
+
model = Model().to(device)
|
| 47 |
+
checkpoint = torch.load("ridgeformer_checkpoints/phase2_scratch.pt",map_location = torch.device('cpu'))
|
| 48 |
+
model.load_state_dict(checkpoint,strict=False)
|
| 49 |
+
|
| 50 |
+
model.eval()
|
| 51 |
+
cl_feats, cb_feats, cl_labels, cb_labels, cl_fnames, cb_fnames, cl_feats_unnormed, cb_feats_unnormed = list(),list(),list(),list(),list(),list(),list(),list()
|
| 52 |
+
print("Computing Test Recall")
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
for (x_cb, target) in tqdm(dataloader_cb):
|
| 56 |
+
x_cb, label = x_cb.to(device), target.to(device)
|
| 57 |
+
x_cb_token = model.get_tokens(x_cb,'contactbased')
|
| 58 |
+
label = label.cpu().detach().numpy()
|
| 59 |
+
cb_feats.append(x_cb_token)
|
| 60 |
+
cb_labels.append(label)
|
| 61 |
+
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
for (x_cl, target) in tqdm(dataloader_cl):
|
| 64 |
+
x_cl, label = x_cl.to(device), target.to(device)
|
| 65 |
+
x_cl_token = model.get_tokens(x_cl,'contactless')
|
| 66 |
+
label = label.cpu().detach().numpy()
|
| 67 |
+
cl_feats.append(x_cl_token)
|
| 68 |
+
cl_labels.append(label)
|
| 69 |
+
|
| 70 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
| 71 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
| 72 |
+
|
| 73 |
+
# CB2CL <---------------------------------------->
|
| 74 |
+
scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
|
| 75 |
+
scores = scores_mat.cpu().detach().numpy().flatten().tolist()
|
| 76 |
+
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
|
| 77 |
+
ids_mod = list()
|
| 78 |
+
for i in labels:
|
| 79 |
+
if i==True:
|
| 80 |
+
ids_mod.append(1)
|
| 81 |
+
else:
|
| 82 |
+
ids_mod.append(0)
|
| 83 |
+
fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
|
| 84 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
| 85 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
| 86 |
+
tar_far_102 = tpr[upper_fpr_idx]#(tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 87 |
+
|
| 88 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
| 89 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
| 90 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 91 |
+
|
| 92 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
| 93 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
| 94 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 95 |
+
|
| 96 |
+
fnr = 1 - tpr
|
| 97 |
+
|
| 98 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 99 |
+
roc_auc = auc(fpr, tpr)
|
| 100 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
| 101 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
| 102 |
+
eer_cb2cl = EER * 100
|
| 103 |
+
cbcltf102 = tar_far_102 * 100
|
| 104 |
+
cbcltf103 = tar_far_103 * 100
|
| 105 |
+
cbcltf104 = tar_far_104 * 100
|
| 106 |
+
cl_label = cl_label.cpu().detach()
|
| 107 |
+
cb_label = cb_label.cpu().detach()
|
| 108 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
| 109 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
| 110 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
| 111 |
+
|
| 112 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
|
| 113 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
|
| 114 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
|
| 115 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
|
| 116 |
+
|
| 117 |
+
# CL2CL -------------------------
|
| 118 |
+
scores = get_fused_cross_score_matrix(model, cl_feats, cl_feats)
|
| 119 |
+
scores_mat = scores
|
| 120 |
+
row, col = torch.triu_indices(row=scores.size(0), col=scores.size(1), offset=1)
|
| 121 |
+
scores = scores[row, col]
|
| 122 |
+
labels = torch.eq(cl_label.view(-1,1) - cl_label.view(1,-1),0.0).float().cuda()
|
| 123 |
+
labels = labels[torch.triu(torch.ones(labels.shape),diagonal = 1) == 1]
|
| 124 |
+
scores = scores.cpu().detach().numpy().flatten().tolist()
|
| 125 |
+
labels = labels.flatten().tolist()
|
| 126 |
+
ids_mod = list()
|
| 127 |
+
for i in labels:
|
| 128 |
+
if i==True:
|
| 129 |
+
ids_mod.append(1)
|
| 130 |
+
else:
|
| 131 |
+
ids_mod.append(0)
|
| 132 |
+
fpr,tpr,thresh = roc_curve(labels,scores,drop_intermediate=True)
|
| 133 |
+
|
| 134 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
| 135 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
| 136 |
+
tar_far_102 = tpr[upper_fpr_idx]#(tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 137 |
+
|
| 138 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
| 139 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
| 140 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 141 |
+
|
| 142 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
| 143 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
| 144 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 145 |
+
|
| 146 |
+
fnr = 1 - tpr
|
| 147 |
+
|
| 148 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 149 |
+
roc_auc = auc(fpr, tpr)
|
| 150 |
+
print(f"ROCAUC for CL2CL: {roc_auc * 100} %")
|
| 151 |
+
print(f"EER for CL2CL: {EER * 100} %")
|
| 152 |
+
eer_cb2cl = EER * 100
|
| 153 |
+
cbcltf102 = tar_far_102 * 100
|
| 154 |
+
cbcltf103 = tar_far_103 * 100
|
| 155 |
+
cbcltf104 = tar_far_104 * 100
|
| 156 |
+
cl_label = cl_label.cpu().detach()
|
| 157 |
+
print(f"TAR@FAR=10^-2 for CL2CL: {tar_far_102 * 100} %")
|
| 158 |
+
print(f"TAR@FAR=10^-3 for CL2CL: {tar_far_103 * 100} %")
|
| 159 |
+
print(f"TAR@FAR=10^-4 for CL2CL: {tar_far_104 * 100} %")
|
| 160 |
+
|
| 161 |
+
print(f"R@1 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 1) * 100} %")
|
| 162 |
+
print(f"R@10 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 10) * 100} %")
|
| 163 |
+
print(f"R@50 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 50) * 100} %")
|
| 164 |
+
print(f"R@100 for CL2CL: {compute_recall_at_k(scores_mat, cl_label, cl_label, 100) * 100} %")
|
requirements.txt
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file may be used to create an environment using:
|
| 2 |
+
# $ conda create --name <env> --file <this file>
|
| 3 |
+
# platform: linux-64
|
| 4 |
+
_libgcc_mutex=0.1=main
|
| 5 |
+
_openmp_mutex=5.1=1_gnu
|
| 6 |
+
absl-py=2.1.0=pypi_0
|
| 7 |
+
addict=2.4.0=pypi_0
|
| 8 |
+
aliyun-python-sdk-core=2.15.0=pypi_0
|
| 9 |
+
aliyun-python-sdk-kms=2.16.2=pypi_0
|
| 10 |
+
attrs=23.2.0=pypi_0
|
| 11 |
+
blas=1.0=mkl
|
| 12 |
+
bzip2=1.0.8=h5eee18b_6
|
| 13 |
+
ca-certificates=2024.7.2=h06a4308_0
|
| 14 |
+
cachetools=5.4.0=pypi_0
|
| 15 |
+
certifi=2024.2.2=pypi_0
|
| 16 |
+
cffi=1.16.0=pypi_0
|
| 17 |
+
charset-normalizer=2.1.1=pypi_0
|
| 18 |
+
click=8.1.7=pypi_0
|
| 19 |
+
colorama=0.4.6=pypi_0
|
| 20 |
+
coloredlogs=15.0.1=pypi_0
|
| 21 |
+
contourpy=1.1.1=pypi_0
|
| 22 |
+
crcmod=1.7=pypi_0
|
| 23 |
+
cryptography=42.0.5=pypi_0
|
| 24 |
+
cuda-cudart=11.8.89=0
|
| 25 |
+
cuda-cudart_linux-64=12.4.127=hd681fbe_0
|
| 26 |
+
cuda-cupti=11.8.87=0
|
| 27 |
+
cuda-libraries=11.8.0=0
|
| 28 |
+
cuda-nvrtc=11.8.89=0
|
| 29 |
+
cuda-nvtx=11.8.86=0
|
| 30 |
+
cuda-opencl=12.4.127=h6a678d5_0
|
| 31 |
+
cuda-runtime=11.8.0=0
|
| 32 |
+
cuda-version=12.4=hbda6634_3
|
| 33 |
+
cycler=0.12.1=pypi_0
|
| 34 |
+
entrypoints=0.4=pypi_0
|
| 35 |
+
ffmpeg=4.3=hf484d3e_0
|
| 36 |
+
flatbuffers=24.3.25=pypi_0
|
| 37 |
+
fonttools=4.53.1=pypi_0
|
| 38 |
+
freetype=2.12.1=h4a9f257_0
|
| 39 |
+
fsspec=2024.3.1=pypi_0
|
| 40 |
+
gmp=6.2.1=h295c915_3
|
| 41 |
+
gnutls=3.6.15=he1e5248_0
|
| 42 |
+
google-auth=2.33.0=pypi_0
|
| 43 |
+
google-auth-oauthlib=1.0.0=pypi_0
|
| 44 |
+
grpcio=1.65.4=pypi_0
|
| 45 |
+
httpcore=1.0.5=pypi_0
|
| 46 |
+
httpx=0.27.0=pypi_0
|
| 47 |
+
huggingface-hub=0.22.1=pypi_0
|
| 48 |
+
humanfriendly=10.0=pypi_0
|
| 49 |
+
idna=3.6=pypi_0
|
| 50 |
+
imageio=2.34.2=pypi_0
|
| 51 |
+
importlib-metadata=7.1.0=pypi_0
|
| 52 |
+
importlib-resources=6.4.0=pypi_0
|
| 53 |
+
intel-openmp=2023.1.0=hdb19cb5_46306
|
| 54 |
+
jinja2=3.1.3=pypi_0
|
| 55 |
+
jmespath=0.10.0=pypi_0
|
| 56 |
+
joblib=1.4.2=pypi_0
|
| 57 |
+
jpeg=9e=h5eee18b_2
|
| 58 |
+
jsonschema=4.21.1=pypi_0
|
| 59 |
+
jsonschema-specifications=2023.12.1=pypi_0
|
| 60 |
+
kaleido=0.2.1=pypi_0
|
| 61 |
+
kiwisolver=1.4.5=pypi_0
|
| 62 |
+
lame=3.100=h7b6447c_0
|
| 63 |
+
lcms2=2.12=h3be6417_0
|
| 64 |
+
ld_impl_linux-64=2.38=h1181459_1
|
| 65 |
+
lerc=3.0=h295c915_0
|
| 66 |
+
libcublas=11.11.3.6=0
|
| 67 |
+
libcufft=10.9.0.58=0
|
| 68 |
+
libcufile=1.9.1.3=h99ab3db_1
|
| 69 |
+
libcurand=10.3.5.147=h99ab3db_1
|
| 70 |
+
libcusolver=11.4.1.48=0
|
| 71 |
+
libcusparse=11.7.5.86=0
|
| 72 |
+
libdeflate=1.17=h5eee18b_1
|
| 73 |
+
libffi=3.4.4=h6a678d5_1
|
| 74 |
+
libgcc-ng=11.2.0=h1234567_1
|
| 75 |
+
libgomp=11.2.0=h1234567_1
|
| 76 |
+
libiconv=1.16=h5eee18b_3
|
| 77 |
+
libidn2=2.3.4=h5eee18b_0
|
| 78 |
+
libjpeg-turbo=2.0.0=h9bf148f_0
|
| 79 |
+
libnpp=11.8.0.86=0
|
| 80 |
+
libnvfatbin=12.4.127=h7934f7d_2
|
| 81 |
+
libnvjitlink=12.4.99=0
|
| 82 |
+
libnvjpeg=11.9.0.86=0
|
| 83 |
+
libpng=1.6.39=h5eee18b_0
|
| 84 |
+
libstdcxx-ng=11.2.0=h1234567_1
|
| 85 |
+
libtasn1=4.19.0=h5eee18b_0
|
| 86 |
+
libtiff=4.5.1=h6a678d5_0
|
| 87 |
+
libunistring=0.9.10=h27cfd23_0
|
| 88 |
+
libwebp-base=1.3.2=h5eee18b_0
|
| 89 |
+
llvm-openmp=14.0.6=h9e868ea_0
|
| 90 |
+
llvmlite=0.41.1=pypi_0
|
| 91 |
+
lz4-c=1.9.4=h6a678d5_1
|
| 92 |
+
markdown=3.6=pypi_0
|
| 93 |
+
markdown-it-py=3.0.0=pypi_0
|
| 94 |
+
markupsafe=2.1.5=pypi_0
|
| 95 |
+
matplotlib=3.7.5=pypi_0
|
| 96 |
+
mdit-py-plugins=0.4.0=pypi_0
|
| 97 |
+
mkl=2023.1.0=h213fc3f_46344
|
| 98 |
+
mmcv=2.1.0=dev_0
|
| 99 |
+
mmdet=3.3.0=dev_0
|
| 100 |
+
mmengine=0.10.3=pypi_0
|
| 101 |
+
model-index=0.1.11=pypi_0
|
| 102 |
+
mpc=1.1.0=h10f8cd9_1
|
| 103 |
+
mpfr=4.0.2=hb69a4c5_1
|
| 104 |
+
mpmath=1.3.0=py38h06a4308_0
|
| 105 |
+
ncurses=6.4=h6a678d5_0
|
| 106 |
+
nettle=3.7.3=hbbd107a_1
|
| 107 |
+
networkx=3.1=py38h06a4308_0
|
| 108 |
+
numba=0.58.1=pypi_0
|
| 109 |
+
numpy=1.24.4=pypi_0
|
| 110 |
+
nvidia-cublas-cu11=11.11.3.6=pypi_0
|
| 111 |
+
nvidia-cuda-cupti-cu11=11.8.87=pypi_0
|
| 112 |
+
nvidia-cuda-nvrtc-cu11=11.8.89=pypi_0
|
| 113 |
+
nvidia-cuda-runtime-cu11=11.8.89=pypi_0
|
| 114 |
+
nvidia-cudnn-cu11=8.7.0.84=pypi_0
|
| 115 |
+
nvidia-cufft-cu11=10.9.0.58=pypi_0
|
| 116 |
+
nvidia-curand-cu11=10.3.0.86=pypi_0
|
| 117 |
+
nvidia-cusolver-cu11=11.4.1.48=pypi_0
|
| 118 |
+
nvidia-cusparse-cu11=11.7.5.86=pypi_0
|
| 119 |
+
nvidia-nccl-cu11=2.19.3=pypi_0
|
| 120 |
+
nvidia-nvtx-cu11=11.8.86=pypi_0
|
| 121 |
+
oauthlib=3.2.2=pypi_0
|
| 122 |
+
ocl-icd=2.3.2=h5eee18b_1
|
| 123 |
+
onnxruntime=1.18.1=pypi_0
|
| 124 |
+
opencv-python=4.10.0.84=pypi_0
|
| 125 |
+
opencv-python-headless=4.10.0.84=pypi_0
|
| 126 |
+
opendatalab=0.0.10=pypi_0
|
| 127 |
+
openh264=2.1.1=h4ff587b_0
|
| 128 |
+
openjpeg=2.4.0=h9ca470c_2
|
| 129 |
+
openmim=0.3.9=pypi_0
|
| 130 |
+
openssl=3.0.14=h5eee18b_0
|
| 131 |
+
openxlab=0.0.37=pypi_0
|
| 132 |
+
ordered-set=4.1.0=pypi_0
|
| 133 |
+
oss2=2.17.0=pypi_0
|
| 134 |
+
packaging=24.0=pypi_0
|
| 135 |
+
pandas=2.0.3=pypi_0
|
| 136 |
+
pillow=9.0.1=pypi_0
|
| 137 |
+
pip=23.3.1=pypi_0
|
| 138 |
+
pkgutil-resolve-name=1.3.10=pypi_0
|
| 139 |
+
platformdirs=4.2.0=pypi_0
|
| 140 |
+
plotly=5.23.0=pypi_0
|
| 141 |
+
pooch=1.8.2=pypi_0
|
| 142 |
+
protobuf=5.27.3=pypi_0
|
| 143 |
+
pyasn1=0.6.0=pypi_0
|
| 144 |
+
pyasn1-modules=0.4.0=pypi_0
|
| 145 |
+
pycocotools=2.0.7=pypi_0
|
| 146 |
+
pycparser=2.21=pypi_0
|
| 147 |
+
pygments=2.17.2=pypi_0
|
| 148 |
+
pymatting=1.1.12=pypi_0
|
| 149 |
+
pyparsing=3.1.2=pypi_0
|
| 150 |
+
python=3.8.19=h955ad1f_0
|
| 151 |
+
python-dateutil=2.9.0.post0=pypi_0
|
| 152 |
+
pytorch-cuda=11.8=h7e8668a_5
|
| 153 |
+
pytorch-metric-learning=2.5.0=pypi_0
|
| 154 |
+
pytorch-mutex=1.0=cuda
|
| 155 |
+
pytz=2023.4=pypi_0
|
| 156 |
+
pywavelets=1.4.1=pypi_0
|
| 157 |
+
pyyaml=6.0.1=py38h5eee18b_0
|
| 158 |
+
readline=8.2=h5eee18b_0
|
| 159 |
+
referencing=0.34.0=pypi_0
|
| 160 |
+
rembg=2.0.58=pypi_0
|
| 161 |
+
requests=2.28.2=pypi_0
|
| 162 |
+
requests-oauthlib=2.0.0=pypi_0
|
| 163 |
+
rich=13.4.2=pypi_0
|
| 164 |
+
rpds-py=0.18.0=pypi_0
|
| 165 |
+
rsa=4.9=pypi_0
|
| 166 |
+
safetensors=0.4.2=pypi_0
|
| 167 |
+
scikit-image=0.19.3=pypi_0
|
| 168 |
+
scikit-learn=1.3.2=pypi_0
|
| 169 |
+
scipy=1.10.1=pypi_0
|
| 170 |
+
setuptools=60.2.0=pypi_0
|
| 171 |
+
shapely=2.0.3=pypi_0
|
| 172 |
+
six=1.16.0=pypi_0
|
| 173 |
+
sqlite=3.45.3=h5eee18b_0
|
| 174 |
+
sympy=1.12=py38h06a4308_0
|
| 175 |
+
tabulate=0.9.0=pypi_0
|
| 176 |
+
tbb=2021.8.0=hdb19cb5_0
|
| 177 |
+
tenacity=9.0.0=pypi_0
|
| 178 |
+
tensorboard=2.14.0=pypi_0
|
| 179 |
+
tensorboard-data-server=0.7.2=pypi_0
|
| 180 |
+
termcolor=2.4.0=pypi_0
|
| 181 |
+
terminaltables=3.1.10=pypi_0
|
| 182 |
+
threadpoolctl=3.5.0=pypi_0
|
| 183 |
+
tifffile=2023.7.10=pypi_0
|
| 184 |
+
timm=0.5.0=dev_0
|
| 185 |
+
tk=8.6.14=h39e8969_0
|
| 186 |
+
tomli=2.0.1=pypi_0
|
| 187 |
+
torch=2.2.2+cu118=pypi_0
|
| 188 |
+
torchaudio=2.2.2+cu118=pypi_0
|
| 189 |
+
torchvision=0.17.2+cu118=pypi_0
|
| 190 |
+
tqdm=4.66.5=pypi_0
|
| 191 |
+
triton=2.2.0=pypi_0
|
| 192 |
+
typing-extensions=4.10.0=pypi_0
|
| 193 |
+
tzdata=2024.1=pypi_0
|
| 194 |
+
urllib3=1.26.18=pypi_0
|
| 195 |
+
werkzeug=3.0.3=pypi_0
|
| 196 |
+
wheel=0.41.2=pypi_0
|
| 197 |
+
xz=5.4.6=h5eee18b_1
|
| 198 |
+
yaml=0.2.5=h7b6447c_0
|
| 199 |
+
yapf=0.40.2=pypi_0
|
| 200 |
+
zipp=3.18.1=pypi_0
|
| 201 |
+
zlib=1.2.13=h5eee18b_1
|
| 202 |
+
zstd=1.5.5=hc292b87_2
|
train_combined.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torchvision import datasets, transforms
|
| 9 |
+
from torch.optim.lr_scheduler import StepLR, MultiStepLR
|
| 10 |
+
from datasets.hkpoly_test import hktest
|
| 11 |
+
from datasets.original_combined_train import Combined_original
|
| 12 |
+
from datasets.rb_loader import RB_loader
|
| 13 |
+
from loss import DualMSLoss_FineGrained_domain_agnostic_ft, DualMSLoss_FineGrained, DualMSLoss_FineGrained_domain_agnostic
|
| 14 |
+
import timm
|
| 15 |
+
from utils import Prev_RetMetric, RetMetric, compute_recall_at_k, l2_norm, compute_sharded_cosine_similarity, count_parameters
|
| 16 |
+
from pprint import pprint
|
| 17 |
+
import numpy as np
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from combined_sampler import BalancedSampler
|
| 20 |
+
from torch.utils.data.sampler import BatchSampler
|
| 21 |
+
from torch.nn.parallel import DataParallel
|
| 22 |
+
from model import SwinModel_domain_agnostic as Model
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
from sklearn.metrics import roc_curve, auc
|
| 25 |
+
import json
|
| 26 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 27 |
+
|
| 28 |
+
def train(args, model, device, train_loader, test_loader, optimizers, epoch, loss_func, pl_arg, stepping, log_writer):
|
| 29 |
+
model.train()
|
| 30 |
+
steploss = list()
|
| 31 |
+
for batch_idx, (x_cl, x_cb, target, category_cl, category_cb) in enumerate(pbar := tqdm(train_loader)):
|
| 32 |
+
x_cl, x_cb, target, category_cl, category_cb = x_cl.to(device), x_cb.to(device), target.to(device), category_cl.to(device), category_cb.to(device)
|
| 33 |
+
for optimizer in optimizers:
|
| 34 |
+
optimizer.zero_grad()
|
| 35 |
+
x_cl, x_cb, x_cl_tokens, x_cb_tokens, domain_class_cl, domain_class_cb = model(x_cl, x_cb)
|
| 36 |
+
loss = loss_func(x_cl, x_cb, x_cl_tokens, x_cb_tokens, target, device, domain_class_cl, domain_class_cb, category_cl, category_cb)
|
| 37 |
+
loss.backward()
|
| 38 |
+
for optimizer in optimizers:
|
| 39 |
+
optimizer.step()
|
| 40 |
+
if batch_idx % args.log_interval == 0:
|
| 41 |
+
if args.dry_run:
|
| 42 |
+
break
|
| 43 |
+
pbar.set_description(f"Loss {loss}")
|
| 44 |
+
steploss.append(loss)
|
| 45 |
+
return sum(steploss)/len(steploss), stepping
|
| 46 |
+
|
| 47 |
+
def l2_norm(input):
|
| 48 |
+
input_size = input.size()
|
| 49 |
+
buffer = torch.pow(input, 2)
|
| 50 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
| 51 |
+
norm = torch.sqrt(normp)
|
| 52 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
| 53 |
+
output = _output.view(input_size)
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
def hkpoly_test_fn(model,device,test_loader,epoch,plot_argument):
|
| 57 |
+
model.eval()
|
| 58 |
+
cl_feats, cb_feats, cl_labels, cb_labels = list(),list(),list(),list()
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
for (x_cl, x_cb, label) in tqdm(test_loader):
|
| 61 |
+
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
|
| 62 |
+
x_cl_feat, x_cl_token = model.get_embeddings(x_cl,'contactless')
|
| 63 |
+
x_cb_feat,x_cb_token = model.get_embeddings(x_cb,'contactbased')
|
| 64 |
+
x_cl_feat = l2_norm(x_cl_feat).cpu().detach().numpy()
|
| 65 |
+
x_cb_feat = l2_norm(x_cb_feat).cpu().detach().numpy()
|
| 66 |
+
label = label.cpu().detach().numpy()
|
| 67 |
+
cl_feats.append(x_cl_feat)
|
| 68 |
+
cb_feats.append(x_cb_feat)
|
| 69 |
+
cl_labels.append(label)
|
| 70 |
+
cb_labels.append(label)
|
| 71 |
+
|
| 72 |
+
cl_feats = np.concatenate(cl_feats)
|
| 73 |
+
cb_feats = np.concatenate(cb_feats)
|
| 74 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
| 75 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
| 76 |
+
|
| 77 |
+
# CB2CL
|
| 78 |
+
scores = np.dot(cl_feats,np.transpose(cb_feats))
|
| 79 |
+
np.save("combined_models_scores/task1_cb2cl_score_matrix_"+str(epoch)+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+"_"+plot_argument[3]+".npy", scores)
|
| 80 |
+
scores = scores.flatten().tolist()
|
| 81 |
+
labels = torch.eq(cl_label.view(-1,1) - cb_label.view(1,-1),0.0).flatten().tolist()
|
| 82 |
+
ids_mod = list()
|
| 83 |
+
for i in labels:
|
| 84 |
+
if i==True:
|
| 85 |
+
ids_mod.append(1)
|
| 86 |
+
else:
|
| 87 |
+
ids_mod.append(0)
|
| 88 |
+
fpr,tpr,_ = roc_curve(labels,scores)
|
| 89 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
| 90 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
| 91 |
+
tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 92 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
| 93 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
| 94 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 95 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
| 96 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
| 97 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 98 |
+
fnr = 1 - tpr
|
| 99 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 100 |
+
roc_auc = auc(fpr, tpr)
|
| 101 |
+
plt.figure()
|
| 102 |
+
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
|
| 103 |
+
plt.plot([0, 1], [0, 1], 'k--', label='No Skill')
|
| 104 |
+
plt.xlim([0, 1])
|
| 105 |
+
plt.ylim([0, 1])
|
| 106 |
+
plt.xlabel('False Positive Rate')
|
| 107 |
+
plt.ylabel('True Positive Rate')
|
| 108 |
+
plt.title('ROC Curve CB2CL task1')
|
| 109 |
+
plt.legend(loc="lower right")
|
| 110 |
+
plt.savefig("combined_models_scores/roc_curve_cb2cl_task1_"+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+"_"+plot_argument[3]+str(epoch)+".png", dpi=300, bbox_inches='tight')
|
| 111 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
| 112 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
| 113 |
+
eer_cb2cl = EER * 100
|
| 114 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
| 115 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
| 116 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
| 117 |
+
cbcltf102 = tar_far_102 * 100
|
| 118 |
+
cbcltf103 = tar_far_103 * 100
|
| 119 |
+
cbcltf104 = tar_far_104 * 100
|
| 120 |
+
cl_label = cl_label.cpu().detach().numpy()
|
| 121 |
+
cb_label = cb_label.cpu().detach().numpy()
|
| 122 |
+
recall_score = Prev_RetMetric([cb_feats,cl_feats],[cb_label,cl_label],cl2cl = False)
|
| 123 |
+
cl2cbk1 = recall_score.recall_k(k=1) * 100
|
| 124 |
+
print(f"R@1 for CB2CL: {recall_score.recall_k(k=1) * 100} %")
|
| 125 |
+
print(f"R@10 for CB2CL: {recall_score.recall_k(k=10) * 100} %")
|
| 126 |
+
print(f"R@50 for CB2CL: {recall_score.recall_k(k=50) * 100} %")
|
| 127 |
+
print(f"R@100 for CB2CL: {recall_score.recall_k(k=100) * 100} %")
|
| 128 |
+
|
| 129 |
+
return cl2cbk1,eer_cb2cl,cbcltf102,cbcltf103,cbcltf104
|
| 130 |
+
|
| 131 |
+
def main():
|
| 132 |
+
# Training settings
|
| 133 |
+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
| 134 |
+
parser.add_argument('--manifest-list', type=list, default=mani_lst,
|
| 135 |
+
help='list of manifest files from different datasets to train on')
|
| 136 |
+
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
| 137 |
+
help='input batch size for training (default: 64)')
|
| 138 |
+
parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
|
| 139 |
+
help='input batch size for testing (default: 1000)')
|
| 140 |
+
parser.add_argument('--epochs', type=int, default=50, metavar='N',
|
| 141 |
+
help='number of epochs to train (default: 14)')
|
| 142 |
+
parser.add_argument('--lr_linear', type=float, default=1.0, metavar='LR',
|
| 143 |
+
help='learning rate (default: 1.0)')
|
| 144 |
+
parser.add_argument('--lr_swin', type=float, default=1.0, metavar='LR',
|
| 145 |
+
help='learning rate (default: 1.0)')
|
| 146 |
+
parser.add_argument('--gamma', type=float, default=0.9, metavar='M',
|
| 147 |
+
help='Learning rate step gamma (default: 0.7)')
|
| 148 |
+
parser.add_argument('--no-cuda', action='store_true', default=False,
|
| 149 |
+
help='disables CUDA training')
|
| 150 |
+
parser.add_argument('--dry-run', action='store_true', default=False,
|
| 151 |
+
help='quickly check a single pass')
|
| 152 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
| 153 |
+
help='random seed (default: 1)')
|
| 154 |
+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
| 155 |
+
help='how many batches to wait before logging training status')
|
| 156 |
+
parser.add_argument('--warmup', type=int, default=2, metavar='N',
|
| 157 |
+
help='warm up rate for feature extractor')
|
| 158 |
+
parser.add_argument('--model-name', type=str, default="ridgeformer",
|
| 159 |
+
help='Name of the model for checkpointing')
|
| 160 |
+
args = parser.parse_args()
|
| 161 |
+
|
| 162 |
+
checkpoint_save_path = "ridgeformer_checkpoints/"
|
| 163 |
+
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
| 164 |
+
|
| 165 |
+
if not os.path.exists("experiment_logs/"+args.model_name):
|
| 166 |
+
os.mkdir("experiment_logs/"+args.model_name)
|
| 167 |
+
|
| 168 |
+
log_writer = SummaryWriter("experiment_logs/"+args.model_name+"/",comment = str(args.batch_size)+str(args.lr_linear)+str(args.lr_swin))
|
| 169 |
+
|
| 170 |
+
torch.manual_seed(args.seed)
|
| 171 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
| 172 |
+
|
| 173 |
+
print("loading Normal RGB images -----------------------------")
|
| 174 |
+
train_dataset = Combined_original(args.manifest_list,split="train")
|
| 175 |
+
val_dataset = hktest(split="test")
|
| 176 |
+
|
| 177 |
+
balanced_sampler = BalancedSampler(train_dataset, batch_size = args.batch_size, images_per_class = 2)
|
| 178 |
+
batch_sampler = BatchSampler(balanced_sampler, batch_size = args.batch_size, drop_last = True)
|
| 179 |
+
|
| 180 |
+
train_kwargs = {'batch_sampler': batch_sampler}
|
| 181 |
+
test_kwargs = {'batch_size': args.test_batch_size}
|
| 182 |
+
|
| 183 |
+
if use_cuda:
|
| 184 |
+
cuda_kwargs = {
|
| 185 |
+
'num_workers': 1,
|
| 186 |
+
'pin_memory': True
|
| 187 |
+
}
|
| 188 |
+
train_kwargs.update(cuda_kwargs)
|
| 189 |
+
test_kwargs.update(cuda_kwargs)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
|
| 193 |
+
test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
|
| 194 |
+
|
| 195 |
+
model = Model().to(device)
|
| 196 |
+
ckpt = torch.load("ridgeformer_checkpoints/phase1_scratch.pt", map_location=torch.device('cpu'))
|
| 197 |
+
model.load_state_dict(ckpt,strict=False)
|
| 198 |
+
print("Number of Trainable Parameters: - ", count_parameters(model))
|
| 199 |
+
|
| 200 |
+
loss_func = DualMSLoss_FineGrained_domain_agnostic()
|
| 201 |
+
# loss_func = DualMSLoss_FineGrained_domain_agnostic_ft()
|
| 202 |
+
|
| 203 |
+
optimizer_swin = optim.AdamW(
|
| 204 |
+
[
|
| 205 |
+
{"params": model.swin_cl.parameters(), "lr":args.lr_swin},
|
| 206 |
+
{"params": model.classify.parameters(), "lr":args.lr_linear},
|
| 207 |
+
{"params": model.linear_cl.parameters(), "lr":args.lr_linear},
|
| 208 |
+
{"params": model.linear_cb.parameters(), "lr":args.lr_linear},
|
| 209 |
+
],
|
| 210 |
+
weight_decay=0.000001,
|
| 211 |
+
lr=args.lr_swin)
|
| 212 |
+
|
| 213 |
+
scheduler_swin = MultiStepLR(optimizer_swin, milestones = [100], gamma=0.7)
|
| 214 |
+
|
| 215 |
+
cl2cl_lst = list()
|
| 216 |
+
cb2cl_lst = list()
|
| 217 |
+
eer_cl2cl_lst = list()
|
| 218 |
+
eer_cb2cl_lst = list()
|
| 219 |
+
cbcltf102_lst,cbcltf103_lst,cbcltf104_lst,clcltf102_lst,clcltf103_lst,clcltf104_lst = list(),list(),list(),list(),list(),list()
|
| 220 |
+
stepping = 1
|
| 221 |
+
for epoch in range(1, args.epochs + 1):
|
| 222 |
+
print(f"running epoch------ {epoch}")
|
| 223 |
+
if (epoch > args.warmup):
|
| 224 |
+
print("Training with Swin")
|
| 225 |
+
model.unfreeze_encoder()
|
| 226 |
+
else:
|
| 227 |
+
print("Training only linear")
|
| 228 |
+
model.freeze_encoder()
|
| 229 |
+
|
| 230 |
+
avg_step_loss,stepping = train(args, model, device, train_loader, test_loader, [optimizer_swin], epoch, loss_func, [args.model_name,str(args.batch_size),str(args.lr_linear),str(args.lr_swin)],stepping,log_writer)
|
| 231 |
+
|
| 232 |
+
print(f"Learning Rate for {epoch} for swin = {scheduler_swin.get_last_lr()}")
|
| 233 |
+
|
| 234 |
+
log_writer.add_scalar('Swin_LR/epoch',scheduler_swin.get_last_lr()[0],epoch)
|
| 235 |
+
|
| 236 |
+
if (epoch > args.warmup):
|
| 237 |
+
scheduler_swin.step()
|
| 238 |
+
|
| 239 |
+
cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch,[args.model_name,str(args.batch_size),str(args.lr_linear),str(args.lr_swin)])
|
| 240 |
+
cl2cl_lst.append(cl2clk1)
|
| 241 |
+
cb2cl_lst.append(cl2cbk1)
|
| 242 |
+
eer_cl2cl_lst.append(eer_cl2cl)
|
| 243 |
+
eer_cb2cl_lst.append(eer_cb2cl)
|
| 244 |
+
cbcltf102_lst.append(cbcltf102)
|
| 245 |
+
cbcltf103_lst.append(cbcltf103)
|
| 246 |
+
cbcltf104_lst.append(cbcltf104)
|
| 247 |
+
clcltf102_lst.append(clcltf102)
|
| 248 |
+
clcltf103_lst.append(clcltf103)
|
| 249 |
+
clcltf104_lst.append(clcltf104)
|
| 250 |
+
|
| 251 |
+
log_writer.add_scalars('recall@1/epoch',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},epoch)
|
| 252 |
+
log_writer.add_scalars('EER/epoch',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},epoch)
|
| 253 |
+
log_writer.add_scalars('TARFAR10^-2/epoch',{'CL2CL':clcltf102,'CB2CL':cbcltf102},epoch)
|
| 254 |
+
log_writer.add_scalars('TARFAR10^-3/epoch',{'CL2CL':clcltf103,'CB2CL':cbcltf103},epoch)
|
| 255 |
+
log_writer.add_scalars('TARFAR10^-4/epoch',{'CL2CL':clcltf104,'CB2CL':cbcltf104},epoch)
|
| 256 |
+
log_writer.add_scalar('AvgLoss/epoch',avg_step_loss,epoch)
|
| 257 |
+
|
| 258 |
+
torch.save(model.state_dict(), checkpoint_save_path + "combinedtrained_hkpolytest_" + args.model_name + "_" + str(args.lr_linear) + "_" + str(args.lr_swin) + "_" + str(args.batch_size) + str(epoch) + "_" + str(cl2clk1)+ "_" + str(cl2cbk1) + ".pt")
|
| 259 |
+
log_writer.close()
|
| 260 |
+
|
| 261 |
+
print(f"Maximum recall@1 for CL2CL: {max(cl2cl_lst)} at epoch {cl2cl_lst.index(max(cl2cl_lst))+1}")
|
| 262 |
+
print(f"Maximum recall@1 for CB2CL: {max(cb2cl_lst)} at epoch {cb2cl_lst.index(max(cb2cl_lst))+1}")
|
| 263 |
+
print(f"Minimum EER for CL2CL: {min(eer_cl2cl_lst)} at epoch {eer_cl2cl_lst.index(min(eer_cl2cl_lst))+1}")
|
| 264 |
+
print(f"Minimum EER for CB2CL: {min(eer_cb2cl_lst)} at epoch {eer_cb2cl_lst.index(min(eer_cb2cl_lst))+1}")
|
| 265 |
+
print(f"Maximum TAR@FAR=10^-2 for CB2CL: {max(cbcltf102_lst)} at epoch {cbcltf102_lst.index(max(cbcltf102_lst))+1}")
|
| 266 |
+
print(f"Maximum TAR@FAR=10^-3 for CB2CL: {max(cbcltf103_lst)} at epoch {cbcltf103_lst.index(max(cbcltf103_lst))+1}")
|
| 267 |
+
print(f"Maximum TAR@FAR=10^-4 for CB2CL: {max(cbcltf104_lst)} at epoch {cbcltf104_lst.index(max(cbcltf104_lst))+1}")
|
| 268 |
+
print(f"Maximum TAR@FAR=10^-2 for CL2CL: {max(clcltf102_lst)} at epoch {clcltf102_lst.index(max(clcltf102_lst))+1}")
|
| 269 |
+
print(f"Maximum TAR@FAR=10^-3 for CL2CL: {max(clcltf103_lst)} at epoch {clcltf103_lst.index(max(clcltf103_lst))+1}")
|
| 270 |
+
print(f"Maximum TAR@FAR=10^-4 for CL2CL: {max(clcltf104_lst)} at epoch {clcltf104_lst.index(max(clcltf104_lst))+1}")
|
| 271 |
+
|
| 272 |
+
if __name__ == '__main__':
|
| 273 |
+
main()
|
train_combined_fusion.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torchvision import datasets, transforms
|
| 9 |
+
from torch.optim.lr_scheduler import StepLR, MultiStepLR
|
| 10 |
+
from datasets.hkpoly_test import hktest
|
| 11 |
+
from datasets.original_combined_train import Combined_original
|
| 12 |
+
from datasets.rb_loader import RB_loader
|
| 13 |
+
from loss import DualMSLoss_FineGrained_domain_agnostic_ft, DualMSLoss_FineGrained, DualMSLoss_FineGrained_domain_agnostic
|
| 14 |
+
import timm
|
| 15 |
+
from utils import Prev_RetMetric, RetMetric, compute_recall_at_k, l2_norm, compute_sharded_cosine_similarity, count_parameters
|
| 16 |
+
from pprint import pprint
|
| 17 |
+
import numpy as np
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from combined_sampler import BalancedSampler
|
| 20 |
+
from torch.utils.data.sampler import BatchSampler
|
| 21 |
+
from torch.nn.parallel import DataParallel
|
| 22 |
+
from model import SwinModel_Fusion as Model
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
from sklearn.metrics import roc_curve, auc
|
| 25 |
+
import json
|
| 26 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 27 |
+
|
| 28 |
+
def train(args, model, device, train_loader, test_loader, optimizers, epoch, loss_func, pl_arg, stepping, log_writer, checkpoint_save_path):
|
| 29 |
+
model.train()
|
| 30 |
+
steploss = list()
|
| 31 |
+
for batch_idx, (x_cl, x_cb, target,_,_) in enumerate(pbar := tqdm(train_loader)):
|
| 32 |
+
x_cl, x_cb, target = x_cl.to(device), x_cb.to(device), target.to(device)
|
| 33 |
+
for optimizer in optimizers:
|
| 34 |
+
optimizer.zero_grad()
|
| 35 |
+
x_cl_tokens, x_cb_tokens = model(x_cl, x_cb)
|
| 36 |
+
|
| 37 |
+
N, M, D = x_cl_tokens.shape
|
| 38 |
+
|
| 39 |
+
index_i = torch.arange(N).unsqueeze(1) # Shape: (100, 1)
|
| 40 |
+
index_j = torch.arange(N).unsqueeze(0) # Shape: (1, 100)
|
| 41 |
+
|
| 42 |
+
x = x_cl_tokens[index_i] # Shape: (100, 100, 197, 1024)
|
| 43 |
+
y = x_cb_tokens[index_j] # Shape: (100, 100, 197, 1024)
|
| 44 |
+
|
| 45 |
+
x = x.expand(N, N, M, D).reshape(N * N, M, D) # Shape: (10000, 197, 1024)
|
| 46 |
+
y = y.expand(N, N, M, D).reshape(N * N, M, D) # Shape: (10000, 197, 1024)
|
| 47 |
+
sim_matrix,_ = model.combine_features(x, y)
|
| 48 |
+
sim_matrix = sim_matrix.view(N, N).to(device)
|
| 49 |
+
|
| 50 |
+
loss = loss_func.ms_sample(sim_matrix, target).cuda() + loss_func.ms_sample(sim_matrix.t(), target.t()).cuda()
|
| 51 |
+
loss.backward()
|
| 52 |
+
for optimizer in optimizers:
|
| 53 |
+
optimizer.step()
|
| 54 |
+
if batch_idx % args.log_interval == 0:
|
| 55 |
+
if args.dry_run:
|
| 56 |
+
break
|
| 57 |
+
pbar.set_description(f"Loss {loss}")
|
| 58 |
+
steploss.append(loss)
|
| 59 |
+
if (batch_idx + 1)%50 == 0:
|
| 60 |
+
cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch, pl_arg)
|
| 61 |
+
log_writer.add_scalars('recall@1/step',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},stepping)
|
| 62 |
+
log_writer.add_scalars('EER/step',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},stepping)
|
| 63 |
+
log_writer.add_scalars('TARFAR10^-2/step',{'CL2CL':clcltf102,'CB2CL':cbcltf102},stepping)
|
| 64 |
+
log_writer.add_scalars('TARFAR10^-4/step',{'CL2CL':clcltf104,'CB2CL':cbcltf104},stepping)
|
| 65 |
+
stepping+=1
|
| 66 |
+
|
| 67 |
+
return sum(steploss)/len(steploss), stepping
|
| 68 |
+
|
| 69 |
+
def l2_norm(input):
|
| 70 |
+
input_size = input.size()
|
| 71 |
+
buffer = torch.pow(input, 2)
|
| 72 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
| 73 |
+
norm = torch.sqrt(normp)
|
| 74 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
| 75 |
+
output = _output.view(input_size)
|
| 76 |
+
return output
|
| 77 |
+
|
| 78 |
+
def get_fused_cross_score_matrix(model, cl_tokens, cb_tokens):
|
| 79 |
+
cl_tokens = torch.cat(cl_tokens)
|
| 80 |
+
cb_tokens = torch.cat(cb_tokens)
|
| 81 |
+
batch_size = cl_tokens.shape[0]
|
| 82 |
+
shard_size = 20
|
| 83 |
+
similarity_matrix = torch.zeros((batch_size, batch_size))
|
| 84 |
+
for i_start in tqdm(range(0, batch_size, shard_size)):
|
| 85 |
+
i_end = min(i_start + shard_size, batch_size)
|
| 86 |
+
shard_i = cl_tokens[i_start:i_end]
|
| 87 |
+
for j_start in range(0, batch_size, shard_size):
|
| 88 |
+
j_end = min(j_start + shard_size, batch_size)
|
| 89 |
+
shard_j = cb_tokens[j_start:j_end]
|
| 90 |
+
batch_i = shard_i.unsqueeze(1)
|
| 91 |
+
batch_j = shard_j.unsqueeze(0)
|
| 92 |
+
pairwise_i = batch_i.expand(-1, shard_size, -1, -1)
|
| 93 |
+
pairwise_j = batch_j.expand(shard_size, -1, -1, -1)
|
| 94 |
+
|
| 95 |
+
similarity_scores, distances = model.combine_features(pairwise_i.reshape(-1, 197, 1024), pairwise_j.reshape(-1, 197, 1024))
|
| 96 |
+
scores = similarity_scores - 0.1 * distances
|
| 97 |
+
scores = scores.reshape(shard_size, shard_size)
|
| 98 |
+
similarity_matrix[i_start:i_end, j_start:j_end] = scores.cpu().detach()
|
| 99 |
+
return similarity_matrix
|
| 100 |
+
|
| 101 |
+
def hkpoly_test_fn(model,device,test_loader,epoch,plot_argument):
|
| 102 |
+
model.eval()
|
| 103 |
+
cl_feats, cb_feats, cl_labels, cb_labels = list(),list(),list(),list()
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
for (x_cl, x_cb, label) in tqdm(test_loader):
|
| 106 |
+
x_cl, x_cb, label = x_cl.to(device), x_cb.to(device), label.to(device)
|
| 107 |
+
x_cl_token = model.get_tokens(x_cl,'contactless')
|
| 108 |
+
x_cb_token = model.get_tokens(x_cb,'contactbased')
|
| 109 |
+
label = label.cpu().detach().numpy()
|
| 110 |
+
cl_feats.append(x_cl_token)
|
| 111 |
+
cb_feats.append(x_cb_token)
|
| 112 |
+
cl_labels.append(label)
|
| 113 |
+
cb_labels.append(label)
|
| 114 |
+
|
| 115 |
+
cl_label = torch.from_numpy(np.concatenate(cl_labels))
|
| 116 |
+
cb_label = torch.from_numpy(np.concatenate(cb_labels))
|
| 117 |
+
|
| 118 |
+
# CB2CL
|
| 119 |
+
scores_mat = get_fused_cross_score_matrix(model, cl_feats, cb_feats)
|
| 120 |
+
np.save("combined_models_scores/task1_cb2cl_score_matrix_"+str(epoch)+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+".npy", scores_mat)
|
| 121 |
+
scores = scores_mat.cpu().detach().numpy().flatten().tolist()
|
| 122 |
+
labels = torch.eq(cb_label.view(-1,1) - cl_label.view(1,-1),0.0).flatten().tolist()
|
| 123 |
+
ids_mod = list()
|
| 124 |
+
for i in labels:
|
| 125 |
+
if i==True:
|
| 126 |
+
ids_mod.append(1)
|
| 127 |
+
else:
|
| 128 |
+
ids_mod.append(0)
|
| 129 |
+
fpr,tpr,_ = roc_curve(labels,scores)
|
| 130 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.01)
|
| 131 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.01)
|
| 132 |
+
tar_far_102 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 133 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.001)
|
| 134 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.001)
|
| 135 |
+
tar_far_103 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 136 |
+
lower_fpr_idx = max(i for i, val in enumerate(fpr) if val < 0.0001)
|
| 137 |
+
upper_fpr_idx = min(i for i, val in enumerate(fpr) if val >= 0.0001)
|
| 138 |
+
tar_far_104 = (tpr[lower_fpr_idx]+tpr[upper_fpr_idx])/2
|
| 139 |
+
fnr = 1 - tpr
|
| 140 |
+
EER = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
|
| 141 |
+
roc_auc = auc(fpr, tpr)
|
| 142 |
+
plt.figure()
|
| 143 |
+
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
|
| 144 |
+
plt.plot([0, 1], [0, 1], 'k--', label='No Skill')
|
| 145 |
+
plt.xlim([0, 1])
|
| 146 |
+
plt.ylim([0, 1])
|
| 147 |
+
plt.xlabel('False Positive Rate')
|
| 148 |
+
plt.ylabel('True Positive Rate')
|
| 149 |
+
plt.title('ROC Curve CB2CL task1')
|
| 150 |
+
plt.legend(loc="lower right")
|
| 151 |
+
plt.savefig("combined_models_scores/roc_curve_cb2cl_task1_"+"_"+plot_argument[0]+"_"+plot_argument[1]+"_"+plot_argument[2]+str(epoch)+".png", dpi=300, bbox_inches='tight')
|
| 152 |
+
print(f"ROCAUC for CB2CL: {roc_auc * 100} %")
|
| 153 |
+
print(f"EER for CB2CL: {EER * 100} %")
|
| 154 |
+
eer_cb2cl = EER * 100
|
| 155 |
+
print(f"TAR@FAR=10^-2 for CB2CL: {tar_far_102 * 100} %")
|
| 156 |
+
print(f"TAR@FAR=10^-3 for CB2CL: {tar_far_103 * 100} %")
|
| 157 |
+
print(f"TAR@FAR=10^-4 for CB2CL: {tar_far_104 * 100} %")
|
| 158 |
+
cbcltf102 = tar_far_102 * 100
|
| 159 |
+
cbcltf103 = tar_far_103 * 100
|
| 160 |
+
cbcltf104 = tar_far_104 * 100
|
| 161 |
+
cl2cbk1 = compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100
|
| 162 |
+
print(f"R@1 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 1) * 100} %")
|
| 163 |
+
print(f"R@10 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 10) * 100} %")
|
| 164 |
+
print(f"R@50 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 50) * 100} %")
|
| 165 |
+
print(f"R@100 for CB2CL: {compute_recall_at_k(scores_mat, cl_label, cb_label, 100) * 100} %")
|
| 166 |
+
torch.cuda.empty_cache()
|
| 167 |
+
|
| 168 |
+
return cl2cbk1,eer_cb2cl,cbcltf102,cbcltf103,cbcltf104
|
| 169 |
+
|
| 170 |
+
def main():
|
| 171 |
+
# Training settings
|
| 172 |
+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
| 173 |
+
parser.add_argument('--manifest-list', type=list, default=mani_lst,
|
| 174 |
+
help='list of manifest files')
|
| 175 |
+
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
|
| 176 |
+
help='input batch size for training (default: 64)')
|
| 177 |
+
parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
|
| 178 |
+
help='input batch size for testing (default: 1000)')
|
| 179 |
+
parser.add_argument('--epochs', type=int, default=50, metavar='N',
|
| 180 |
+
help='number of epochs to train (default: 14)')
|
| 181 |
+
parser.add_argument('--lr_fusion', type=float, default=1.0, metavar='LR',
|
| 182 |
+
help='learning rate (default: 1.0)')
|
| 183 |
+
parser.add_argument('--gamma', type=float, default=0.9, metavar='M',
|
| 184 |
+
help='Learning rate step gamma (default: 0.7)')
|
| 185 |
+
parser.add_argument('--no-cuda', action='store_true', default=False,
|
| 186 |
+
help='disables CUDA training')
|
| 187 |
+
parser.add_argument('--dry-run', action='store_true', default=False,
|
| 188 |
+
help='quickly check a single pass')
|
| 189 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
| 190 |
+
help='random seed (default: 1)')
|
| 191 |
+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
| 192 |
+
help='how many batches to wait before logging training status')
|
| 193 |
+
parser.add_argument('--warmup', type=int, default=2, metavar='N',
|
| 194 |
+
help='warm up rate for feature extractor')
|
| 195 |
+
parser.add_argument('--model-name', type=str, default="swinmodel",
|
| 196 |
+
help='Name of the model for checkpointing')
|
| 197 |
+
args = parser.parse_args()
|
| 198 |
+
|
| 199 |
+
device = torch.device("cuda")
|
| 200 |
+
model = Model().to(device)
|
| 201 |
+
ckpt_combined_phase1_ft = "ridgeformer_checkpoints/combined_models_check/phase1_ft_hkpoly.pt"
|
| 202 |
+
ckpt_combined_phase2 = "ridgeformer_checkpoints/phase2_scratch.pt"
|
| 203 |
+
|
| 204 |
+
model.load_pretrained_models(ckpt_combined_phase1_ft, ckpt_combined_phase2)
|
| 205 |
+
model.freeze_backbone()
|
| 206 |
+
checkpoint_save_path = "ridgeformer_checkpoints/"
|
| 207 |
+
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
| 208 |
+
|
| 209 |
+
if not os.path.exists("experiment_logs/"+args.model_name):
|
| 210 |
+
os.mkdir("experiment_logs/"+args.model_name)
|
| 211 |
+
|
| 212 |
+
log_writer = SummaryWriter("experiment_logs/"+args.model_name+"/",comment = str(args.batch_size)+str(args.lr_fusion))
|
| 213 |
+
|
| 214 |
+
torch.manual_seed(args.seed)
|
| 215 |
+
|
| 216 |
+
print("loading Normal RGB images -----------------------------")
|
| 217 |
+
train_dataset = Combined_original(args.manifest_list,split="train")
|
| 218 |
+
val_dataset = hktest(split="test")
|
| 219 |
+
|
| 220 |
+
balanced_sampler = BalancedSampler(train_dataset, batch_size = args.batch_size, images_per_class = 2)
|
| 221 |
+
batch_sampler = BatchSampler(balanced_sampler, batch_size = args.batch_size, drop_last = True)
|
| 222 |
+
|
| 223 |
+
train_kwargs = {'batch_sampler': batch_sampler}
|
| 224 |
+
test_kwargs = {'batch_size': args.test_batch_size}
|
| 225 |
+
|
| 226 |
+
if use_cuda:
|
| 227 |
+
cuda_kwargs = {
|
| 228 |
+
'num_workers': 1,
|
| 229 |
+
'pin_memory': True
|
| 230 |
+
}
|
| 231 |
+
train_kwargs.update(cuda_kwargs)
|
| 232 |
+
test_kwargs.update(cuda_kwargs)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, **train_kwargs)
|
| 236 |
+
test_loader = torch.utils.data.DataLoader(val_dataset, **test_kwargs)
|
| 237 |
+
|
| 238 |
+
print("Number of Trainable Parameters: - ", count_parameters(model))
|
| 239 |
+
|
| 240 |
+
loss_func = DualMSLoss_FineGrained()
|
| 241 |
+
optimizer_fusion = optim.AdamW(
|
| 242 |
+
[
|
| 243 |
+
{"params": model.output_logit_mlp.parameters(), "lr":args.lr_fusion},
|
| 244 |
+
{"params": model.fusion.parameters(), "lr":args.lr_fusion},
|
| 245 |
+
{"params": model.sep_token, "lr":args.lr_fusion},
|
| 246 |
+
{"params": model.encoder_layer.parameters(), "lr":args.lr_fusion},
|
| 247 |
+
|
| 248 |
+
],
|
| 249 |
+
weight_decay=0.000001,
|
| 250 |
+
lr=args.lr_fusion)
|
| 251 |
+
|
| 252 |
+
scheduler = MultiStepLR(optimizer_fusion, milestones = [3,6,9,14], gamma=0.5)
|
| 253 |
+
|
| 254 |
+
cl2cl_lst,cb2cl_lst,eer_cl2cl_lst,eer_cb2cl_lst,cbcltf102_lst,cbcltf103_lst,cbcltf104_lst,clcltf102_lst,clcltf103_lst,clcltf104_lst = list(),list(),list(),list(),list(),list(),list(),list(),list(),list()
|
| 255 |
+
stepping = 1
|
| 256 |
+
for epoch in range(1, args.epochs + 1):
|
| 257 |
+
print(f"running epoch------ {epoch}")
|
| 258 |
+
avg_step_loss,stepping = train(args, model, device, train_loader, test_loader, [optimizer_fusion], epoch, loss_func, [args.model_name,str(args.batch_size),str(args.lr_fusion)],stepping,log_writer, checkpoint_save_path)
|
| 259 |
+
|
| 260 |
+
print(f"Learning Rate for {epoch} for linear = {scheduler.get_last_lr()}")
|
| 261 |
+
print(f"Learning Rate for {epoch} for swin = {scheduler.get_last_lr()}")
|
| 262 |
+
|
| 263 |
+
log_writer.add_scalar('Liner_LR/epoch',scheduler.get_last_lr()[0],epoch)
|
| 264 |
+
log_writer.add_scalar('Swin_LR/epoch',scheduler.get_last_lr()[0],epoch)
|
| 265 |
+
|
| 266 |
+
scheduler.step()
|
| 267 |
+
|
| 268 |
+
cl2clk1,cl2cbk1,eer_cb2cl,eer_cl2cl,cbcltf102,cbcltf103,cbcltf104,clcltf102,clcltf103,clcltf104 = hkpoly_test_fn(model, device, test_loader, epoch, [args.model_name,str(args.batch_size),str(args.lr_fusion)])
|
| 269 |
+
cl2cl_lst.append(cl2clk1)
|
| 270 |
+
cb2cl_lst.append(cl2cbk1)
|
| 271 |
+
eer_cl2cl_lst.append(eer_cl2cl)
|
| 272 |
+
eer_cb2cl_lst.append(eer_cb2cl)
|
| 273 |
+
cbcltf102_lst.append(cbcltf102)
|
| 274 |
+
cbcltf103_lst.append(cbcltf103)
|
| 275 |
+
cbcltf104_lst.append(cbcltf104)
|
| 276 |
+
clcltf102_lst.append(clcltf102)
|
| 277 |
+
clcltf103_lst.append(clcltf103)
|
| 278 |
+
clcltf104_lst.append(clcltf104)
|
| 279 |
+
|
| 280 |
+
log_writer.add_scalars('recall@1/epoch',{'CL2CL':cl2clk1,'CB2CL':cl2cbk1},epoch)
|
| 281 |
+
log_writer.add_scalars('EER/epoch',{'CL2CL':eer_cl2cl,'CB2CL':eer_cb2cl},epoch)
|
| 282 |
+
log_writer.add_scalars('TARFAR10^-2/epoch',{'CL2CL':clcltf102,'CB2CL':cbcltf102},epoch)
|
| 283 |
+
log_writer.add_scalars('TARFAR10^-4/epoch',{'CL2CL':clcltf104,'CB2CL':cbcltf104},epoch)
|
| 284 |
+
log_writer.add_scalar('AvgLoss/epoch',avg_step_loss,epoch)
|
| 285 |
+
|
| 286 |
+
torch.save(model.state_dict(), checkpoint_save_path + "combinedtrained_hkpolytest_" + args.model_name + "_" + str(args.lr_fusion) + "_" + str(args.batch_size) + str(epoch) + "_" + str(cl2clk1)+ "_" + str(cl2cbk1) + ".pt")
|
| 287 |
+
log_writer.close()
|
| 288 |
+
|
| 289 |
+
print(f"Maximum recall@1 for CL2CL: {max(cl2cl_lst)} at epoch {cl2cl_lst.index(max(cl2cl_lst))+1}")
|
| 290 |
+
print(f"Maximum recall@1 for CB2CL: {max(cb2cl_lst)} at epoch {cb2cl_lst.index(max(cb2cl_lst))+1}")
|
| 291 |
+
print(f"Minimum EER for CL2CL: {min(eer_cl2cl_lst)} at epoch {eer_cl2cl_lst.index(min(eer_cl2cl_lst))+1}")
|
| 292 |
+
print(f"Minimum EER for CB2CL: {min(eer_cb2cl_lst)} at epoch {eer_cb2cl_lst.index(min(eer_cb2cl_lst))+1}")
|
| 293 |
+
print(f"Maximum TAR@FAR=10^-2 for CB2CL: {max(cbcltf102_lst)} at epoch {cbcltf102_lst.index(max(cbcltf102_lst))+1}")
|
| 294 |
+
print(f"Maximum TAR@FAR=10^-3 for CB2CL: {max(cbcltf103_lst)} at epoch {cbcltf103_lst.index(max(cbcltf103_lst))+1}")
|
| 295 |
+
print(f"Maximum TAR@FAR=10^-4 for CB2CL: {max(cbcltf104_lst)} at epoch {cbcltf104_lst.index(max(cbcltf104_lst))+1}")
|
| 296 |
+
print(f"Maximum TAR@FAR=10^-2 for CL2CL: {max(clcltf102_lst)} at epoch {clcltf102_lst.index(max(clcltf102_lst))+1}")
|
| 297 |
+
print(f"Maximum TAR@FAR=10^-3 for CL2CL: {max(clcltf103_lst)} at epoch {clcltf103_lst.index(max(clcltf103_lst))+1}")
|
| 298 |
+
print(f"Maximum TAR@FAR=10^-4 for CL2CL: {max(clcltf104_lst)} at epoch {clcltf104_lst.index(max(clcltf104_lst))+1}")
|
| 299 |
+
|
| 300 |
+
if __name__ == '__main__':
|
| 301 |
+
main()
|
utils.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
|
| 6 |
+
class RetMetric(object):
|
| 7 |
+
def __init__(self, sim_mat, labels):
|
| 8 |
+
self.gallery_labels, self.query_labels = labels
|
| 9 |
+
self.sim_mat = sim_mat
|
| 10 |
+
self.is_equal_query = False
|
| 11 |
+
|
| 12 |
+
def recall_k(self, k=1):
|
| 13 |
+
m = len(self.sim_mat)
|
| 14 |
+
|
| 15 |
+
match_counter = 0
|
| 16 |
+
|
| 17 |
+
for i in range(m):
|
| 18 |
+
pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
|
| 19 |
+
neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
|
| 20 |
+
|
| 21 |
+
thresh = np.sort(pos_sim)[-2] if self.is_equal_query and len(pos_sim) > 1 else np.max(pos_sim)
|
| 22 |
+
|
| 23 |
+
if np.sum(neg_sim > thresh) < k:
|
| 24 |
+
match_counter += 1
|
| 25 |
+
return float(match_counter) / m
|
| 26 |
+
|
| 27 |
+
class Prev_RetMetric(object):
|
| 28 |
+
def __init__(self, feats, labels, cl2cl=True):
|
| 29 |
+
|
| 30 |
+
if len(feats) == 2 and type(feats) == list:
|
| 31 |
+
"""
|
| 32 |
+
feats = [gallery_feats, query_feats]
|
| 33 |
+
labels = [gallery_labels, query_labels]
|
| 34 |
+
"""
|
| 35 |
+
self.is_equal_query = False
|
| 36 |
+
|
| 37 |
+
self.gallery_feats, self.query_feats = feats
|
| 38 |
+
self.gallery_labels, self.query_labels = labels
|
| 39 |
+
|
| 40 |
+
else:
|
| 41 |
+
self.is_equal_query = True
|
| 42 |
+
self.gallery_feats = self.query_feats = feats
|
| 43 |
+
self.gallery_labels = self.query_labels = labels
|
| 44 |
+
|
| 45 |
+
self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats))
|
| 46 |
+
if cl2cl:
|
| 47 |
+
self.sim_mat = self.sim_mat * (1 - np.identity(self.sim_mat.shape[0]))
|
| 48 |
+
|
| 49 |
+
def recall_k(self, k=1):
|
| 50 |
+
m = len(self.sim_mat)
|
| 51 |
+
|
| 52 |
+
match_counter = 0
|
| 53 |
+
|
| 54 |
+
for i in range(m):
|
| 55 |
+
pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]]
|
| 56 |
+
neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]]
|
| 57 |
+
|
| 58 |
+
thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim)
|
| 59 |
+
|
| 60 |
+
if np.sum(neg_sim > thresh) < k:
|
| 61 |
+
match_counter += 1
|
| 62 |
+
return float(match_counter) / m
|
| 63 |
+
|
| 64 |
+
def compute_recall_at_k(similarity_matrix, p_labels, g_labels, k):
|
| 65 |
+
num_probes = p_labels.size(0)
|
| 66 |
+
recall_at_k = 0.0
|
| 67 |
+
for i in range(num_probes):
|
| 68 |
+
probe_label = p_labels[i]
|
| 69 |
+
sim_scores = similarity_matrix[i]
|
| 70 |
+
sorted_indices = torch.argsort(sim_scores, descending=True)
|
| 71 |
+
top_k_indices = sorted_indices[:k]
|
| 72 |
+
correct_in_top_k = any(g_labels[idx] == probe_label for idx in top_k_indices)
|
| 73 |
+
recall_at_k += correct_in_top_k
|
| 74 |
+
recall_at_k /= num_probes
|
| 75 |
+
return recall_at_k
|
| 76 |
+
|
| 77 |
+
def count_parameters(model):
|
| 78 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 79 |
+
|
| 80 |
+
def l2_norm(input):
|
| 81 |
+
input_size = input.size()
|
| 82 |
+
buffer = torch.pow(input, 2)
|
| 83 |
+
normp = torch.sum(buffer, 1).add_(1e-12)
|
| 84 |
+
norm = torch.sqrt(normp)
|
| 85 |
+
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
|
| 86 |
+
output = _output.view(input_size)
|
| 87 |
+
|
| 88 |
+
return output
|
| 89 |
+
|
| 90 |
+
def compute_sharded_cosine_similarity(tensor1, tensor2, shard_size):
|
| 91 |
+
B, T, D = tensor1.shape
|
| 92 |
+
average_sim_matrix = torch.zeros((B, B), device=tensor1.device)
|
| 93 |
+
|
| 94 |
+
for start_idx1 in tqdm(range(0, B, shard_size)):
|
| 95 |
+
end_idx1 = min(start_idx1 + shard_size, B)
|
| 96 |
+
|
| 97 |
+
for start_idx2 in range(0, B, shard_size):
|
| 98 |
+
end_idx2 = min(start_idx2 + shard_size, B)
|
| 99 |
+
|
| 100 |
+
# Get the shard
|
| 101 |
+
shard_tensor1 = tensor1[start_idx1:end_idx1]
|
| 102 |
+
shard_tensor2 = tensor2[start_idx2:end_idx2]
|
| 103 |
+
|
| 104 |
+
# Reshape and expand
|
| 105 |
+
shard_tensor1_expanded = shard_tensor1.unsqueeze(1).unsqueeze(3)
|
| 106 |
+
shard_tensor2_expanded = shard_tensor2.unsqueeze(0).unsqueeze(2)
|
| 107 |
+
|
| 108 |
+
# Compute cosine similarity for the shard
|
| 109 |
+
shard_cos_sim = F.cosine_similarity(shard_tensor1_expanded, shard_tensor2_expanded, dim=-1)
|
| 110 |
+
|
| 111 |
+
# Sum up the cosine similarities
|
| 112 |
+
average_sim_matrix[start_idx1:end_idx1, start_idx2:end_idx2] += torch.sum(shard_cos_sim, dim=[2, 3])
|
| 113 |
+
|
| 114 |
+
# Normalize by the total number of elements (T*T)
|
| 115 |
+
average_sim_matrix /= (T * T)
|
| 116 |
+
|
| 117 |
+
return average_sim_matrix
|