|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional |
|
|
from torch import Tensor |
|
|
|
|
|
class ContrastiveLoss(nn.Module): |
|
|
def __init__(self, temperature=.25, distance_metric='cosine'): |
|
|
super(ContrastiveLoss, self).__init__() |
|
|
self.temperature = temperature |
|
|
self.distance_metric = distance_metric |
|
|
|
|
|
def compute_similarity(self, embeddings): |
|
|
if self.distance_metric == 'cosine': |
|
|
embeddings = F.normalize(embeddings, p=2, dim=-1) |
|
|
sim = torch.matmul(embeddings, embeddings.transpose(-1, -2)) |
|
|
else: |
|
|
raise ValueError(f"Unsupported distance metric: {self.distance_metric}") |
|
|
return sim / self.temperature |
|
|
|
|
|
def compute_cross_similarity(self, embeddings1, embeddings2): |
|
|
"""Compute similarity between two different embedding sets""" |
|
|
if self.distance_metric == 'cosine': |
|
|
embeddings1 = F.normalize(embeddings1, p=2, dim=-1) |
|
|
embeddings2 = F.normalize(embeddings2, p=2, dim=-1) |
|
|
sim = torch.matmul(embeddings1, embeddings2.transpose(-1, -2)) |
|
|
else: |
|
|
raise ValueError(f"Unsupported distance metric: {self.distance_metric}") |
|
|
return sim / self.temperature |
|
|
|
|
|
def pairwise_and_no_diag(self, m): |
|
|
m_i = m.unsqueeze(2) |
|
|
m_j = m.unsqueeze(1) |
|
|
out = m_i & m_j |
|
|
diag = torch.eye(m.size(1), dtype=torch.bool, device=m.device).unsqueeze(0) |
|
|
return out & ~diag |
|
|
|
|
|
def forward(self, embeddings, anchors, enrollment_embeddings: Optional[Tensor] = None, |
|
|
enrollment_embeddings_mask: Optional[Tensor] = None): |
|
|
""" |
|
|
Args: |
|
|
embeddings: [B, 2T, D] - main embeddings |
|
|
anchors: [B, 2T] - boolean mask indicating anchor positions |
|
|
enrollment_embeddings: Optional[B, 2T, D] - enrollment embeddings for positive pairs |
|
|
enrollment_embeddings_mask: Optional[B, 2T] - boolean mask for valid enrollment positions |
|
|
Returns: |
|
|
Scalar contrastive loss |
|
|
""" |
|
|
|
|
|
if enrollment_embeddings is not None and enrollment_embeddings_mask is not None: |
|
|
return self._forward_with_enrollment(embeddings, anchors, enrollment_embeddings, enrollment_embeddings_mask) |
|
|
else: |
|
|
|
|
|
return self._forward_original(embeddings, anchors) |
|
|
|
|
|
def _forward_with_enrollment(self, embeddings, anchors, enrollment_embeddings, enrollment_embeddings_mask): |
|
|
"""Forward pass using enrollment embeddings as positives""" |
|
|
B, two_T, D = embeddings.shape |
|
|
T = two_T // 2 |
|
|
|
|
|
|
|
|
cross_sim = self.compute_cross_similarity(embeddings, enrollment_embeddings) |
|
|
|
|
|
|
|
|
self_sim = self.compute_similarity(embeddings) |
|
|
|
|
|
|
|
|
m1 = anchors[:, :T] |
|
|
m2 = anchors[:, T:] |
|
|
|
|
|
|
|
|
enroll_m1 = enrollment_embeddings_mask[:, :T] |
|
|
enroll_m2 = enrollment_embeddings_mask[:, T:] |
|
|
|
|
|
|
|
|
|
|
|
pos_mask_1to1 = m1.unsqueeze(2) & enroll_m1.unsqueeze(1) |
|
|
|
|
|
pos_mask_2to2 = m2.unsqueeze(2) & enroll_m2.unsqueeze(1) |
|
|
|
|
|
|
|
|
pos_mask = torch.cat([ |
|
|
torch.cat([pos_mask_1to1, torch.zeros_like(pos_mask_1to1)], dim=2), |
|
|
torch.cat([torch.zeros_like(pos_mask_2to2), pos_mask_2to2], dim=2) |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
cross = m1.unsqueeze(2) & m2.unsqueeze(1) |
|
|
neg_mask = torch.cat([ |
|
|
torch.cat([torch.zeros_like(cross), cross], dim=2), |
|
|
torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) |
|
|
neg_mask &= ~identity_mask |
|
|
|
|
|
|
|
|
pos_mask &= ~identity_mask |
|
|
|
|
|
|
|
|
if pos_mask.any(): |
|
|
|
|
|
pos_sim = cross_sim[pos_mask] |
|
|
pos_exp = torch.exp(pos_sim) |
|
|
|
|
|
|
|
|
exp_self_sim = torch.exp(self_sim) |
|
|
neg_exp_sum = torch.sum(exp_self_sim * neg_mask.float(), dim=2) |
|
|
|
|
|
|
|
|
pos_indices = torch.nonzero(pos_mask, as_tuple=False) |
|
|
batch_idx = pos_indices[:, 0] |
|
|
row_idx = pos_indices[:, 1] |
|
|
|
|
|
|
|
|
neg_sums_for_pos = neg_exp_sum[batch_idx, row_idx] |
|
|
|
|
|
|
|
|
denominators = pos_exp + neg_sums_for_pos |
|
|
|
|
|
|
|
|
loss = -torch.log(pos_exp / denominators) |
|
|
total_loss = loss.mean() |
|
|
else: |
|
|
|
|
|
total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True) |
|
|
|
|
|
return total_loss |
|
|
|
|
|
def _forward_original(self, embeddings, pos_indicator_mask): |
|
|
"""Original forward pass for backward compatibility""" |
|
|
B, two_T, D = embeddings.shape |
|
|
T = two_T // 2 |
|
|
sim = self.compute_similarity(embeddings) |
|
|
|
|
|
|
|
|
m1 = pos_indicator_mask[:, :T] |
|
|
m2 = pos_indicator_mask[:, T:] |
|
|
|
|
|
|
|
|
pos_block1 = self.pairwise_and_no_diag(m1) |
|
|
pos_block2 = self.pairwise_and_no_diag(m2) |
|
|
pos_mask = torch.cat([ |
|
|
torch.cat([pos_block1, torch.zeros_like(pos_block1)], dim=2), |
|
|
torch.cat([torch.zeros_like(pos_block2), pos_block2], dim=2) |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
cross = m1.unsqueeze(2) & m2.unsqueeze(1) |
|
|
neg_mask = torch.cat([ |
|
|
torch.cat([torch.zeros_like(cross), cross], dim=2), |
|
|
torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) |
|
|
pos_mask &= ~identity_mask |
|
|
neg_mask &= ~identity_mask |
|
|
|
|
|
|
|
|
if pos_mask.any(): |
|
|
|
|
|
exp_sim = torch.exp(sim) |
|
|
|
|
|
|
|
|
pos_sim = sim[pos_mask] |
|
|
pos_exp = torch.exp(pos_sim) |
|
|
|
|
|
|
|
|
neg_exp_avg = 10 * torch.mean(exp_sim * neg_mask.float(), dim=2) |
|
|
|
|
|
|
|
|
pos_indices = torch.nonzero(pos_mask, as_tuple=False) |
|
|
batch_idx = pos_indices[:, 0] |
|
|
row_idx = pos_indices[:, 1] |
|
|
|
|
|
|
|
|
neg_avgs_for_pos = neg_exp_avg[batch_idx, row_idx] |
|
|
|
|
|
|
|
|
denominators = pos_exp + neg_avgs_for_pos |
|
|
|
|
|
|
|
|
loss = -torch.log(pos_exp / denominators) |
|
|
total_loss = loss.mean() |
|
|
else: |
|
|
|
|
|
total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True) |
|
|
return total_loss |