from typing import Optional, Tuple import torch from torch import nn from transformers import PreTrainedModel from transformers.cache_utils import Cache from configuration_spect1 import SpecT1Config class SpecT1MTPLayers(nn.Module): def __init__(self, config: SpecT1Config): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) self.token_layernorm = nn.LayerNorm(config.hidden_size) self.hidden_layernorm = nn.LayerNorm(config.hidden_size) self.final_layernorm = nn.LayerNorm(config.hidden_size) self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) self.self_attn = nn.MultiheadAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, batch_first=True ) self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.ReLU(), nn.Linear(config.intermediate_size, config.hidden_size) ) def forward( self, input_embeds: torch.Tensor, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, cache_position=None, **kwargs ) -> torch.Tensor: input_embeds = self.token_layernorm(input_embeds) previous_hidden_states = self.hidden_layernorm(hidden_states) hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1)) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_output, _ = self.self_attn(hidden_states, hidden_states, hidden_states, attn_mask=attention_mask) hidden_states = residual + attn_output residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) mlp_output = self.mlp(hidden_states) hidden_states = residual + mlp_output hidden_states = self.final_layernorm(hidden_states) return hidden_states class SpecT1Model(nn.Module): config_class = SpecT1Config def __init__(self, config: SpecT1Config): super().__init__() self.config = config self.mtp_layers = nn.ModuleList([ SpecT1MTPLayers(config) for _ in range(config.num_nextn_predict_layers) ]) def forward( self, input_embeds: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: hidden_states = input_embeds for layer in self.mtp_layers: hidden_states = layer( input_embeds=input_embeds, hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, **kwargs ) return hidden_states class SpecT1ForCausalLM(PreTrainedModel): config_class = SpecT1Config def __init__(self, config: SpecT1Config): super().__init__(config) self.config = config self.model = SpecT1Model(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward( self, input_ids: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, **kwargs ) -> torch.Tensor: if inputs_embeds is None: raise ValueError("inputs_embeds must be provided for SpecT1ForCausalLM") hidden_states = self.model( input_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, **kwargs ) logits = self.lm_head(hidden_states) loss = None if labels is not None: # Compute loss for training (optional) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: return (logits,) + (loss,) if loss is not None else (logits,) from transformers.modeling_outputs import CausalLMOutputWithPast return CausalLMOutputWithPast( loss=loss, logits=logits, hidden_states=None, attentions=None, past_key_values=None ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): if inputs_embeds is None: # Placeholder for embedding lookup; adjust as needed raise ValueError("SpecT1ForCausalLM requires inputs_embeds for generation") return { "inputs_embeds": inputs_embeds, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache", True) }