Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import math | |
from typing import Optional | |
import torch.nn.functional as F | |
# This llama model is based on the paper: https://arxiv.org/pdf/2302.13971.pdf | |
# Model Architecturte: static/llamaModel.jpg | |
# It is a transformer model with rotary position embeddings (RoPE) and SwiGLU | |
# activation function. It uses RMSNorm for normalization. | |
# Other Good reads: https://pub.towardsai.net/llama-explained-a70e71e706e9 | |
def precompute_rotary_emb(dim: int, max_seq_len: int, base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Precompute the rotary position embeddings | |
Args: | |
dim: Dimension of the embeddings | |
max_seq_len: Maximum sequence length | |
base: Base for the angle calculations | |
Returns: | |
Tuple of (sin, cos) tensors of shape (max_seq_len, dim//2) | |
""" | |
# Create position indices tensor | |
position = torch.arange(max_seq_len).unsqueeze(1) # (seq_len, 1) | |
# Create dimension indices tensor | |
div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(base) / dim)) # (dim//2) | |
# Compute angles | |
angles = position * div_term # (seq_len, dim//2) | |
# Return sin and cos | |
return torch.sin(angles), torch.cos(angles) | |
def apply_rotary_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: | |
""" | |
Apply rotary position embeddings to the input tensor | |
Args: | |
x: Input tensor of shape (batch_size, seq_len, num_heads, head_dim) | |
sin: Sine tensor of shape (seq_len, head_dim//2) | |
cos: Cosine tensor of shape (seq_len, head_dim//2) | |
Returns: | |
Tensor with rotary position embeddings applied | |
""" | |
# Reshape x to split last dimension in half | |
x_reshape = x.float().reshape(*x.shape[:-1], -1, 2) | |
# Extract even and odd dimensions | |
x1, x2 = x_reshape[..., 0], x_reshape[..., 1] | |
# Reshape sin and cos for broadcasting | |
sin = sin.view(1, sin.shape[0], 1, sin.shape[1]) # (1, seq_len, 1, dim//2) | |
cos = cos.view(1, cos.shape[0], 1, cos.shape[1]) # (1, seq_len, 1, dim//2) | |
# Apply rotation using the rotation matrix multiplication | |
result = torch.stack([ | |
x1 * cos - x2 * sin, | |
x2 * cos + x1 * sin | |
], dim=-1) | |
return result.flatten(-2) # Flatten last 2 dimensions | |
class LlamaAttention(nn.Module): | |
def __init__(self, dim: int, num_heads: int, num_kv_heads: Optional[int] = None, max_position_embeddings=2048): | |
super().__init__() | |
self.num_heads = num_heads | |
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads | |
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | |
self.head_dim = dim // num_heads | |
self.scale = 1.0 / math.sqrt(self.head_dim) | |
# self.q_proj = nn.Linear(dim, dim, bias=False) | |
# self.k_proj = nn.Linear(dim, dim, bias=False) | |
# self.v_proj = nn.Linear(dim, dim, bias=False) | |
# Adjust projections for GQA | |
self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, D) or (B, T, H * D/H) | |
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, H_kv * D/H) | |
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, H_kv * D/H) | |
self.o_proj = nn.Linear(dim, dim, bias=False) | |
# self.o_proj.NANGPT_SCALE_INIT = 1 TODO do we need weight initialization scaling? | |
# Cache attributes | |
self.k_cache = None | |
self.v_cache = None | |
self.cache_seq_len = 0 | |
# Precompute sin and cos for all positions | |
self.sin, self.cos = precompute_rotary_emb(self.head_dim, max_position_embeddings) | |
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False): | |
batch_size, seq_len, _ = x.shape | |
# Project inputs | |
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) | |
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim) | |
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim) | |
# Get rotary embeddings for the new tokens | |
# sin = self.sin[self.cache_seq_len:self.cache_seq_len + seq_len].to(x.device) | |
# cos = self.cos[self.cache_seq_len:self.cache_seq_len + seq_len].to(x.device) | |
sin = self.sin[:seq_len].to(x.device) | |
cos = self.cos[:seq_len].to(x.device) | |
# Apply rotary embeddings | |
q = apply_rotary_emb(q, sin, cos) | |
k = apply_rotary_emb(k, sin, cos) | |
# Handle KV caching | |
# if use_cache: | |
# if self.k_cache is None: | |
# # Initialize cache if empty | |
# self.k_cache = k | |
# self.v_cache = v | |
# else: | |
# # Concatenate new KV with cached KV | |
# self.k_cache = torch.cat([self.k_cache, k], dim=1) | |
# self.v_cache = torch.cat([self.v_cache, v], dim=1) | |
# # Use concatenated KV pairs | |
# k = self.k_cache | |
# v = self.v_cache | |
# # Update cache sequence length | |
# self.cache_seq_len += seq_len | |
# Reshape for attention computation | |
q = q.transpose(1, 2) | |
k = k.transpose(1, 2) | |
v = v.transpose(1, 2) | |
# Handle GQA (Grouped Query Attention) | |
if self.num_queries_per_kv > 1: | |
k = k.unsqueeze(2).expand(-1, -1, self.num_queries_per_kv, -1, -1) | |
v = v.unsqueeze(2).expand(-1, -1, self.num_queries_per_kv, -1, -1) | |
k = k.reshape(batch_size, self.num_heads, -1, self.head_dim) | |
v = v.reshape(batch_size, self.num_heads, -1, self.head_dim) | |
# Compute attention | |
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale | |
if mask is not None: | |
scores = scores.masked_fill(mask == 0, float('-inf')) | |
attn = F.softmax(scores, dim=-1) | |
out = torch.matmul(attn, v) | |
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) | |
# Speed up - Flash Attention (calculation happens in GPU sram and not GPU RAM) TODO Not sure how to apply this in group query attention? | |
# out F.scaled_dot_product_attention(q, k, v, is_causal = True) | |
return self.o_proj(out) | |
def clear_cache(self): | |
self.k_cache = None | |
self.v_cache = None | |
self.cache_seq_len = 0 | |
class LlamaFFN(nn.Module): | |
def __init__(self, dim: int, hidden_dim: int): | |
super().__init__() | |
self.gate = nn.Linear(dim, hidden_dim, bias=False) | |
self.up = nn.Linear(dim, hidden_dim, bias=False) | |
self.down = nn.Linear(hidden_dim, dim, bias=False) | |
# self.down.NANGPT_SCALE_INIT = 1 # TODO do we need weight initialization scaling - Optimization ? | |
self.act_fn = nn.SiLU() # SwiGLU activation function | |
def forward(self, x): | |
return self.down(self.act_fn(self.gate(x)) * self.up(x)) | |
class LlamaBlock(nn.Module): | |
def __init__(self, config): | |
# nn_embed or dim is the dimension of the input to the block | |
super().__init__() | |
self.attention = LlamaAttention( | |
config.nn_embed, | |
config.num_attention_heads, | |
config.num_key_value_heads, | |
config.max_sequence_len | |
) | |
self.feed_forward = LlamaFFN(config.nn_embed, config.ffn_intermediate_size) | |
self.attention_norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps) | |
self.ffn_norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps) | |
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False): | |
x = x + self.attention(self.attention_norm(x), mask, use_cache) | |
x = x + self.feed_forward(self.ffn_norm(x)) | |
return x | |
class SmolLM2(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
# Normal Embedding (position embedding will be part of Attention layer) | |
self.embedding = nn.Embedding(config.vocab_size, config.nn_embed) | |
# total num_hidden_layers Blocks (Each block has attention and feedforward layer) | |
self.layers = nn.ModuleList([ | |
LlamaBlock(config) for _ in range(config.num_hidden_layers) | |
]) | |
self.norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps) | |
# final layer returning the logits of size (batch_size, vocab_size) | |
self.lm_head = nn.Linear(config.nn_embed, config.vocab_size, bias=False) | |
# Optimization Weight sharing between lm_head and embedding | |
self.lm_head.weight = self.embedding.weight | |
# Initialize weights | |
self.apply(self._init_weights) | |
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False, targets: Optional[torch.Tensor] = None): | |
if (mask is None): | |
mask = self.create_causal_mask(x.shape[1], device=x.device) | |
x = self.embedding(x) | |
for layer in self.layers: | |
x = layer(x, mask, use_cache) | |
x = self.norm(x) | |
logits = self.lm_head(x) | |
if targets is not None: | |
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1)) | |
return logits, loss | |
return logits | |
# Linear layers (attention projections, FFN layers, lm_head) are initialized from N(0, 0.02) | |
# Embedding layer is initialized from N(0, 0.02) | |
# All RMSNorm weights are initialized to 1.0 | |
def _init_weights(self, module): | |
if isinstance(module, nn.Linear): | |
std = 0.02 | |
if hasattr(module, 'NANGPT_SCALE_INIT'): | |
std *= (2 * self.config.n_layer) ** -0.5 | |
torch.nn.init.normal_(module.weight, mean = 0.0, std = std) | |
if module.bias is not None: | |
torch.nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.RMSNorm): | |
torch.nn.init.ones_(module.weight) | |
elif isinstance(module, nn.Embedding): | |
torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02) | |
def clear_cache(self): | |
"""Clear KV cache in all attention layers""" | |
for layer in self.layers: | |
layer.attention.clear_cache() | |
def create_causal_mask(self, seq_len, device): | |
"""Creates a causal attention mask where each position can only attend to previous positions""" | |
# Create lower triangular matrix (including diagonal) | |
# mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() | |
# mask = torch.triu(torch.ones(1, 1, seq_len, seq_len), diagonal=1).bool() | |
# # Invert and convert to float | |
# return (~mask).float() | |
return torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len).to(device) | |
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 20, | |
temperature: float = 1.0, top_k: int = 50) -> torch.Tensor: | |
""" | |
Generate text using the model | |
Args: | |
input_ids: Starting token ids (B, T) | |
max_new_tokens: Number of tokens to generate | |
temperature: Controls randomness (1.0 = neutral, <1.0 = more deterministic, >1.0 = more random) | |
top_k: Number of highest probability tokens to consider for sampling | |
Returns: | |
Generated token ids (B, T+max_new_tokens) | |
""" | |
batch_size, seq_len = input_ids.shape | |
# clear existing KV caching | |
self.clear_cache() | |
# Create a new tensor to store the generated tokens | |
input_ids = torch.cat([input_ids, torch.zeros((batch_size, max_new_tokens), | |
dtype=torch.long, device=input_ids.device)], dim=1) | |
# Generate tokens one at a time | |
for idx in range(max_new_tokens): | |
# print(f"Generating token {idx+1} of {max_new_tokens}") | |
# Get the current sequence length including cached tokens | |
current_seq_len = seq_len + idx | |
next_mask = self.create_causal_mask(current_seq_len, device=input_ids.device) | |
# Create mask that includes both the current input and cached tokens | |
# if idx == 0: | |
# # First iteration - create mask for the full input sequence | |
# next_mask = self.create_causal_mask(current_seq_len, device=input_ids.device) | |
# else: | |
# # Subsequent iterations - create mask for the new token attending to all previous tokens | |
# next_mask = torch.ones((1, 1, 1, current_seq_len), device=input_ids.device) | |
# Process including the new tokens | |
logits = self(input_ids[:, :current_seq_len], next_mask, use_cache=False) | |
# Get the last token's logits | |
next_token_logits = logits[:, -1, :] / temperature | |
# Apply top-k filtering | |
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) | |
probs = F.softmax(top_k_logits, dim=-1) | |
# Sample from the filtered distribution | |
next_token = top_k_indices[ | |
torch.arange(batch_size, device=input_ids.device), | |
torch.multinomial(probs, num_samples=1).squeeze(1) | |
] | |
# Update input_ids with the new token | |
input_ids[:, current_seq_len] = next_token | |
return input_ids |