|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class BitNetConfig(PretrainedConfig): |
|
|
model_type = "bitnet" |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=32000, |
|
|
hidden_size=768, |
|
|
num_hidden_layers=12, |
|
|
num_attention_heads=12, |
|
|
intermediate_size=3072, |
|
|
hidden_act="gelu", |
|
|
max_position_embeddings=512, |
|
|
initializer_range=0.02, |
|
|
layer_norm_eps=1e-12, |
|
|
dropout=0.1, |
|
|
pad_token_id=0, |
|
|
bos_token_id=1, |
|
|
eos_token_id=2, |
|
|
**kwargs |
|
|
): |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.intermediate_size = intermediate_size |
|
|
self.hidden_act = hidden_act |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.initializer_range = initializer_range |
|
|
self.layer_norm_eps = layer_norm_eps |
|
|
self.dropout = dropout |
|
|
super().__init__( |
|
|
pad_token_id=pad_token_id, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
class BitNetForCausalLM(PreTrainedModel): |
|
|
config_class = BitNetConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.layers = nn.ModuleList([ |
|
|
|
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=config.dropout |
|
|
) for _ in range(config.num_hidden_layers) |
|
|
]) |
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
|
|
|
|
|
hidden_states = self.embed_tokens(input_ids) |
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states) |
|
|
hidden_states = self.norm(hidden_states) |
|
|
logits = self.lm_head(hidden_states) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
return {"logits": logits, "loss": loss} if loss is not None else {"logits": logits} |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
|
return {"input_ids": input_ids, **kwargs} |