import torch from torch.nn.functional import cross_entropy, softmax from .configuration_bacformer import SPECIAL_TOKENS_DICT def compute_contrastive_loss( protein_embeddings: torch.Tensor, last_hidden_state: torch.Tensor, special_tokens_mask: torch.Tensor, ) -> torch.Tensor: """Compute contrastive loss between protein embeddings and masked items.""" # keep protein embeddings and masked items # ensure the batch size is 1, the model currently does not work with batch size > 1 assert protein_embeddings.shape[0] == last_hidden_state.shape[0] == 1 # subset to mask and protein embedding tokens special_tokens_mask = special_tokens_mask.squeeze(0) mask = (special_tokens_mask == SPECIAL_TOKENS_DICT["PROT_EMB"]) | ( special_tokens_mask == SPECIAL_TOKENS_DICT["MASK"] ) protein_embeddings = protein_embeddings.squeeze(0)[mask] last_hidden_state = last_hidden_state.squeeze(0)[mask] # Normalize embeddings last_hidden_state = last_hidden_state / last_hidden_state.norm(dim=1, keepdim=True) protein_embeddings = protein_embeddings / protein_embeddings.norm(dim=1, keepdim=True) # Compute similarity matrix and loss as before similarity_matrix = torch.matmul(last_hidden_state, protein_embeddings.T) n_prots = protein_embeddings.shape[0] labels = torch.arange(n_prots).to(protein_embeddings.device) # Compute the loss loss = cross_entropy(similarity_matrix, labels) return loss def top_k_filtering(logits: torch.Tensor, top_k: int = 50): """ Keep only top_k logits and set the rest to -inf. Args: logits (torch.Tensor): Logits of shape (batch_size, vocab_size). top_k (int): The number of highest probability logits to keep. Returns ------- torch.Tensor: Filtered logits where only the top k values remain, and all others are -inf. """ if top_k <= 0: return logits # Find top_k values top_k = min(top_k, logits.size(-1)) vals, idx = torch.topk(logits, top_k, dim=-1) # Get the smallest logit in the top_k min_vals = vals[:, -1].unsqueeze(-1) # Mask all logits that are < this min value mask = logits < min_vals logits[mask] = float("-inf") return logits def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9): """ Keep the smallest set of logits whose cumulative probability >= top_p. Args: logits (torch.Tensor): Logits of shape (batch_size, vocab_size). top_p (float): Cumulative probability threshold. Returns ------- torch.Tensor: Filtered logits where only tokens within the top_p cumulative probability mass are kept; the rest are set to -inf. """ if top_p >= 1.0: return logits sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1) # Identify where cumulative probability exceeds top_p sorted_indices_to_remove = cumulative_probs > top_p # Shift the mask to ensure we always keep at least one token sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False # Scatter to replicate the mask in the original ordering for i in range(logits.size(0)): remove_indices = sorted_indices[i, sorted_indices_to_remove[i]] logits[i, remove_indices] = float("-inf") return logits def create_4d_from_2d_attn_mask(attn_mask: torch.Tensor, num_attn_heads: int): """Helper function to reshape attn_mask to 3D from 2D""" assert ( len(attn_mask.shape) == 2 ), f"Please provide attn_mask of shape (batch_size, seq_len), current shape {attn_mask.shape}" bs, seq_len = attn_mask.shape attn_mask = attn_mask.view(bs, 1, 1, seq_len) attn_mask = attn_mask.expand(-1, num_attn_heads, -1, -1) attn_mask = attn_mask.view(bs, num_attn_heads, -1, seq_len) return attn_mask