NeoBERT-ONNX / export.py
Xenova's picture
Xenova HF Staff
Upload ONNX export script
55791dc verified
raw
history blame
11.4 kB
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn.functional import scaled_dot_product_attention
from transformers import (
PreTrainedModel,
PretrainedConfig,
)
from transformers.modeling_outputs import BaseModelOutput
from xformers.ops import SwiGLU
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rotary_emb_real(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pure-real rotary embeddings.
xq, xk: (B, seq, n_heads, dim)
freqs_cis: (cos, sin), each of shape (B, seq, dim/2)
"""
cos, sin = freqs_cis
# make (B, seq, 1, dim/2) so they broadcast to (B, seq, n_heads, dim/2)
cos = cos.unsqueeze(2)
sin = sin.unsqueeze(2)
# split even/odd dims
xq_even = xq[..., 0::2]
xq_odd = xq[..., 1::2]
xk_even = xk[..., 0::2]
xk_odd = xk[..., 1::2]
# apply the rotation formula:
q_rot_even = xq_even * cos - xq_odd * sin
q_rot_odd = xq_even * sin + xq_odd * cos
k_rot_even = xk_even * cos - xk_odd * sin
k_rot_odd = xk_even * sin + xk_odd * cos
# interleave even/odd back into last dim
xq_rot = torch.stack([q_rot_even, q_rot_odd], dim=-1).flatten(-2)
xk_rot = torch.stack([k_rot_even, k_rot_odd], dim=-1).flatten(-2)
return xq_rot.type_as(xq), xk_rot.type_as(xk)
class NeoBERTConfig(PretrainedConfig):
model_type = "neobert"
# All config parameters must have a default value.
def __init__(
self,
hidden_size: int = 768,
num_hidden_layers: int = 28,
num_attention_heads: int = 12,
intermediate_size: int = 3072,
embedding_init_range: float = 0.02,
decoder_init_range: float = 0.02,
norm_eps: float = 1e-06,
vocab_size: int = 30522,
pad_token_id: int = 0,
max_length: int = 1024,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if hidden_size % num_attention_heads != 0:
raise ValueError("Hidden size must be divisible by the number of heads.")
self.dim_head = hidden_size // num_attention_heads
self.intermediate_size = intermediate_size
self.embedding_init_range = embedding_init_range
self.decoder_init_range = decoder_init_range
self.norm_eps = norm_eps
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.kwargs = kwargs
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(self, config: NeoBERTConfig):
super().__init__()
self.config = config
# Attention
self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False)
self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False)
# Feedforward network
multiple_of = 8
intermediate_size = int(2 * config.intermediate_size / 3)
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False)
# Layer norms
self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
output_attentions: bool,
):
# Attention
attn_output, attn_weights = self._att_block(
self.attention_norm(x), attention_mask, freqs_cis, output_attentions,
)
# Residual
x = x + attn_output
# Feed-forward
x = x + self.ffn(self.ffn_norm(x))
return x, attn_weights
def _att_block(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
output_attentions: bool,
):
batch_size, seq_len, _ = x.shape
xq, xk, xv = self.qkv(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.dim_head * 3).chunk(3, axis=-1)
xq, xk = apply_rotary_emb_real(xq, xk, freqs_cis)
# Attn block
attn_weights = None
# Eager attention if attention weights are needed in the output
if output_attentions:
attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_weights = attn_weights.softmax(-1)
attn = attn_weights @ xv.permute(0, 2, 1, 3)
attn = attn.transpose(1, 2)
# Fall back to SDPA otherwise
else:
attn = scaled_dot_product_attention(
query=xq.transpose(1, 2),
key=xk.transpose(1, 2),
value=xv.transpose(1, 2),
attn_mask=attention_mask.bool(),
dropout_p=0,
).transpose(1, 2)
return self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.config.dim_head)), attn_weights
class NeoBERTPreTrainedModel(PreTrainedModel):
config_class = NeoBERTConfig
base_model_prefix = "model"
_supports_cache_class = True
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range)
elif isinstance(module, nn.Embedding):
module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range)
class NeoBERT(NeoBERTPreTrainedModel):
config_class = NeoBERTConfig
def __init__(self, config: NeoBERTConfig):
super().__init__(config)
self.config = config
self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# Ensures freqs_cis is moved to the same devices as the model. Non-persistent buffers are not saved in the state_dict.
freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length)
self.register_buffer("freqs_cos", freqs_cis.real, persistent=False)
self.register_buffer("freqs_sin", freqs_cis.imag, persistent=False)
self.transformer_encoder = nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.transformer_encoder.append(EncoderBlock(config))
self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: torch.Tensor = None,
position_ids: torch.Tensor = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
output_attentions: bool = False,
**kwargs,
):
# Initialize
hidden_states, attentions = [], []
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
# RoPE
freqs_cos = (
self.freqs_cos[position_ids]
if position_ids is not None
else self.freqs_cos[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0)
)
freqs_sin = (
self.freqs_sin[position_ids]
if position_ids is not None
else self.freqs_sin[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0)
)
# Embedding
x = self.encoder(input_ids) if input_ids is not None else inputs_embeds
# Transformer encoder
for layer in self.transformer_encoder:
x, attn = layer(x, attention_mask, (freqs_cos, freqs_sin), output_attentions)
if output_hidden_states:
hidden_states.append(x)
if output_attentions:
attentions.append(attn)
# Final normalization layer
x = self.layer_norm(x)
# Return the output of the last hidden layer
return BaseModelOutput(
last_hidden_state=x,
hidden_states=hidden_states if output_hidden_states else None,
attentions=attentions if output_attentions else None,
)
if __name__ == "__main__":
from transformers import AutoTokenizer
model_name = "chandar-lab/NeoBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = NeoBERT.from_pretrained(model_name)
# Tokenize input text
text = [
"NeoBERT is the most efficient model of its kind!",
"This is really cool",
]
inputs = tokenizer(text, padding=True, return_tensors="pt")
# Generate embeddings
with torch.no_grad():
pytorch_outputs = model(**inputs)
# Export to ONNX
torch.onnx.export(
model,
(inputs['input_ids'], inputs['attention_mask']),
f="model.onnx",
export_params=True,
opset_version=20,
do_constant_folding=True,
input_names = ['input_ids', 'attention_mask'],
output_names = ['last_hidden_state'],
dynamic_axes = {
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'},
},
dynamo=True,
)
# Validate
import onnxruntime as ort
ort_session = ort.InferenceSession("model.onnx")
ort_inputs = {
"input_ids": inputs['input_ids'].numpy(),
"attention_mask": inputs['attention_mask'].numpy(),
}
ort_outputs = ort_session.run(None, ort_inputs)
assert (pytorch_outputs.last_hidden_state.numpy() - ort_outputs[0]).max() < 1e-3