Spec-T1-RL-7B / modeling_spect1.py
SVECTOR-OFFICIAL's picture
Update modeling_spect1.py
5e5aa43 verified
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.cache_utils import Cache
from configuration_spect1 import SpecT1Config
class SpecT1MTPLayers(nn.Module):
def __init__(self, config: SpecT1Config):
super().__init__()
self.input_layernorm = nn.LayerNorm(config.hidden_size)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
self.token_layernorm = nn.LayerNorm(config.hidden_size)
self.hidden_layernorm = nn.LayerNorm(config.hidden_size)
self.final_layernorm = nn.LayerNorm(config.hidden_size)
self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
self.self_attn = nn.MultiheadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
batch_first=True
)
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size),
nn.ReLU(),
nn.Linear(config.intermediate_size, config.hidden_size)
)
def forward(
self,
input_embeds: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position=None,
**kwargs
) -> torch.Tensor:
input_embeds = self.token_layernorm(input_embeds)
previous_hidden_states = self.hidden_layernorm(hidden_states)
hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attn_mask=attention_mask)
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
hidden_states = residual + mlp_output
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class SpecT1Model(nn.Module):
config_class = SpecT1Config
def __init__(self, config: SpecT1Config):
super().__init__()
self.config = config
self.mtp_layers = nn.ModuleList([
SpecT1MTPLayers(config) for _ in range(config.num_nextn_predict_layers)
])
def forward(
self,
input_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
hidden_states = input_embeds
for layer in self.mtp_layers:
hidden_states = layer(
input_embeds=input_embeds,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs
)
return hidden_states
class SpecT1ForCausalLM(PreTrainedModel):
config_class = SpecT1Config
def __init__(self, config: SpecT1Config):
super().__init__(config)
self.config = config
self.model = SpecT1Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
**kwargs
) -> torch.Tensor:
if inputs_embeds is None:
raise ValueError("inputs_embeds must be provided for SpecT1ForCausalLM")
hidden_states = self.model(
input_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
return (logits,) + (loss,) if loss is not None else (logits,)
from transformers.modeling_outputs import CausalLMOutputWithPast
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
past_key_values=None
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if inputs_embeds is None:
raise ValueError("SpecT1ForCausalLM requires inputs_embeds for generation")
return {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", True)
}