import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple class MultiHeadAttention(nn.Module): def __init__(self, config): super().__init__() self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads self.query = nn.Linear(config.hidden_size, config.hidden_size) self.key = nn.Linear(config.hidden_size, config.hidden_size) self.value = nn.Linear(config.hidden_size, config.hidden_size) self.out = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.attention_dropout) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_length = hidden_states.shape[:2] # Project queries, keys, and values query_states = self.query(hidden_states) key_states = self.key(hidden_states) value_states = self.value(hidden_states) # Reshape for multi-head attention query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2) key_states = key_states.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2) value_states = value_states.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2) # Calculate attention scores attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.head_size) if attention_mask is not None: attention_scores = attention_scores + attention_mask attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = self.dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask # Apply attention to values context_layer = torch.matmul(attention_probs, value_states) context_layer = context_layer.transpose(1, 2).contiguous() # Reshape back context_layer = context_layer.view(batch_size, seq_length, self.hidden_size) context_layer = self.out(context_layer) return context_layer, attention_probs class MLP(nn.Module): def __init__(self, config): super().__init__() self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) self.act = nn.GELU() self.dropout = nn.Dropout(config.hidden_dropout) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dense_4h_to_h(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class TransformerBlock(nn.Module): def __init__(self, config): super().__init__() self.attention = MultiHeadAttention(config) self.mlp = MLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Self-attention attention_layernorm_out = self.input_layernorm(hidden_states) attention_output, attention_probs = self.attention( attention_layernorm_out, attention_mask=attention_mask, head_mask=head_mask, ) attention_output = self.dropout(attention_output) # Add & norm attention_output = attention_output + hidden_states # MLP mlp_layernorm_out = self.post_attention_layernorm(attention_output) mlp_output = self.mlp(mlp_layernorm_out) # Add & norm layer_output = mlp_output + attention_output return layer_output, attention_probs class OpenPeerLLM(nn.Module): def __init__(self, config): super().__init__() self.config = config # Token embeddings self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # Transformer layers self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) # Final layer norm self.final_layernorm = nn.LayerNorm(config.hidden_size) # Output head self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights self.init_weights() def init_weights(self): """Initialize weights with small random values""" self.apply(self._init_weights) def _init_weights(self, module): """Initialize weights for different layer types""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, ...]: batch_size, seq_length = input_ids.shape # Create position IDs position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Get embeddings inputs_embeds = self.word_embeddings(input_ids) position_embeds = self.position_embeddings(position_ids) # Combine embeddings hidden_states = inputs_embeds + position_embeds # Create attention mask if needed if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.to(dtype=hidden_states.dtype) attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min # Process through transformer layers all_attentions = [] for layer in self.layers: hidden_states, attention_probs = layer(hidden_states, attention_mask) all_attentions.append(attention_probs) # Final layer norm hidden_states = self.final_layernorm(hidden_states) # Get logits logits = self.lm_head(hidden_states) # Calculate loss if labels provided loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) return { "loss": loss, "logits": logits, "hidden_states": hidden_states, "attentions": all_attentions, }