SE_DiCoW / contrastive_loss.py
Lakoc's picture
Upload DiCoWForConditionalGeneration
64c2cbc verified
raw
history blame
9.02 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from torch import Tensor
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=.25, distance_metric='cosine'):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
self.distance_metric = distance_metric
def compute_similarity(self, embeddings):
if self.distance_metric == 'cosine':
embeddings = F.normalize(embeddings, p=2, dim=-1) # [B, 2T, D]
sim = torch.matmul(embeddings, embeddings.transpose(-1, -2)) # [B, 2T, 2T]
else:
raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
return sim / self.temperature
def compute_cross_similarity(self, embeddings1, embeddings2):
"""Compute similarity between two different embedding sets"""
if self.distance_metric == 'cosine':
embeddings1 = F.normalize(embeddings1, p=2, dim=-1) # [B, 2T, D]
embeddings2 = F.normalize(embeddings2, p=2, dim=-1) # [B, 2T, D]
sim = torch.matmul(embeddings1, embeddings2.transpose(-1, -2)) # [B, 2T, 2T]
else:
raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
return sim / self.temperature
def pairwise_and_no_diag(self, m):
m_i = m.unsqueeze(2) # [B, T, 1]
m_j = m.unsqueeze(1) # [B, 1, T]
out = m_i & m_j # [B, T, T]
diag = torch.eye(m.size(1), dtype=torch.bool, device=m.device).unsqueeze(0)
return out & ~diag
def forward(self, embeddings, anchors, enrollment_embeddings: Optional[Tensor] = None,
enrollment_embeddings_mask: Optional[Tensor] = None):
"""
Args:
embeddings: [B, 2T, D] - main embeddings
anchors: [B, 2T] - boolean mask indicating anchor positions
enrollment_embeddings: Optional[B, 2T, D] - enrollment embeddings for positive pairs
enrollment_embeddings_mask: Optional[B, 2T] - boolean mask for valid enrollment positions
Returns:
Scalar contrastive loss
"""
# Use enrollment embeddings if provided
if enrollment_embeddings is not None and enrollment_embeddings_mask is not None:
return self._forward_with_enrollment(embeddings, anchors, enrollment_embeddings, enrollment_embeddings_mask)
else:
# Fall back to original behavior
return self._forward_original(embeddings, anchors)
def _forward_with_enrollment(self, embeddings, anchors, enrollment_embeddings, enrollment_embeddings_mask):
"""Forward pass using enrollment embeddings as positives"""
B, two_T, D = embeddings.shape
T = two_T // 2
# Compute similarity between main embeddings and enrollment embeddings
cross_sim = self.compute_cross_similarity(embeddings, enrollment_embeddings) # [B, 2T, 2T]
# Compute similarity within main embeddings for negatives
self_sim = self.compute_similarity(embeddings) # [B, 2T, 2T]
# Split anchor mask
m1 = anchors[:, :T] # [B, T]
m2 = anchors[:, T:] # [B, T]
# Split enrollment mask
enroll_m1 = enrollment_embeddings_mask[:, :T] # [B, T]
enroll_m2 = enrollment_embeddings_mask[:, T:] # [B, T]
# Create positive mask: anchor positions can match with corresponding enrollment positions
# First speaker (positions 0:T) matches with enrollment first speaker (positions 0:T)
pos_mask_1to1 = m1.unsqueeze(2) & enroll_m1.unsqueeze(1) # [B, T, T]
# Second speaker (positions T:2T) matches with enrollment second speaker (positions T:2T)
pos_mask_2to2 = m2.unsqueeze(2) & enroll_m2.unsqueeze(1) # [B, T, T]
# Build full positive mask
pos_mask = torch.cat([
torch.cat([pos_mask_1to1, torch.zeros_like(pos_mask_1to1)], dim=2), # [B, T, 2T]
torch.cat([torch.zeros_like(pos_mask_2to2), pos_mask_2to2], dim=2) # [B, T, 2T]
], dim=1) # [B, 2T, 2T]
# Create negative mask: cross-speaker pairs within main embeddings
cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T]
neg_mask = torch.cat([
torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T]
torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T]
], dim=1) # [B, 2T, 2T]
# Exclude self-pairs in negative mask
identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T]
neg_mask &= ~identity_mask
# Also exclude self-pairs in positive mask (diagonal elements)
pos_mask &= ~identity_mask
# Compute contrastive loss
if pos_mask.any():
# Get positive similarities from cross-similarity matrix
pos_sim = cross_sim[pos_mask] # [num_pos_pairs]
pos_exp = torch.exp(pos_sim) # [num_pos_pairs]
# Compute negative exponentials from self-similarity matrix
exp_self_sim = torch.exp(self_sim) # [B, 2T, 2T]
neg_exp_sum = torch.sum(exp_self_sim * neg_mask.float(), dim=2) # [B, 2T]
# Get the negative sums corresponding to each positive pair
pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3]
batch_idx = pos_indices[:, 0] # [num_pos_pairs]
row_idx = pos_indices[:, 1] # [num_pos_pairs]
# Get negative sums for each positive pair's anchor
neg_sums_for_pos = neg_exp_sum[batch_idx, row_idx] # [num_pos_pairs]
# Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair
denominators = pos_exp + neg_sums_for_pos # [num_pos_pairs]
# InfoNCE loss: -log(exp(pos) / denominator)
loss = -torch.log(pos_exp / denominators)
total_loss = loss.mean()
else:
# No positive pairs found, return zero loss
total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True)
return total_loss
def _forward_original(self, embeddings, pos_indicator_mask):
"""Original forward pass for backward compatibility"""
B, two_T, D = embeddings.shape
T = two_T // 2
sim = self.compute_similarity(embeddings) # [B, 2T, 2T]
# Split input mask
m1 = pos_indicator_mask[:, :T] # [B, T]
m2 = pos_indicator_mask[:, T:] # [B, T]
# Positive mask (same speaker pairs, diagonal excluded)
pos_block1 = self.pairwise_and_no_diag(m1) # [B, T, T]
pos_block2 = self.pairwise_and_no_diag(m2) # [B, T, T]
pos_mask = torch.cat([
torch.cat([pos_block1, torch.zeros_like(pos_block1)], dim=2), # [B, T, 2T]
torch.cat([torch.zeros_like(pos_block2), pos_block2], dim=2) # [B, T, 2T]
], dim=1) # [B, 2T, 2T]
# Negative mask (cross-speaker pairs where both are active)
cross = m1.unsqueeze(2) & m2.unsqueeze(1) # [B, T, T]
neg_mask = torch.cat([
torch.cat([torch.zeros_like(cross), cross], dim=2), # [B, T, 2T]
torch.cat([cross.transpose(1, 2), torch.zeros_like(cross)], dim=2) # [B, T, 2T]
], dim=1) # [B, 2T, 2T]
# Identity mask (exclude [i, i] self-pairs)
identity_mask = torch.eye(two_T, dtype=torch.bool, device=embeddings.device).unsqueeze(0) # [1, 2T, 2T]
pos_mask &= ~identity_mask
neg_mask &= ~identity_mask
# Fully vectorized InfoNCE computation
if pos_mask.any():
# Compute exp(similarities) for numerical stability
exp_sim = torch.exp(sim) # [B, 2T, 2T]
# Get positive similarities
pos_sim = sim[pos_mask] # [num_pos_pairs]
pos_exp = torch.exp(pos_sim) # [num_pos_pairs]
# For each position, sum the exponentials of its negatives
neg_exp_avg = 10 * torch.mean(exp_sim * neg_mask.float(), dim=2) # [B, 2T]
# Get the negative sums corresponding to each positive pair
pos_indices = torch.nonzero(pos_mask, as_tuple=False) # [num_pos_pairs, 3]
batch_idx = pos_indices[:, 0] # [num_pos_pairs]
row_idx = pos_indices[:, 1] # [num_pos_pairs]
# Get negative sums for each positive pair's anchor
neg_avgs_for_pos = neg_exp_avg[batch_idx, row_idx] # [num_pos_pairs]
# Compute denominators: exp(pos) + sum(exp(neg)) for each positive pair
denominators = pos_exp + neg_avgs_for_pos # [num_pos_pairs]
# InfoNCE loss: -log(exp(pos) / denominator)
loss = -torch.log(pos_exp / denominators)
total_loss = loss.mean()
else:
# No positive pairs found, return zero loss
total_loss = torch.tensor(0.0, device=embeddings.device, requires_grad=True)
return total_loss