SmolLM2-135m / model.py
gitesh-grover's picture
Upload 6 files
960a17b verified
import torch
import torch.nn as nn
import math
from typing import Optional
import torch.nn.functional as F
# This llama model is based on the paper: https://arxiv.org/pdf/2302.13971.pdf
# Model Architecturte: static/llamaModel.jpg
# It is a transformer model with rotary position embeddings (RoPE) and SwiGLU
# activation function. It uses RMSNorm for normalization.
# Other Good reads: https://pub.towardsai.net/llama-explained-a70e71e706e9
def precompute_rotary_emb(dim: int, max_seq_len: int, base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]:
"""
Precompute the rotary position embeddings
Args:
dim: Dimension of the embeddings
max_seq_len: Maximum sequence length
base: Base for the angle calculations
Returns:
Tuple of (sin, cos) tensors of shape (max_seq_len, dim//2)
"""
# Create position indices tensor
position = torch.arange(max_seq_len).unsqueeze(1) # (seq_len, 1)
# Create dimension indices tensor
div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(base) / dim)) # (dim//2)
# Compute angles
angles = position * div_term # (seq_len, dim//2)
# Return sin and cos
return torch.sin(angles), torch.cos(angles)
def apply_rotary_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
"""
Apply rotary position embeddings to the input tensor
Args:
x: Input tensor of shape (batch_size, seq_len, num_heads, head_dim)
sin: Sine tensor of shape (seq_len, head_dim//2)
cos: Cosine tensor of shape (seq_len, head_dim//2)
Returns:
Tensor with rotary position embeddings applied
"""
# Reshape x to split last dimension in half
x_reshape = x.float().reshape(*x.shape[:-1], -1, 2)
# Extract even and odd dimensions
x1, x2 = x_reshape[..., 0], x_reshape[..., 1]
# Reshape sin and cos for broadcasting
sin = sin.view(1, sin.shape[0], 1, sin.shape[1]) # (1, seq_len, 1, dim//2)
cos = cos.view(1, cos.shape[0], 1, cos.shape[1]) # (1, seq_len, 1, dim//2)
# Apply rotation using the rotation matrix multiplication
result = torch.stack([
x1 * cos - x2 * sin,
x2 * cos + x1 * sin
], dim=-1)
return result.flatten(-2) # Flatten last 2 dimensions
class LlamaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, num_kv_heads: Optional[int] = None, max_position_embeddings=2048):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_dim = dim // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
# self.q_proj = nn.Linear(dim, dim, bias=False)
# self.k_proj = nn.Linear(dim, dim, bias=False)
# self.v_proj = nn.Linear(dim, dim, bias=False)
# Adjust projections for GQA
self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, D) or (B, T, H * D/H)
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, H_kv * D/H)
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, H_kv * D/H)
self.o_proj = nn.Linear(dim, dim, bias=False)
# self.o_proj.NANGPT_SCALE_INIT = 1 TODO do we need weight initialization scaling?
# Cache attributes
self.k_cache = None
self.v_cache = None
self.cache_seq_len = 0
# Precompute sin and cos for all positions
self.sin, self.cos = precompute_rotary_emb(self.head_dim, max_position_embeddings)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False):
batch_size, seq_len, _ = x.shape
# Project inputs
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# Get rotary embeddings for the new tokens
# sin = self.sin[self.cache_seq_len:self.cache_seq_len + seq_len].to(x.device)
# cos = self.cos[self.cache_seq_len:self.cache_seq_len + seq_len].to(x.device)
sin = self.sin[:seq_len].to(x.device)
cos = self.cos[:seq_len].to(x.device)
# Apply rotary embeddings
q = apply_rotary_emb(q, sin, cos)
k = apply_rotary_emb(k, sin, cos)
# Handle KV caching
# if use_cache:
# if self.k_cache is None:
# # Initialize cache if empty
# self.k_cache = k
# self.v_cache = v
# else:
# # Concatenate new KV with cached KV
# self.k_cache = torch.cat([self.k_cache, k], dim=1)
# self.v_cache = torch.cat([self.v_cache, v], dim=1)
# # Use concatenated KV pairs
# k = self.k_cache
# v = self.v_cache
# # Update cache sequence length
# self.cache_seq_len += seq_len
# Reshape for attention computation
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Handle GQA (Grouped Query Attention)
if self.num_queries_per_kv > 1:
k = k.unsqueeze(2).expand(-1, -1, self.num_queries_per_kv, -1, -1)
v = v.unsqueeze(2).expand(-1, -1, self.num_queries_per_kv, -1, -1)
k = k.reshape(batch_size, self.num_heads, -1, self.head_dim)
v = v.reshape(batch_size, self.num_heads, -1, self.head_dim)
# Compute attention
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
# Speed up - Flash Attention (calculation happens in GPU sram and not GPU RAM) TODO Not sure how to apply this in group query attention?
# out F.scaled_dot_product_attention(q, k, v, is_causal = True)
return self.o_proj(out)
def clear_cache(self):
self.k_cache = None
self.v_cache = None
self.cache_seq_len = 0
class LlamaFFN(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.gate = nn.Linear(dim, hidden_dim, bias=False)
self.up = nn.Linear(dim, hidden_dim, bias=False)
self.down = nn.Linear(hidden_dim, dim, bias=False)
# self.down.NANGPT_SCALE_INIT = 1 # TODO do we need weight initialization scaling - Optimization ?
self.act_fn = nn.SiLU() # SwiGLU activation function
def forward(self, x):
return self.down(self.act_fn(self.gate(x)) * self.up(x))
class LlamaBlock(nn.Module):
def __init__(self, config):
# nn_embed or dim is the dimension of the input to the block
super().__init__()
self.attention = LlamaAttention(
config.nn_embed,
config.num_attention_heads,
config.num_key_value_heads,
config.max_sequence_len
)
self.feed_forward = LlamaFFN(config.nn_embed, config.ffn_intermediate_size)
self.attention_norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
self.ffn_norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False):
x = x + self.attention(self.attention_norm(x), mask, use_cache)
x = x + self.feed_forward(self.ffn_norm(x))
return x
class SmolLM2(nn.Module):
def __init__(self, config):
super().__init__()
# Normal Embedding (position embedding will be part of Attention layer)
self.embedding = nn.Embedding(config.vocab_size, config.nn_embed)
# total num_hidden_layers Blocks (Each block has attention and feedforward layer)
self.layers = nn.ModuleList([
LlamaBlock(config) for _ in range(config.num_hidden_layers)
])
self.norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
# final layer returning the logits of size (batch_size, vocab_size)
self.lm_head = nn.Linear(config.nn_embed, config.vocab_size, bias=False)
# Optimization Weight sharing between lm_head and embedding
self.lm_head.weight = self.embedding.weight
# Initialize weights
self.apply(self._init_weights)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False, targets: Optional[torch.Tensor] = None):
if (mask is None):
mask = self.create_causal_mask(x.shape[1], device=x.device)
x = self.embedding(x)
for layer in self.layers:
x = layer(x, mask, use_cache)
x = self.norm(x)
logits = self.lm_head(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
return logits, loss
return logits
# Linear layers (attention projections, FFN layers, lm_head) are initialized from N(0, 0.02)
# Embedding layer is initialized from N(0, 0.02)
# All RMSNorm weights are initialized to 1.0
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
if hasattr(module, 'NANGPT_SCALE_INIT'):
std *= (2 * self.config.n_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.RMSNorm):
torch.nn.init.ones_(module.weight)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
def clear_cache(self):
"""Clear KV cache in all attention layers"""
for layer in self.layers:
layer.attention.clear_cache()
def create_causal_mask(self, seq_len, device):
"""Creates a causal attention mask where each position can only attend to previous positions"""
# Create lower triangular matrix (including diagonal)
# mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# mask = torch.triu(torch.ones(1, 1, seq_len, seq_len), diagonal=1).bool()
# # Invert and convert to float
# return (~mask).float()
return torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len).to(device)
@torch.no_grad()
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 20,
temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
"""
Generate text using the model
Args:
input_ids: Starting token ids (B, T)
max_new_tokens: Number of tokens to generate
temperature: Controls randomness (1.0 = neutral, <1.0 = more deterministic, >1.0 = more random)
top_k: Number of highest probability tokens to consider for sampling
Returns:
Generated token ids (B, T+max_new_tokens)
"""
batch_size, seq_len = input_ids.shape
# clear existing KV caching
self.clear_cache()
# Create a new tensor to store the generated tokens
input_ids = torch.cat([input_ids, torch.zeros((batch_size, max_new_tokens),
dtype=torch.long, device=input_ids.device)], dim=1)
# Generate tokens one at a time
for idx in range(max_new_tokens):
# print(f"Generating token {idx+1} of {max_new_tokens}")
# Get the current sequence length including cached tokens
current_seq_len = seq_len + idx
next_mask = self.create_causal_mask(current_seq_len, device=input_ids.device)
# Create mask that includes both the current input and cached tokens
# if idx == 0:
# # First iteration - create mask for the full input sequence
# next_mask = self.create_causal_mask(current_seq_len, device=input_ids.device)
# else:
# # Subsequent iterations - create mask for the new token attending to all previous tokens
# next_mask = torch.ones((1, 1, 1, current_seq_len), device=input_ids.device)
# Process including the new tokens
logits = self(input_ids[:, :current_seq_len], next_mask, use_cache=False)
# Get the last token's logits
next_token_logits = logits[:, -1, :] / temperature
# Apply top-k filtering
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
probs = F.softmax(top_k_logits, dim=-1)
# Sample from the filtered distribution
next_token = top_k_indices[
torch.arange(batch_size, device=input_ids.device),
torch.multinomial(probs, num_samples=1).squeeze(1)
]
# Update input_ids with the new token
input_ids[:, current_seq_len] = next_token
return input_ids