from typing import Optional, Tuple import torch from torch import nn from torch.nn.functional import scaled_dot_product_attention from transformers import ( PreTrainedModel, PretrainedConfig, ) from transformers.modeling_outputs import BaseModelOutput from xformers.ops import SwiGLU def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py. Args: dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() return torch.polar(torch.ones_like(freqs), freqs) def apply_rotary_emb_real( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Pure-real rotary embeddings. xq, xk: (B, seq, n_heads, dim) freqs_cis: (cos, sin), each of shape (B, seq, dim/2) """ cos, sin = freqs_cis # make (B, seq, 1, dim/2) so they broadcast to (B, seq, n_heads, dim/2) cos = cos.unsqueeze(2) sin = sin.unsqueeze(2) # split even/odd dims xq_even = xq[..., 0::2] xq_odd = xq[..., 1::2] xk_even = xk[..., 0::2] xk_odd = xk[..., 1::2] # apply the rotation formula: q_rot_even = xq_even * cos - xq_odd * sin q_rot_odd = xq_even * sin + xq_odd * cos k_rot_even = xk_even * cos - xk_odd * sin k_rot_odd = xk_even * sin + xk_odd * cos # interleave even/odd back into last dim xq_rot = torch.stack([q_rot_even, q_rot_odd], dim=-1).flatten(-2) xk_rot = torch.stack([k_rot_even, k_rot_odd], dim=-1).flatten(-2) return xq_rot.type_as(xq), xk_rot.type_as(xk) class NeoBERTConfig(PretrainedConfig): model_type = "neobert" # All config parameters must have a default value. def __init__( self, hidden_size: int = 768, num_hidden_layers: int = 28, num_attention_heads: int = 12, intermediate_size: int = 3072, embedding_init_range: float = 0.02, decoder_init_range: float = 0.02, norm_eps: float = 1e-06, vocab_size: int = 30522, pad_token_id: int = 0, max_length: int = 1024, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads if hidden_size % num_attention_heads != 0: raise ValueError("Hidden size must be divisible by the number of heads.") self.dim_head = hidden_size // num_attention_heads self.intermediate_size = intermediate_size self.embedding_init_range = embedding_init_range self.decoder_init_range = decoder_init_range self.norm_eps = norm_eps self.vocab_size = vocab_size self.pad_token_id = pad_token_id self.max_length = max_length self.kwargs = kwargs class EncoderBlock(nn.Module): """Transformer encoder block.""" def __init__(self, config: NeoBERTConfig): super().__init__() self.config = config # Attention self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False) self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False) # Feedforward network multiple_of = 8 intermediate_size = int(2 * config.intermediate_size / 3) intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of) self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False) # Layer norms self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) def forward( self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], output_attentions: bool, ): # Attention attn_output, attn_weights = self._att_block( self.attention_norm(x), attention_mask, freqs_cis, output_attentions, ) # Residual x = x + attn_output # Feed-forward x = x + self.ffn(self.ffn_norm(x)) return x, attn_weights def _att_block( self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], output_attentions: bool, ): batch_size, seq_len, _ = x.shape xq, xk, xv = self.qkv(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.dim_head * 3).chunk(3, axis=-1) xq, xk = apply_rotary_emb_real(xq, xk, freqs_cis) # Attn block attn_weights = None # Eager attention if attention weights are needed in the output if output_attentions: attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5) if attention_mask is not None: attn_weights = attn_weights * attention_mask attn_weights = attn_weights.softmax(-1) attn = attn_weights @ xv.permute(0, 2, 1, 3) attn = attn.transpose(1, 2) # Fall back to SDPA otherwise else: attn = scaled_dot_product_attention( query=xq.transpose(1, 2), key=xk.transpose(1, 2), value=xv.transpose(1, 2), attn_mask=attention_mask.bool(), dropout_p=0, ).transpose(1, 2) return self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.config.dim_head)), attn_weights class NeoBERTPreTrainedModel(PreTrainedModel): config_class = NeoBERTConfig base_model_prefix = "model" _supports_cache_class = True def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range) elif isinstance(module, nn.Embedding): module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range) class NeoBERT(NeoBERTPreTrainedModel): config_class = NeoBERTConfig def __init__(self, config: NeoBERTConfig): super().__init__(config) self.config = config self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) # Ensures freqs_cis is moved to the same devices as the model. Non-persistent buffers are not saved in the state_dict. freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length) self.register_buffer("freqs_cos", freqs_cis.real, persistent=False) self.register_buffer("freqs_sin", freqs_cis.imag, persistent=False) self.transformer_encoder = nn.ModuleList() for _ in range(config.num_hidden_layers): self.transformer_encoder.append(EncoderBlock(config)) self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: torch.Tensor = None, position_ids: torch.Tensor = None, inputs_embeds: Optional[torch.Tensor] = None, output_hidden_states: bool = False, output_attentions: bool = False, **kwargs, ): # Initialize hidden_states, attentions = [], [] if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length) if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1) # RoPE freqs_cos = ( self.freqs_cos[position_ids] if position_ids is not None else self.freqs_cos[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0) ) freqs_sin = ( self.freqs_sin[position_ids] if position_ids is not None else self.freqs_sin[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0) ) # Embedding x = self.encoder(input_ids) if input_ids is not None else inputs_embeds # Transformer encoder for layer in self.transformer_encoder: x, attn = layer(x, attention_mask, (freqs_cos, freqs_sin), output_attentions) if output_hidden_states: hidden_states.append(x) if output_attentions: attentions.append(attn) # Final normalization layer x = self.layer_norm(x) # Return the output of the last hidden layer return BaseModelOutput( last_hidden_state=x, hidden_states=hidden_states if output_hidden_states else None, attentions=attentions if output_attentions else None, ) if __name__ == "__main__": from transformers import AutoTokenizer model_name = "chandar-lab/NeoBERT" tokenizer = AutoTokenizer.from_pretrained(model_name) model = NeoBERT.from_pretrained(model_name) # Tokenize input text text = [ "NeoBERT is the most efficient model of its kind!", "This is really cool", ] inputs = tokenizer(text, padding=True, return_tensors="pt") # Generate embeddings with torch.no_grad(): pytorch_outputs = model(**inputs) # Export to ONNX torch.onnx.export( model, (inputs['input_ids'], inputs['attention_mask']), f="model.onnx", export_params=True, opset_version=20, do_constant_folding=True, input_names = ['input_ids', 'attention_mask'], output_names = ['last_hidden_state'], dynamic_axes = { 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, 'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}, }, dynamo=True, ) # Validate import onnxruntime as ort ort_session = ort.InferenceSession("model.onnx") ort_inputs = { "input_ids": inputs['input_ids'].numpy(), "attention_mask": inputs['attention_mask'].numpy(), } ort_outputs = ort_session.run(None, ort_inputs) assert (pytorch_outputs.last_hidden_state.numpy() - ort_outputs[0]).max() < 1e-3