|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from transformers import PreTrainedModel |
|
from transformers.cache_utils import Cache |
|
|
|
from configuration_spect1 import SpecT1Config |
|
|
|
|
|
class SpecT1MTPLayers(nn.Module): |
|
def __init__(self, config: SpecT1Config): |
|
super().__init__() |
|
self.input_layernorm = nn.LayerNorm(config.hidden_size) |
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) |
|
self.token_layernorm = nn.LayerNorm(config.hidden_size) |
|
self.hidden_layernorm = nn.LayerNorm(config.hidden_size) |
|
self.final_layernorm = nn.LayerNorm(config.hidden_size) |
|
self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) |
|
self.self_attn = nn.MultiheadAttention( |
|
embed_dim=config.hidden_size, |
|
num_heads=config.num_attention_heads, |
|
dropout=config.attention_dropout, |
|
batch_first=True |
|
) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(config.hidden_size, config.intermediate_size), |
|
nn.ReLU(), |
|
nn.Linear(config.intermediate_size, config.hidden_size) |
|
) |
|
|
|
def forward( |
|
self, |
|
input_embeds: torch.Tensor, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
cache_position=None, |
|
**kwargs |
|
) -> torch.Tensor: |
|
input_embeds = self.token_layernorm(input_embeds) |
|
previous_hidden_states = self.hidden_layernorm(hidden_states) |
|
hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1)) |
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
attn_output, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attn_mask=attention_mask) |
|
hidden_states = residual + attn_output |
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
mlp_output = self.mlp(hidden_states) |
|
hidden_states = residual + mlp_output |
|
hidden_states = self.final_layernorm(hidden_states) |
|
return hidden_states |
|
|
|
class SpecT1Model(nn.Module): |
|
config_class = SpecT1Config |
|
def __init__(self, config: SpecT1Config): |
|
super().__init__() |
|
self.config = config |
|
self.mtp_layers = nn.ModuleList([ |
|
SpecT1MTPLayers(config) for _ in range(config.num_nextn_predict_layers) |
|
]) |
|
|
|
def forward( |
|
self, |
|
input_embeds: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
**kwargs |
|
) -> torch.Tensor: |
|
hidden_states = input_embeds |
|
for layer in self.mtp_layers: |
|
hidden_states = layer( |
|
input_embeds=input_embeds, |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
**kwargs |
|
) |
|
return hidden_states |
|
|
|
class SpecT1ForCausalLM(PreTrainedModel): |
|
config_class = SpecT1Config |
|
def __init__(self, config: SpecT1Config): |
|
super().__init__(config) |
|
self.config = config |
|
self.model = SpecT1Model(config) |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
use_cache: Optional[bool] = False, |
|
output_attentions: Optional[bool] = False, |
|
output_hidden_states: Optional[bool] = False, |
|
return_dict: Optional[bool] = True, |
|
**kwargs |
|
) -> torch.Tensor: |
|
if inputs_embeds is None: |
|
raise ValueError("inputs_embeds must be provided for SpecT1ForCausalLM") |
|
hidden_states = self.model( |
|
input_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
**kwargs |
|
) |
|
logits = self.lm_head(hidden_states) |
|
loss = None |
|
if labels is not None: |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
if not return_dict: |
|
return (logits,) + (loss,) if loss is not None else (logits,) |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None, |
|
past_key_values=None |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
if inputs_embeds is None: |
|
|
|
raise ValueError("SpecT1ForCausalLM requires inputs_embeds for generation") |
|
return { |
|
"inputs_embeds": inputs_embeds, |
|
"attention_mask": attention_mask, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache", True) |
|
} |