|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn.functional import scaled_dot_product_attention |
|
|
|
from transformers import ( |
|
PreTrainedModel, |
|
PretrainedConfig, |
|
) |
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
from xformers.ops import SwiGLU |
|
|
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |
|
""" |
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. |
|
|
|
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' |
|
and the end index 'end'. The 'theta' parameter scales the frequencies. |
|
The returned tensor contains complex values in complex64 data type. |
|
|
|
Adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py. |
|
|
|
Args: |
|
dim (int): Dimension of the frequency tensor. |
|
end (int): End index for precomputing frequencies. |
|
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. |
|
|
|
Returns: |
|
torch.Tensor: Precomputed frequency tensor with complex exponentials. |
|
""" |
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
|
t = torch.arange(end, device=freqs.device) |
|
freqs = torch.outer(t, freqs).float() |
|
return torch.polar(torch.ones_like(freqs), freqs) |
|
|
|
|
|
def apply_rotary_emb_real( |
|
xq: torch.Tensor, |
|
xk: torch.Tensor, |
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Pure-real rotary embeddings. |
|
|
|
xq, xk: (B, seq, n_heads, dim) |
|
freqs_cis: (cos, sin), each of shape (B, seq, dim/2) |
|
""" |
|
cos, sin = freqs_cis |
|
|
|
cos = cos.unsqueeze(2) |
|
sin = sin.unsqueeze(2) |
|
|
|
|
|
xq_even = xq[..., 0::2] |
|
xq_odd = xq[..., 1::2] |
|
xk_even = xk[..., 0::2] |
|
xk_odd = xk[..., 1::2] |
|
|
|
|
|
q_rot_even = xq_even * cos - xq_odd * sin |
|
q_rot_odd = xq_even * sin + xq_odd * cos |
|
k_rot_even = xk_even * cos - xk_odd * sin |
|
k_rot_odd = xk_even * sin + xk_odd * cos |
|
|
|
|
|
xq_rot = torch.stack([q_rot_even, q_rot_odd], dim=-1).flatten(-2) |
|
xk_rot = torch.stack([k_rot_even, k_rot_odd], dim=-1).flatten(-2) |
|
|
|
return xq_rot.type_as(xq), xk_rot.type_as(xk) |
|
|
|
|
|
class NeoBERTConfig(PretrainedConfig): |
|
model_type = "neobert" |
|
|
|
|
|
def __init__( |
|
self, |
|
hidden_size: int = 768, |
|
num_hidden_layers: int = 28, |
|
num_attention_heads: int = 12, |
|
intermediate_size: int = 3072, |
|
embedding_init_range: float = 0.02, |
|
decoder_init_range: float = 0.02, |
|
norm_eps: float = 1e-06, |
|
vocab_size: int = 30522, |
|
pad_token_id: int = 0, |
|
max_length: int = 1024, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
if hidden_size % num_attention_heads != 0: |
|
raise ValueError("Hidden size must be divisible by the number of heads.") |
|
self.dim_head = hidden_size // num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.embedding_init_range = embedding_init_range |
|
self.decoder_init_range = decoder_init_range |
|
self.norm_eps = norm_eps |
|
self.vocab_size = vocab_size |
|
self.pad_token_id = pad_token_id |
|
self.max_length = max_length |
|
self.kwargs = kwargs |
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
"""Transformer encoder block.""" |
|
|
|
def __init__(self, config: NeoBERTConfig): |
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
|
|
self.qkv = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size * 3, bias=False) |
|
self.wo = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=False) |
|
|
|
|
|
multiple_of = 8 |
|
intermediate_size = int(2 * config.intermediate_size / 3) |
|
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of) |
|
self.ffn = SwiGLU(config.hidden_size, intermediate_size, config.hidden_size, bias=False) |
|
|
|
|
|
self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) |
|
self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor], |
|
output_attentions: bool, |
|
): |
|
|
|
attn_output, attn_weights = self._att_block( |
|
self.attention_norm(x), attention_mask, freqs_cis, output_attentions, |
|
) |
|
|
|
|
|
x = x + attn_output |
|
|
|
|
|
x = x + self.ffn(self.ffn_norm(x)) |
|
|
|
return x, attn_weights |
|
|
|
def _att_block( |
|
self, |
|
x: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor], |
|
output_attentions: bool, |
|
): |
|
batch_size, seq_len, _ = x.shape |
|
|
|
xq, xk, xv = self.qkv(x).view(batch_size, seq_len, self.config.num_attention_heads, self.config.dim_head * 3).chunk(3, axis=-1) |
|
|
|
xq, xk = apply_rotary_emb_real(xq, xk, freqs_cis) |
|
|
|
|
|
attn_weights = None |
|
|
|
|
|
if output_attentions: |
|
attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5) |
|
if attention_mask is not None: |
|
attn_weights = attn_weights * attention_mask |
|
attn_weights = attn_weights.softmax(-1) |
|
attn = attn_weights @ xv.permute(0, 2, 1, 3) |
|
attn = attn.transpose(1, 2) |
|
|
|
else: |
|
attn = scaled_dot_product_attention( |
|
query=xq.transpose(1, 2), |
|
key=xk.transpose(1, 2), |
|
value=xv.transpose(1, 2), |
|
attn_mask=attention_mask.bool(), |
|
dropout_p=0, |
|
).transpose(1, 2) |
|
|
|
return self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.config.dim_head)), attn_weights |
|
|
|
|
|
class NeoBERTPreTrainedModel(PreTrainedModel): |
|
config_class = NeoBERTConfig |
|
base_model_prefix = "model" |
|
_supports_cache_class = True |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.uniform_(-self.config.decoder_init_range, self.config.decoder_init_range) |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.uniform_(-self.config.embedding_init_range, self.config.embedding_init_range) |
|
|
|
|
|
class NeoBERT(NeoBERTPreTrainedModel): |
|
config_class = NeoBERTConfig |
|
|
|
def __init__(self, config: NeoBERTConfig): |
|
super().__init__(config) |
|
|
|
self.config = config |
|
|
|
self.encoder = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
|
|
|
|
freqs_cis = precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_length) |
|
self.register_buffer("freqs_cos", freqs_cis.real, persistent=False) |
|
self.register_buffer("freqs_sin", freqs_cis.imag, persistent=False) |
|
|
|
self.transformer_encoder = nn.ModuleList() |
|
for _ in range(config.num_hidden_layers): |
|
self.transformer_encoder.append(EncoderBlock(config)) |
|
|
|
self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: torch.Tensor = None, |
|
position_ids: torch.Tensor = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
output_hidden_states: bool = False, |
|
output_attentions: bool = False, |
|
**kwargs, |
|
): |
|
|
|
hidden_states, attentions = [], [] |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1) |
|
|
|
|
|
freqs_cos = ( |
|
self.freqs_cos[position_ids] |
|
if position_ids is not None |
|
else self.freqs_cos[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0) |
|
) |
|
freqs_sin = ( |
|
self.freqs_sin[position_ids] |
|
if position_ids is not None |
|
else self.freqs_sin[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0) |
|
) |
|
|
|
|
|
x = self.encoder(input_ids) if input_ids is not None else inputs_embeds |
|
|
|
|
|
for layer in self.transformer_encoder: |
|
x, attn = layer(x, attention_mask, (freqs_cos, freqs_sin), output_attentions) |
|
if output_hidden_states: |
|
hidden_states.append(x) |
|
if output_attentions: |
|
attentions.append(attn) |
|
|
|
|
|
x = self.layer_norm(x) |
|
|
|
|
|
return BaseModelOutput( |
|
last_hidden_state=x, |
|
hidden_states=hidden_states if output_hidden_states else None, |
|
attentions=attentions if output_attentions else None, |
|
) |
|
|
|
if __name__ == "__main__": |
|
from transformers import AutoTokenizer |
|
|
|
model_name = "chandar-lab/NeoBERT" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = NeoBERT.from_pretrained(model_name) |
|
|
|
|
|
text = [ |
|
"NeoBERT is the most efficient model of its kind!", |
|
"This is really cool", |
|
] |
|
inputs = tokenizer(text, padding=True, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
pytorch_outputs = model(**inputs) |
|
|
|
|
|
torch.onnx.export( |
|
model, |
|
(inputs['input_ids'], inputs['attention_mask']), |
|
f="model.onnx", |
|
export_params=True, |
|
opset_version=20, |
|
do_constant_folding=True, |
|
input_names = ['input_ids', 'attention_mask'], |
|
output_names = ['last_hidden_state'], |
|
dynamic_axes = { |
|
'input_ids': {0: 'batch_size', 1: 'sequence_length'}, |
|
'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, |
|
'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}, |
|
}, |
|
dynamo=True, |
|
) |
|
|
|
|
|
import onnxruntime as ort |
|
ort_session = ort.InferenceSession("model.onnx") |
|
ort_inputs = { |
|
"input_ids": inputs['input_ids'].numpy(), |
|
"attention_mask": inputs['attention_mask'].numpy(), |
|
} |
|
ort_outputs = ort_session.run(None, ort_inputs) |
|
|
|
assert (pytorch_outputs.last_hidden_state.numpy() - ort_outputs[0]).max() < 1e-3 |
|
|