File size: 6,960 Bytes
70a0a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from src.utils import LlamaRotaryEmbedding, repeat_kv


class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # Root Mean Square Layer Normalization
        rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * rms * self.weight


class Attention(nn.Module):
    """Multi-head attention module with support for GQA (Grouped Query Attention)."""

    def __init__(self, config):
        super(Attention, self).__init__()
        self.emb_dim = config.emb_dim
        self.n_q_heads = config.n_q_heads
        self.n_kv_heads = config.n_kv_heads
        self.head_dim = self.emb_dim // self.n_q_heads
        self.n_rep = self.n_q_heads // self.n_kv_heads

        # Projections for Q, K, V & O
        self.q_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)
        self.k_proj = nn.Linear(
            self.emb_dim, self.head_dim * self.n_kv_heads, bias=False
        )
        self.v_proj = nn.Linear(
            self.emb_dim, self.head_dim * self.n_kv_heads, bias=False
        )
        self.o_proj = nn.Linear(self.emb_dim, self.emb_dim, bias=False)

        # Initialize rotary embeddings
        self.rotary_embedding = LlamaRotaryEmbedding(
            dim=self.head_dim, max_seq_len=config.max_seq_len
        )

        # Dropout layers
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # Causal mask
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)).view(
                1, 1, config.max_seq_len, config.max_seq_len
            ),
        )

    def forward(self, x):
        B, T, C = x.size()  # batch_size, seq_len, emb_dim

        # Project Q, K, V
        q = self.q_proj(x)  # (B, T, emb_dim)
        k = self.k_proj(x)  # (B, T, n_kv_heads * head_dim)
        v = self.v_proj(x)  # (B, T, n_kv_heads * head_dim)

        # Reshape Q, K, V
        q = q.view(B, T, self.n_q_heads, self.head_dim)  # (B, T, n_q_heads, head_dim)
        k = k.view(B, T, self.n_kv_heads, self.head_dim)  # (B, T, n_kv_heads, head_dim)
        v = v.view(B, T, self.n_kv_heads, self.head_dim)  # (B, T, n_kv_heads, head_dim)

        # Reshape for attention computation
        q = q.transpose(1, 2)  # (B, n_q_heads, T, head_dim)
        k = k.transpose(1, 2)  # (B, n_kv_heads, T, head_dim)
        v = v.transpose(1, 2)  # (B, n_kv_heads, T, head_dim)

        # Apply rotary embeddings
        q, k = self.rotary_embedding(q, k)

        # Repeat K and V for GQA
        k = repeat_kv(k, self.n_rep)  # (B, n_q_heads, T, head_dim)
        v = repeat_kv(v, self.n_rep)  # (B, n_q_heads, T, head_dim)

        # Compute attention scores
        scale = 1.0 / math.sqrt(self.head_dim)
        att = (q @ k.transpose(-2, -1)) * scale  # (B, n_q_heads, T, T)
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)

        # Apply attention to values
        y = att @ v  # (B, n_q_heads, T, head_dim)

        # Reshape and project output
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # (B, T, emb_dim)
        y = self.o_proj(y)
        y = self.resid_dropout(y)

        return y


class FeedForward(nn.Module):
    """Feed-forward module with SiLU activation."""

    def __init__(self, config):
        super(FeedForward, self).__init__()
        # Gate and up-projections project from hidden_size to intermediate_size
        self.gate_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.emb_dim, config.intermediate_size, bias=False)

        # Down projection brings the dimension back to hidden_size
        self.down_proj = nn.Linear(config.intermediate_size, config.emb_dim, bias=False)

        # SiLU activation function
        self.act_fn = F.silu

        # Dropout layer
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        # Apply gate and up projections
        gate_output = self.act_fn(self.gate_proj(x))  # SiLU activation
        up_output = self.up_proj(x)

        # Element-wise multiplication of gate and up projections
        intermediate_output = gate_output * up_output

        # Project back to hidden size
        output = self.down_proj(intermediate_output)
        output = self.dropout(output)

        return output


class TransformerBlock(nn.Module):
    """Transformer block with attention and feed-forward modules."""

    def __init__(self, config):
        super(TransformerBlock, self).__init__()
        self.attention = Attention(config)
        self.feed_forward = FeedForward(config)
        self.input_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)
        self.attention_layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)

    def forward(self, x):
        x = x + self.attention(self.input_layernorm(x))
        x = x + self.feed_forward(self.attention_layernorm(x))

        return x


class SmolLM(nn.Module):
    """Small language model with transformer blocks."""

    def __init__(self, config):
        super(SmolLM, self).__init__()
        self.config = config
        self.wte = nn.Embedding(config.vocab_size, config.emb_dim)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(config) for _ in range(config.num_layers)]
        )

        self.lm_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
        self.apply(self._init_weights)
        self.layernorm = RMSNorm(config.emb_dim, config.rms_norm_eps)

        # weight sharing
        self.lm_head.weight = self.wte.weight

    def total_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, x):
        x = self.wte(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = self.layernorm(x)
        logits = self.lm_head(x)
        return logits


# @dataclass
# class Config:
#     vocab_size: int = 49152
#     emb_dim: int = 576
#     intermediate_size: int = 1536
#     num_layers: int = 10
#     n_q_heads: int = 9
#     n_kv_heads: int = 3
#     max_seq_len: int = 8192
#     dropout: float = 0.1
#     rms_norm_eps: float = 1e-05
#     init_std: float = 0.041666666666666664