import math from collections import OrderedDict from dataclasses import dataclass from typing import Literal, Optional, Union import torch from torch import nn from torch.nn.functional import ( binary_cross_entropy_with_logits, cross_entropy, gelu, mse_loss, scaled_dot_product_attention, softmax, ) from transformers import PreTrainedModel from transformers.utils import ModelOutput from .configuration_bacformer import SPECIAL_TOKENS_DICT, BacformerConfig from .utils_bacformer import compute_contrastive_loss, create_4d_from_2d_attn_mask, top_k_filtering, top_p_filtering @dataclass class BacformerModelOutput(ModelOutput): """Base class for outputs of the Bacformer model.""" loss: torch.FloatTensor | None = None logits: torch.FloatTensor = None last_hidden_state: torch.FloatTensor | None = None attentions: Union[torch.FloatTensor, None] = None pooler_output: torch.FloatTensor | None = None # Taken from facebookresearch/llama/model.py def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): """Reshape the rotary embeddings for broadcasting.""" ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) # Taken from facebookresearch/llama/model.py def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to the query and key tensors.""" # reshape xq and xk to match the complex representation xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1) xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) # apply rotation using real numbers xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos # flatten last two dimensions xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) # Taken from facebookresearch/llama/model.py def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): """Precompute the freqs cis for rotary embeddings.""" freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore freqs_cos = torch.cos(freqs) # real part freqs_sin = torch.sin(freqs) # imaginary part return freqs_cos, freqs_sin def scaled_dot_product_attention_w_attn_weights( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None ) -> tuple[torch.Tensor, torch.Tensor]: """PyTorch Native implementation, modified to return attention weights.""" L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) attn_output = attn_weight @ value return attn_output, attn_weight class RotarySelfAttention(nn.Module): """Rotary self-attention module.""" def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.1, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dim_head = embed_dim // num_heads self.dropout_rate = dropout self.q = nn.Linear(embed_dim, embed_dim, bias=False) self.k = nn.Linear(embed_dim, embed_dim, bias=False) self.v = nn.Linear(embed_dim, embed_dim, bias=False) self.att_proj_linear = nn.Linear(embed_dim, embed_dim) def forward( self, x: torch.Tensor, attn_mask: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, is_causal: bool = False, return_attn_weights: bool = False, ): """Forward pass for the rotary self-attention module.""" batch_size, seq_len, _ = x.shape xq, xk, xv = self.q(x), self.k(x), self.v(x) # Reshape for rotary embeddings xq = xq.view(batch_size, seq_len, self.num_heads, self.dim_head) xk = xk.view(batch_size, seq_len, self.num_heads, self.dim_head) xv = xv.view(batch_size, seq_len, self.num_heads, self.dim_head) xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) # Reshape for attention calculation: (b_sz, n_head, s_len, d_head) xq = xq.transpose(1, 2) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) attn_weights = None if return_attn_weights: att, attn_weights = scaled_dot_product_attention_w_attn_weights( query=xq, key=xk, value=xv, attn_mask=attn_mask, dropout_p=self.dropout_rate if self.training else 0.0, is_causal=is_causal, ) else: att = scaled_dot_product_attention( query=xq, key=xk, value=xv, attn_mask=attn_mask, dropout_p=self.dropout_rate if self.training else 0.0, is_causal=is_causal, ) # Shape (b_sz, s_len, n_head, d_head) out = att.transpose(1, 2).contiguous() out = out.view(batch_size, seq_len, self.num_heads * self.dim_head) return self.att_proj_linear(out), attn_weights class BacformerTransformerLayer(nn.Module): """Own implementation of transformer layer which uses pytorch native MHA but returns attention weights""" def __init__( self, hidden_size: int, intermediate_size: int, num_attention_heads: int, dropout: float = 0.1, activation: Literal["gelu", "relu"] = "gelu", ): super().__init__() self.self_mha = RotarySelfAttention( embed_dim=hidden_size, num_heads=num_attention_heads, dropout=dropout, ) self.fc1 = nn.Linear(hidden_size, intermediate_size) self.fc2 = nn.Linear(intermediate_size, hidden_size) self.activation = nn.GELU() if activation == "gelu" else nn.ReLU() self.norm1 = nn.LayerNorm(hidden_size) self.norm2 = nn.LayerNorm(hidden_size) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) def forward( self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None, freqs_cos: torch.Tensor = None, freqs_sin: torch.Tensor = None, return_attn_weights: bool = False, is_causal: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Forward pass""" attn_outputs, attn_weights = self.self_mha( hidden_state, attn_mask=attention_mask, freqs_cos=freqs_cos, freqs_sin=freqs_sin, return_attn_weights=return_attn_weights, is_causal=is_causal, ) x = self.norm1(hidden_state + self.dropout1(attn_outputs)) ff_output = self.fc2(self.dropout2(self.activation(self.fc1(x)))) x = self.norm2(x + self.dropout3(ff_output)) return x, attn_weights class BacformerTransformerEncoder(nn.Module): """Own implementation of Transformer which return attention weights""" def __init__( self, num_hidden_layers: int, hidden_size: int, intermediate_size: int, num_attention_heads: int, dropout: float = 0.1, activation: Literal["gelu", "relu"] = "gelu", ): super().__init__() self.layers = nn.ModuleList( [ BacformerTransformerLayer( hidden_size=hidden_size, intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, dropout=dropout, activation=activation, ) for _ in range(num_hidden_layers) ] ) self.gradient_checkpointing = False def forward( self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None, freqs_cos: torch.Tensor = None, freqs_sin: torch.Tensor = None, return_attn_weights: bool = False, is_causal: bool = False, ) -> tuple[torch.Tensor, list[torch.Tensor | None]]: """Forward pass""" attn_weights_arr = [] for layer in self.layers: if self.gradient_checkpointing and self.training: hidden_state, attn_weights = self._gradient_checkpointing_func( layer.__call__, hidden_state, attention_mask, freqs_cos, freqs_sin, return_attn_weights, is_causal, ) else: hidden_state, attn_weights = layer( hidden_state=hidden_state, attention_mask=attention_mask, freqs_cos=freqs_cos, freqs_sin=freqs_sin, return_attn_weights=return_attn_weights, is_causal=is_causal, ) # keep the attention weights from each layer attn_weights_arr.append(attn_weights) return hidden_state, attn_weights_arr class BacformerEmbeddings(nn.Module): """Construct the protein embeddings from protein sequence, position embeddings and sequence type embeddings.""" def __init__(self, config): super().__init__() self.config = config self.linear = nn.Linear(config.hidden_size, config.hidden_size) self.token_type_embeddings = nn.Embedding( num_embeddings=config.max_token_type_embeddings + 1, embedding_dim=config.hidden_size, padding_idx=config.max_token_type_embeddings, ) self.special_tokens_embeddings = nn.Embedding( num_embeddings=config.num_special_tokens, embedding_dim=config.hidden_size, ) self.prot_emb_token_id = config.prot_emb_token_id self.pad_token_id = config.pad_token_id self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, protein_embeddings: torch.Tensor = None, special_tokens_mask: torch.Tensor = None, token_type_ids: torch.Tensor = None, labels: torch.Tensor = None, # used for causal protein family modeling property_ids: torch.Tensor = None, # used for conditional fine-tuning for desired property ) -> torch.Tensor: """Forward pass for protein embeddings.""" bs, seq_length, dim = protein_embeddings.shape # pass the pooled ESM protein embeddings through a linear layer protein_embeddings = self.linear(protein_embeddings) protein_embeddings = torch.where( special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id, protein_embeddings, self.special_tokens_embeddings(special_tokens_mask), ) if token_type_ids is not None: protein_embeddings += self.token_type_embeddings(token_type_ids) protein_embeddings = self.LayerNorm(protein_embeddings) protein_embeddings = self.dropout(protein_embeddings) return protein_embeddings class BacformerProteinFamilyEmbeddings(nn.Module): """Construct the protein embeddings from protein family tokens, special tokens and sequence type embeddings.""" def __init__( self, config, protein_family_embeddings: torch.Tensor = None, token_type_embeddings: torch.Tensor = None, special_tokens_embeddings: torch.Tensor = None, n_conditional_properties: int = None, ): super().__init__() self.config = config if protein_family_embeddings is not None: self.protein_family_embeddings = nn.Embedding.from_pretrained( protein_family_embeddings, freeze=False, padding_idx=config.pad_token_id, ) else: self.protein_family_embeddings = nn.Embedding( num_embeddings=config.protein_clusters_vocab_size + 1, embedding_dim=config.hidden_size, padding_idx=config.pad_token_id, ) if token_type_embeddings is not None: self.token_type_embeddings = nn.Embedding.from_pretrained( token_type_embeddings, freeze=False, padding_idx=config.max_token_type_embeddings, ) else: self.token_type_embeddings = nn.Embedding( num_embeddings=config.max_token_type_embeddings + 1, embedding_dim=config.hidden_size, padding_idx=config.max_token_type_embeddings, ) if special_tokens_embeddings is not None: self.special_tokens_embeddings = nn.Embedding.from_pretrained( special_tokens_embeddings, freeze=False, padding_idx=config.pad_token_id, ) else: self.special_tokens_embeddings = nn.Embedding( num_embeddings=config.num_special_tokens, embedding_dim=config.hidden_size, padding_idx=config.pad_token_id, ) # add layer for conditional properties if n_conditional_properties is not None: self.conditional_properties_layer = nn.Embedding(n_conditional_properties, config.hidden_size) self.prot_emb_token_id = config.prot_emb_token_id self.pad_token_id = config.pad_token_id self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, protein_embeddings: torch.Tensor = None, special_tokens_mask: torch.Tensor = None, token_type_ids: torch.Tensor = None, labels: torch.Tensor = None, # used for causal protein family modeling property_ids: torch.Tensor = None, # used for conditional fine-tuning for desired property ) -> torch.Tensor: """Forward pass for protein embeddings.""" # pass the pooled ESM protein embeddings through a linear layer # replace -100 with pad_token_id labels[labels == -100] = self.pad_token_id protein_embeddings = self.protein_family_embeddings(labels) bs, seq_length, dim = protein_embeddings.shape protein_embeddings = torch.where( special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id, protein_embeddings, self.special_tokens_embeddings(special_tokens_mask), ) if token_type_ids is not None: protein_embeddings += self.token_type_embeddings(token_type_ids) if property_ids is not None: # get the embeddings for the conditional properties property_embedding = self.conditional_properties_layer(property_ids).unsqueeze(1) # concatenate the protein embeddings with the conditional properties embeddings # property embeddings are added to the beginning of the protein embeddings after the CLS token protein_embeddings = torch.cat( [ protein_embeddings[:, :1, :], # CLS token property_embedding, # conditional properties embeddings protein_embeddings[:, 1:, :], ], # protein embeddings dim=1, ) protein_embeddings = self.LayerNorm(protein_embeddings) protein_embeddings = self.dropout(protein_embeddings) return protein_embeddings class BacformerEncoder(nn.Module): """Bacformer encoder model""" def __init__(self, config): super().__init__() self.config = config self.encoder = BacformerTransformerEncoder( num_hidden_layers=config.num_hidden_layers, hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, activation="gelu", dropout=config.attention_probs_dropout_prob, ) # Note that config.max_position_embeddings is multiplied by 1.5 because the token limit for the Bacformer of # models is 6000. Adding this multiplier instead of using 6000 directly allows for dynamism of token # lengths while training or fine-tuning. freqs_cos, freqs_sin = precompute_freqs_cis( config.hidden_size // config.num_attention_heads, int(config.max_position_embeddings * 1.5) ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor = None, return_attn_weights: Union[bool, None] = None, is_causal: bool = False, ) -> tuple[torch.Tensor, list[torch.Tensor | None]]: """Pass the input through the encoder layers in turn. Args: hidden_states: hidden states from the BacformerEmbeddings layer attention_mask: mask for the attention in the transformer """ return_attn_weights = ( return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights ) bs, seq_len, _ = hidden_states.shape last_hidden_state, attn_weights = self.encoder( hidden_state=hidden_states, attention_mask=attention_mask, freqs_cos=self.freqs_cos[:seq_len, :], freqs_sin=self.freqs_sin[:seq_len, :], return_attn_weights=return_attn_weights, is_causal=is_causal, ) return last_hidden_state, attn_weights class BacformerPreTrainedModel(PreTrainedModel): """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.""" config_class = BacformerConfig base_model_prefix = "bacformer" supports_gradient_checkpointing = True _no_split_modules = ["BacformerEmbeddings", "BacformerTransformerLayer"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class BacformerModel(BacformerPreTrainedModel): """Bacformer model.""" def __init__(self, config: BacformerConfig, add_pooling_layer: bool = False): super().__init__(config) self.config = config self.embeddings = BacformerEmbeddings(config) self.encoder = BacformerEncoder(config) self.pooler = BacformerPooler(config) if add_pooling_layer else None # Initialize weights and apply final processing self.post_init() def forward( self, protein_embeddings: torch.Tensor = None, special_tokens_mask: torch.Tensor = None, token_type_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, labels: torch.Tensor = None, property_ids: torch.Tensor = None, return_attn_weights: bool = False, return_dict: Union[bool, None] = None, is_causal: bool = False, ) -> Optional[BacformerModelOutput]: """Forward method for the model.""" return_dict = return_dict if return_dict is not None else self.config.return_dict # get embeddings protein_embeddings = self.embeddings( protein_embeddings=protein_embeddings, labels=labels, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, property_ids=property_ids, ) # create 3D attention mask from 2D if not doing causal GM if attention_mask is not None and not is_causal: attention_mask = create_4d_from_2d_attn_mask( attn_mask=attention_mask, num_attn_heads=self.config.num_attention_heads ).bool() last_hidden_state, attentions = self.encoder( hidden_states=protein_embeddings, attention_mask=attention_mask, return_attn_weights=return_attn_weights, is_causal=is_causal, ) pooler_output = ( self.pooler(hidden_states=last_hidden_state, padding_mask=attention_mask) if self.pooler is not None else None ) if not return_dict: return (last_hidden_state, pooler_output, attentions) return BacformerModelOutput( last_hidden_state=last_hidden_state, pooler_output=pooler_output, attentions=attentions, ) class BacformerForCausalGM(BacformerPreTrainedModel): """Bacformer model with genomic modeling head on top""" _tied_weights_keys = ["gm_head.decoder.weight"] def __init__(self, config: BacformerConfig): super().__init__(config) self.config = config self.bacformer = BacformerModel(config, add_pooling_layer=False) self.gm_head = BacformerGMHead(config) # Initialize weights self.init_weights() def forward( self, protein_embeddings: torch.Tensor, special_tokens_mask: torch.Tensor, labels: torch.Tensor = None, token_type_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, return_attn_weights: bool = None, return_dict: Union[bool, None] = None, ) -> Optional[BacformerModelOutput]: """Forward method for the model.""" return_dict = return_dict if return_dict is not None else self.config.return_dict return_attn_weights = ( return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights ) outputs = self.bacformer( protein_embeddings=protein_embeddings, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, attention_mask=None, # attention mechanism handles the causal mask return_attn_weights=return_attn_weights, return_dict=return_dict, is_causal=True, ) last_hidden_state = outputs[0] prediction_scores = self.gm_head(last_hidden_state) loss = None if labels is not None: labels = labels.to(prediction_scores.device) shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1]) labels = labels[:, 1:].contiguous().view(-1) loss = cross_entropy(shifted_prediction_scores, labels) if not return_dict: return ( loss, prediction_scores, ) + outputs return BacformerModelOutput( loss=loss, logits=prediction_scores, last_hidden_state=outputs.last_hidden_state, attentions=outputs.attentions, ) class BacformerForMaskedGM(BacformerPreTrainedModel): """Bacformer model with genomic modeling head on top""" _tied_weights_keys = ["gm_head.decoder.weight"] def __init__(self, config: BacformerConfig): super().__init__(config) self.config = config self.bacformer = BacformerModel(config, add_pooling_layer=False) self.gm_head = BacformerGMHead(config) # Initialize weights self.init_weights() def forward( self, protein_embeddings: torch.Tensor, special_tokens_mask: torch.Tensor, labels: torch.Tensor = None, token_type_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, return_attn_weights: bool = None, return_dict: Union[bool, None] = None, ) -> Union[BacformerModelOutput, None]: """Forward method for the model.""" return_dict = return_dict if return_dict is not None else self.config.return_dict return_attn_weights = ( return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights ) outputs = self.bacformer( protein_embeddings=protein_embeddings, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, attention_mask=attention_mask, return_attn_weights=return_attn_weights, return_dict=return_dict, ) last_hidden_state = outputs[0] # to speed up the forward pass, let's only consider the masked tokens loss = None if labels is not None: # to speed up the forward pass, let's only consider the masked tokens last_hidden_state = last_hidden_state[labels != -100] prediction_scores = self.gm_head(last_hidden_state) labels = labels.to(prediction_scores.device) ### notes # use the labels to get -100 for non-masked tokens # do not use special_tokens_mask # check how the labels are constructed # only considering the masked tokens labels = labels[labels != -100] loss = cross_entropy(prediction_scores, labels) else: prediction_scores = self.gm_head(last_hidden_state) if not return_dict: return ( loss, prediction_scores, ) + outputs return BacformerModelOutput( loss=loss, logits=prediction_scores, last_hidden_state=outputs.last_hidden_state, attentions=outputs.attentions, ) class BacformerForCausalProteinFamilyModeling(BacformerPreTrainedModel): """Bacformer model for causal modeling of protein families. Using protein family as tokens rather than protein embeddings""" _tied_weights_keys = ["gm_head.decoder.weight"] def __init__( self, config: BacformerConfig, n_conditional_properties: int = None, initialise_from_non_pfm_model: bool = False, ): super().__init__(config) self.config = config self.cls_token_id = SPECIAL_TOKENS_DICT["CLS"] self.bacformer = BacformerModel(config, add_pooling_layer=False) self.gm_head = BacformerGMHead(config) if initialise_from_non_pfm_model: # Initialize weights self.init_weights() # overwrite the embeddings with the pretrained # protein family embeddings from the decoder of the GM Head self.bacformer.embeddings = BacformerProteinFamilyEmbeddings( config, protein_family_embeddings=self.gm_head.decoder.weight, token_type_embeddings=self.bacformer.embeddings.token_type_embeddings.weight, special_tokens_embeddings=self.bacformer.embeddings.special_tokens_embeddings.weight, n_conditional_properties=n_conditional_properties, ) else: self.bacformer.embeddings = BacformerProteinFamilyEmbeddings( config, n_conditional_properties=n_conditional_properties, ) self.init_weights() def forward( self, labels: torch.Tensor = None, special_tokens_mask: torch.Tensor = None, token_type_ids: torch.Tensor = None, property_ids: torch.Tensor = None, return_attn_weights: bool = None, return_dict: Union[bool, None] = None, ) -> Optional[BacformerModelOutput]: """Forward method for the model.""" return_dict = return_dict if return_dict is not None else self.config.return_dict return_attn_weights = ( return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights ) outputs = self.bacformer( protein_embeddings=None, labels=labels, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, property_ids=property_ids, return_attn_weights=return_attn_weights, return_dict=return_dict, is_causal=True, ) last_hidden_state = outputs[0] prediction_scores = self.gm_head(last_hidden_state) loss = None if labels is not None: if property_ids is not None: labels = torch.cat( [ torch.tensor([-100], dtype=torch.long) .unsqueeze(0) .to(labels.device), # account for the property token labels, ], dim=1, ) # ignore index labels = labels.to(prediction_scores.device) shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1]) labels = labels[:, 1:].contiguous().view(-1) loss = cross_entropy(shifted_prediction_scores, labels) if not return_dict: return ( loss, prediction_scores, ) + outputs return BacformerModelOutput( loss=loss, logits=prediction_scores, last_hidden_state=outputs.last_hidden_state, attentions=outputs.attentions, ) def generate( self, protein_family_ids: torch.LongTensor, special_tokens_mask: torch.LongTensor = None, token_type_ids: torch.LongTensor = None, max_length: int = 6000, end_token_id: int = 50000, do_sample: bool = False, top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, property_ids: torch.LongTensor = None, return_last_hidden_states: bool = False, ): """ Generate a sequence of tokens autoregressively from a given prompt. Args: protein_family_ids (torch.LongTensor): Tensor of shape (batch, seq_len) with token indices. max_length (int): Maximum length of the generated sequence (prompt + newly generated). end_token_id (int, optional): Token ID signifying end-of-sequence (END). If encountered, generation stops. do_sample (bool): Whether to sample from the probability distribution (True) or use greedy decoding (False). top_k (int): If >0, use top-k filtering in sampling mode. top_p (float): If <1.0, use nucleus (top-p) filtering in sampling mode. temperature (float): Softmax temperature for scaling logits. Higher => more random, lower => more deterministic. return_last_hidden_states (bool): If True, return final hidden states as well. Returns ------- torch.LongTensor: The generated token sequence of shape (batch, final_seq_len). (Optional) torch.FloatTensor: Final hidden states of shape (batch, final_seq_len, hidden_dim) if `return_hidden_states=True`. """ # Default END token if end_token_id is None: end_token_id = getattr(self, "end_token_id", None) # Switch to eval mode and move input to correct device self.eval() device = next(self.parameters()).device protein_family_ids = protein_family_ids.to(device) # create a special tokens mask if not provided if special_tokens_mask is None: # add a cls token at the beginning protein_family_ids = torch.cat( [torch.tensor([[-100]]).to(device), protein_family_ids], dim=1, ) special_tokens_mask = [self.cls_token_id] + [self.config.prot_emb_token_id] * ( protein_family_ids.shape[1] - 1 ) special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.long).to(device) # create a token type mask if not provided if token_type_ids is None: token_type_ids = torch.zeros_like(protein_family_ids) # Prepare the initial sequence and define max new tokens generated = protein_family_ids.clone() batch_size, prompt_length = generated.shape max_new_tokens = max_length - prompt_length if max_new_tokens <= 0: max_new_tokens = 0 # Disable gradient calculations for generation with torch.no_grad(): for _step in range(max_new_tokens): # Forward pass logits = self.forward( labels=generated, special_tokens_mask=special_tokens_mask, # assume it's all on one chromosome token_type_ids=token_type_ids, property_ids=property_ids, return_dict=True, ).logits # Focus on the last token's logits next_token_logits = logits[:, -1, :] # (batch_size, vocab_size) # Apply temperature if temperature != 1.0: next_token_logits = next_token_logits / temperature # Sampling or greedy? if do_sample: # Top-k filter next_token_logits = top_k_filtering(next_token_logits, top_k=top_k) # Top-p filter next_token_logits = top_p_filtering(next_token_logits, top_p=top_p) probs = softmax(next_token_logits, dim=-1) next_token_id = torch.multinomial(probs, num_samples=1) else: # Greedy decoding next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) # Append predicted token generated = torch.cat([generated, next_token_id], dim=1) special_tokens_mask = torch.cat( [special_tokens_mask, torch.tensor([[self.config.prot_emb_token_id]]).to(generated.device)], dim=1 ) last_token_type_id = token_type_ids[:, -1].unsqueeze(1) token_type_ids = torch.cat([token_type_ids, last_token_type_id], dim=1) # Check for END in all sequences if end_token_id is not None: if (next_token_id.squeeze(1) == end_token_id).all(): # If every sequence ended, break early break if not return_last_hidden_states: return generated # Optionally compute final hidden states if return_last_hidden_states: last_hidden_state = self.forward( labels=generated, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, return_dict=True, ).last_hidden_state return generated, last_hidden_state class BacformerForMaskedGMWithContrastiveLoss(BacformerPreTrainedModel): """Bacformer model with genomic modeling head on top""" _tied_weights_keys = ["gm_head.decoder.weight"] def __init__(self, config: BacformerConfig): super().__init__(config) self.config = config self.bacformer = BacformerModel(config, add_pooling_layer=False) self.gm_head = BacformerGMHead(config) # Initialize weights self.init_weights() def forward( self, protein_embeddings: torch.Tensor, special_tokens_mask: torch.Tensor, labels: torch.Tensor = None, token_type_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, return_attn_weights: bool = None, return_dict: Union[bool, None] = None, ) -> Union[BacformerModelOutput, None]: """Forward method for the model.""" return_dict = return_dict if return_dict is not None else self.config.return_dict return_attn_weights = ( return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights ) outputs = self.bacformer( protein_embeddings=protein_embeddings, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, attention_mask=attention_mask, return_attn_weights=return_attn_weights, return_dict=return_dict, ) last_hidden_state = outputs[0] # to speed up the forward pass, let's only consider the masked tokens loss = None if labels is not None: # contrastive loss contrastive_loss = compute_contrastive_loss(protein_embeddings, last_hidden_state, special_tokens_mask) # to speed up the forward pass, let's only consider the masked tokens last_hidden_state = last_hidden_state[labels != -100] prediction_scores = self.gm_head(last_hidden_state) labels = labels.to(prediction_scores.device) # only considering the masked tokens labels = labels[labels != -100] masked_loss = cross_entropy(prediction_scores, labels) loss = masked_loss + self.config.alpha_contrastive_loss * contrastive_loss else: prediction_scores = self.gm_head(last_hidden_state) if not return_dict: return ( loss, prediction_scores, ) + outputs return BacformerModelOutput( loss=loss, logits=prediction_scores, last_hidden_state=outputs.last_hidden_state, attentions=outputs.attentions, ) class BacformerForProteinClassification(BacformerPreTrainedModel): """Bacformer model with a classification head on top for protein classification tasks.""" def __init__(self, config: BacformerConfig, benchmark_esm: bool = False): super().__init__(config) self.config = config self.benchmark_esm = benchmark_esm self.bacformer = BacformerModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, protein_embeddings: torch.Tensor, special_tokens_mask: torch.Tensor, labels: torch.Tensor = None, token_type_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, return_attn_weights: bool = None, return_dict: Union[bool, None] = None, ) -> Optional[BacformerModelOutput]: """Forward method for the model.""" return_dict = return_dict if return_dict is not None else self.config.return_dict return_attn_weights = ( return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights ) if self.benchmark_esm: outputs = [protein_embeddings] else: outputs = self.bacformer( protein_embeddings=protein_embeddings, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, attention_mask=attention_mask, return_attn_weights=return_attn_weights, return_dict=return_dict, ) last_hidden_state = outputs[0] last_hidden_state = self.dropout(last_hidden_state) logits = self.classifier(last_hidden_state) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type == "regression": loss = mse_loss(logits, labels) elif self.config.problem_type == "single_label_classification": loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1)) elif ( self.config.problem_type == "multi_label_classification" or self.config.problem_type == "binary_classification" ): # remove the -100 labels from loss computation mask = torch.ones_like(labels.view(-1)) - (labels.view(-1) == -100.0).float() loss = binary_cross_entropy_with_logits( logits.view(-1), labels.view(-1).type_as(logits), reduction="none" ) loss = (loss * mask).sum() / mask.sum() if not return_dict: return ( loss, None, logits, ) # + outputs return BacformerModelOutput( loss=loss, logits=logits, last_hidden_state=last_hidden_state, attentions=outputs.attentions, ) class BacformerForGenomeClassification(BacformerPreTrainedModel): """Bacformer model with a classification head on top for genome classification tasks.""" def __init__(self, config: BacformerConfig): super().__init__(config) self.config = config self.bacformer = BacformerModel(config, add_pooling_layer=False) self.classifier = BacformerGenomeClassificationHead(config) # Initialize weights and apply final processing self.post_init() def forward( self, protein_embeddings: torch.Tensor, special_tokens_mask: torch.Tensor, labels: torch.Tensor = None, token_type_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, return_attn_weights: bool = None, return_dict: Union[bool, None] = None, ) -> Optional[BacformerModelOutput]: """Forward method for the model.""" return_dict = return_dict if return_dict is not None else self.config.return_dict return_attn_weights = ( return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights ) outputs = self.bacformer( protein_embeddings=protein_embeddings, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, attention_mask=attention_mask, return_attn_weights=return_attn_weights, return_dict=return_dict, ) last_hidden_state = outputs[0] logits = self.classifier(last_hidden_state, attention_mask) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type == "regression": loss = mse_loss(logits.view(-1), labels.view(-1)) elif self.config.problem_type == "binary_classification": loss = binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) elif self.config.problem_type == "single_label_classification": loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss = binary_cross_entropy_with_logits(logits, labels) if not return_dict: return ( loss, None, logits, ) return BacformerModelOutput( loss=loss, logits=logits, last_hidden_state=outputs.last_hidden_state, attentions=outputs.attentions, ) class BacformerForProteinProteinInteraction(BacformerPreTrainedModel): """Bacformer model with a protein-protein interaction head on top.""" def __init__(self, config: BacformerConfig, benchmark_esm: bool = False): super().__init__(config) self.config = config self.benchmark_esm = benchmark_esm print("Benchmark ESM:", self.benchmark_esm) self.return_attn_weights = config.return_attn_weights self.bacformer = BacformerModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dense = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.GELU(), nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), nn.Dropout(0.2), ) self.ppi_head = BacformerProteinProteinInteractionHead( in_features=config.hidden_size, prot_emb_idx=config.prot_emb_token_id ) # Initialize weights and apply final processing self.post_init() def forward( self, protein_embeddings: torch.Tensor, special_tokens_mask: torch.Tensor, labels: torch.Tensor = None, token_type_ids: torch.Tensor = None, attention_mask: torch.Tensor = None, return_attn_weights: bool = None, return_dict: Union[bool, None] = None, ) -> Union[OrderedDict, None]: # TODO: change it from token classifier output """Forward method for the model.""" return_dict = return_dict if return_dict is not None else self.config.return_dict if self.benchmark_esm: last_hidden_state = protein_embeddings.squeeze(0)[1:-2, :] else: outputs = self.bacformer( protein_embeddings=protein_embeddings, special_tokens_mask=special_tokens_mask, token_type_ids=token_type_ids, attention_mask=attention_mask, return_attn_weights=False, return_dict=True, ) last_hidden_state = outputs.last_hidden_state.squeeze(0)[1:-2, :] assert labels.shape[0] == 1, "Batch size should be 1 for protein-protein interaction task" last_hidden_state = self.dense(self.dropout(last_hidden_state)) last_hidden_state = torch.cat([last_hidden_state[labels[:, 0]], last_hidden_state[labels[:, 1]]], dim=0).mean( dim=0 ) logits = self.ppi_head(last_hidden_state) loss = binary_cross_entropy_with_logits(logits, labels[:, 2].type_as(logits).squeeze(0)) if not return_dict: return ( loss, logits, ) return BacformerModelOutput( loss=loss, logits=logits, last_hidden_state=outputs.last_hidden_state, attentions=outputs.attentions, ) # Copied from transformers.models.bert.modeling_bert.BertPooler class BacformerPooler(nn.Module): """Pooler for Bacformer model.""" def __init__(self, config: BacformerConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states: torch.Tensor, padding_mask: torch.Tensor = None) -> torch.Tensor: """Forward method for the pooler.""" # We "pool" the model by taking the mean of non-padding tokens padding_mask = padding_mask.to(hidden_states.device) if padding_mask is not None else None if padding_mask is not None: mean_hidden_states = torch.einsum("ijk,ij->ik", hidden_states, padding_mask) / padding_mask.sum( 1 ).unsqueeze(1) else: mean_hidden_states = hidden_states.mean(dim=1) pooled_output = self.dense(mean_hidden_states) pooled_output = self.activation(pooled_output) return pooled_output class BacformerGMHead(nn.Module): """Bacformer Head for genomic modeling.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # add 1 to the condfig.protein_clusters_vocab_size to account for the end token self.decoder = nn.Linear(config.hidden_size, config.protein_clusters_vocab_size + 1, bias=False) self.bias = nn.Parameter(torch.zeros(config.protein_clusters_vocab_size + 1)) def forward(self, features, **kwargs): """Forward method for the head.""" x = self.dense(features) x = gelu(x) x = self.layer_norm(x) # project back to nr of labels with bias x = self.decoder(x) + self.bias return x class BacformerGenomeClassificationHead(nn.Module): """Head for genome-level classification tasks.""" def __init__(self, config: BacformerConfig): super().__init__() self.dropout = nn.Dropout(config.hidden_dropout_prob) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features: torch.Tensor, padding_mask: torch.Tensor, **kwargs): """Forward method for the head.""" if padding_mask is not None: x = torch.einsum("ijk,ij->ik", features, padding_mask) / padding_mask.sum(1).unsqueeze(1) else: x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.out_proj(x) return x class BacformerProteinProteinInteractionHead(nn.Module): """Head for protein-protein interaction task at a genome level.""" def __init__(self, in_features: int, prot_emb_idx: int = 4, bias: bool = True): super().__init__() self.in_features = in_features self.prot_emb_idx = prot_emb_idx self.dropout = nn.Dropout(0.2) self.linear = nn.Linear(in_features, 1, bias=bias) def forward( self, hidden_states: torch.Tensor ) -> torch.Tensor: # special_tokens_mask: torch.Tensor, attentions: torch.Tensor): """Forward method for the head.""" return self.linear(self.dropout(hidden_states)).squeeze(-1)