Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from src.utils import LlamaRotaryEmbedding, repeat_kv | |
class RMSNorm(nn.Module): | |
def __init__(self, dim, eps=1e-6): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def forward(self, x): | |
# Root Mean Square Layer Normalization | |
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
return x * rms * self.weight | |
class Attention(nn.Module): | |
"""Multi-head attention module with support for GQA (Grouped Query Attention).""" | |
def __init__(self, config): | |
super(Attention, self).__init__() | |
self.emb_dim = config.emb_dim | |
self.n_q_heads = config.n_q_heads | |
self.n_kv_heads = config.n_kv_heads | |
self.head_dim = self.emb_dim // self.n_q_heads | |
self.n_rep = self.n_q_heads // self.n_kv_heads | |
# Projections for Q, K, V & O | |
self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False) | |
self.k_proj = nn.Linear( | |
self.emb_dim, self.head_dim * self.n_kv_heads, bias=False | |
) | |
self.v_proj = nn.Linear( | |
self.emb_dim, self.head_dim * self.n_kv_heads, bias=False | |
) | |
self.o_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False) | |
# Initialize rotary embeddings | |
self.rotary_embedding = LlamaRotaryEmbedding( | |
dim=self.head_dim, max_seq_len=config.max_seq_len | |
) | |
# Dropout layers | |
self.attn_dropout = nn.Dropout(config.dropout) | |
self.resid_dropout = nn.Dropout(config.dropout) | |
# Causal mask | |
self.register_buffer( | |
"mask", | |
torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)).view( | |
1, 1, config.max_seq_len, config.max_seq_len | |
), | |
) | |
def forward(self, x): | |
B, T, C = x.size() # batch_size, seq_len, emb_dim | |
# Project Q, K, V | |
q = self.q_proj(x) # (B, T, emb_dim) | |
k = self.k_proj(x) # (B, T, n_kv_heads * head_dim) | |
v = self.v_proj(x) # (B, T, n_kv_heads * head_dim) | |
# Reshape Q, K, V | |
q = q.view(B, T, self.n_q_heads, self.head_dim) # (B, T, n_q_heads, head_dim) | |
k = k.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim) | |
v = v.view(B, T, self.n_kv_heads, self.head_dim) # (B, T, n_kv_heads, head_dim) | |
# Reshape for attention computation | |
q = q.transpose(1, 2) # (B, n_q_heads, T, head_dim) | |
k = k.transpose(1, 2) # (B, n_kv_heads, T, head_dim) | |
v = v.transpose(1, 2) # (B, n_kv_heads, T, head_dim) | |
# Apply rotary embeddings | |
q, k = self.rotary_embedding(q, k) | |
# Repeat K and V for GQA | |
k = repeat_kv(k, self.n_rep) # (B, n_q_heads, T, head_dim) | |
v = repeat_kv(v, self.n_rep) # (B, n_q_heads, T, head_dim) | |
# Compute attention scores | |
scale = 1.0 / math.sqrt(self.head_dim) | |
att = (q @ k.transpose(-2, -1)) * scale # (B, n_q_heads, T, T) | |
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) | |
att = F.softmax(att, dim=-1) | |
att = self.attn_dropout(att) | |
# Apply attention to values | |
y = att @ v # (B, n_q_heads, T, head_dim) | |
# Reshape and project output | |
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, emb_dim) | |
y = self.o_proj(y) | |
y = self.resid_dropout(y) | |
return y | |
class FeedForward(nn.Module): | |
"""Feed-forward module with SiLU activation.""" | |
def __init__(self, config): | |
super(FeedForward, self).__init__() | |
# Gate and up-projections project from hidden_size to intermediate_size | |
self.gate_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False) | |
self.up_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False) | |
# Down projection brings the dimension back to hidden_size | |
self.down_proj = nn.Linear(config.intermediate_size, config.emb_dim, bias=False) | |
# SiLU activation function | |
self.act_fn = F.silu | |
# Dropout layer | |
self.dropout = nn.Dropout(config.dropout) | |
def forward(self, x): | |
# Apply gate and up projections | |
gate_output = self.act_fn(self.gate_proj(x)) # SiLU activation | |
up_output = self.up_proj(x) | |
# Element-wise multiplication of gate and up projections | |
intermediate_output = gate_output * up_output | |
# Project back to hidden size | |
output = self.down_proj(intermediate_output) | |
output = self.dropout(output) | |
return output | |
class TransformerBlock(nn.Module): | |
"""Transformer block with attention and feed-forward modules.""" | |
def __init__(self, config): | |
super(TransformerBlock, self).__init__() | |
self.attention = Attention(config) | |
self.feed_forward = FeedForward(config) | |
self.input_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) | |
self.attention_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) | |
def forward(self, x): | |
x = x + self.attention(self.input_layernorm(x)) | |
x = x + self.feed_forward(self.attention_layernorm(x)) | |
return x | |
class SmolLM(nn.Module): | |
"""Small language model with transformer blocks.""" | |
def __init__(self, config): | |
super(SmolLM, self).__init__() | |
self.config = config | |
self.wte = nn.Embedding(config.vocab_size, config.emb_dim) | |
self.transformer_blocks = nn.ModuleList( | |
[TransformerBlock(config) for _ in range(config.num_layers)] | |
) | |
self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False) | |
self.apply(self._init_weights) | |
self.layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps) | |
# weight sharing | |
self.lm_head.weight = self.wte.weight | |
def total_params(self): | |
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
def _init_weights(self, module): | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
module.weight.data.normal_(mean=0.0, std=self.config.init_std) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def forward(self, x): | |
x = self.wte(x) | |
for block in self.transformer_blocks: | |
x = block(x) | |
x = self.layernorm(x) | |
logits = self.lm_head(x) | |
return logits | |
# @dataclass | |
# class Config: | |
# vocab_size: int = 49152 | |
# emb_dim: int = 576 | |
# intermediate_size: int = 1536 | |
# num_layers: int = 10 | |
# n_q_heads: int = 9 | |
# n_kv_heads: int = 3 | |
# max_seq_len: int = 8192 | |
# dropout: float = 0.1 | |
# rms_norm_eps: float = 1e-05 | |
# init_std: float = 0.041666666666666664 | |