gitesh-grover commited on
Commit
960a17b
·
verified ·
1 Parent(s): 8b4bb02

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +226 -7
  2. app.py +61 -0
  3. config.py +40 -0
  4. model.py +306 -0
  5. requirements.txt +8 -0
  6. utils.py +26 -0
README.md CHANGED
@@ -1,13 +1,232 @@
1
  ---
2
- title: SmolLM2 135m
3
- emoji: 🌖
4
- colorFrom: yellow
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.15.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: Demo SmolLM2-135m model trained for only 5k steps
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SmolLM2 135M Text Generation Demo
3
+ emoji: 📚
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+ # SmolLM2 Text Generation Demo
13
+
14
+ This is a simple text generation demo using the SmolLM2 language model with a Gradio interface.
15
+
16
+ ## Description
17
+
18
+ This application provides a web interface for text generation using the SmolLM2 language model. Users can input a prompt and adjust various generation parameters to control the output.
19
+
20
+ ## Features
21
+
22
+ - Interactive web interface built with Gradio
23
+ - Adjustable generation parameters:
24
+ - Maximum new tokens (1-150)
25
+ - Temperature (0.1-2.0)
26
+ - Top-K sampling (1-100)
27
+ - Real-time text generation
28
+
29
+ ## Usage
30
+
31
+ 1. Enter your prompt in the text input field
32
+ 2. Adjust the generation parameters (optional):
33
+ - **Max New Tokens**: Controls the length of the generated text
34
+ - **Temperature**: Controls randomness (higher = more creative, lower = more focused)
35
+ - **Top-K**: Controls diversity of word choices
36
+ 3. Click submit to generate text
37
+
38
+ ## Installation
39
+
40
+ 1. Clone the repository
41
+ 2. Install dependencies:
42
+ ```bash
43
+ pip install -r requirements.txt
44
+ ```
45
+ ## Run the application:
46
+ ```bash
47
+ python app.py
48
+ ```
49
+ The interface will be available at `http://localhost:7860`
50
+
51
+
52
+ ## Train the model:
53
+ ```bash
54
+ python train.py
55
+ ```
56
+
57
+
58
+ # Model details
59
+ SmolLM2 is a language model designed for [add your model's specific details here]. The model uses the [specify tokenizer] tokenizer from Hugging Face's transformers library.
60
+
61
+ ## Llama 2 Architecture
62
+
63
+ ![Llama 2 Architecture](./static/llamaModel.jpg)
64
+ Read https://pub.towardsai.net/llama-explained-a70e71e706e9 for more details.
65
+
66
+ # Compare Custom SmolLM2-135 with HuggingFaceTB/SmolLM2-135M
67
+ HuggingFaceTB/SmolLM2-135M
68
+ ```bash
69
+ LlamaForCausalLM(
70
+ (model): LlamaModel(
71
+ (embed_tokens): Embedding(49152, 576)
72
+ (layers): ModuleList(
73
+ (0-29): 30 x LlamaDecoderLayer(
74
+ (self_attn): LlamaAttention(
75
+ (q_proj): Linear(in_features=576, out_features=576, bias=False)
76
+ (k_proj): Linear(in_features=576, out_features=192, bias=False)
77
+ (v_proj): Linear(in_features=576, out_features=192, bias=False)
78
+ (o_proj): Linear(in_features=576, out_features=576, bias=False)
79
+ )
80
+ (mlp): LlamaMLP(
81
+ (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
82
+ (up_proj): Linear(in_features=576, out_features=1536, bias=False)
83
+ (down_proj): Linear(in_features=1536, out_features=576, bias=False)
84
+ (act_fn): SiLU()
85
+ )
86
+ (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
87
+ (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
88
+ )
89
+ )
90
+ (norm): LlamaRMSNorm((576,), eps=1e-05)
91
+ (rotary_emb): LlamaRotaryEmbedding()
92
+ )
93
+ (lm_head): Linear(in_features=576, out_features=49152, bias=False)
94
+ )
95
+ ```
96
+
97
+ Custom SmolLM2-135
98
+ ```bash
99
+ SmolLM2(
100
+ (embedding): Embedding(49152, 576)
101
+ (layers): ModuleList(
102
+ (0-29): 30 x LlamaBlock(
103
+ (attention): LlamaAttention(
104
+ (q_proj): Linear(in_features=576, out_features=576, bias=False)
105
+ (k_proj): Linear(in_features=576, out_features=192, bias=False)
106
+ (v_proj): Linear(in_features=576, out_features=192, bias=False)
107
+ (o_proj): Linear(in_features=576, out_features=576, bias=False)
108
+ )
109
+ (feed_forward): LlamaFFN(
110
+ (gate): Linear(in_features=576, out_features=1536, bias=False)
111
+ (up): Linear(in_features=576, out_features=1536, bias=False)
112
+ (down): Linear(in_features=1536, out_features=576, bias=False)
113
+ (act_fn): SiLU()
114
+ )
115
+ (attention_norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)
116
+ (ffn_norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)
117
+ )
118
+ )
119
+ (norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)
120
+ (lm_head): Linear(in_features=576, out_features=49152, bias=False)
121
+ )
122
+
123
+ ```
124
+
125
+ # Training Logs
126
+ ## Training with 5000 steps (without checkpoint)
127
+ ```bash
128
+ (venv) gitesh.grover@Giteshs-MacBook-Pro ai-era-assignment13 % python train.py
129
+
130
+
131
+ Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 720.56it/s]
132
+ Resolving data files: 100%|████████████████████████████████████████��████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 562123.22it/s]
133
+ Epoch: 0, Step: 0, Batch: 0, Loss: 10.9101, Time: 1.44s, Token/s: 2842.75
134
+ Saved checkpoint at step 0
135
+ What is Gravity? thymopenedi something aneur checklist fertiliserlete hiding Watching [[GuardinnamonGuard thym thym something multilinguali runway astronlighten runwayinnamon nastylighten disadvant snout plumquest
136
+ Epoch: 0, Step: 1, Batch: 1, Loss: 10.6729, Time: 2.00s, Token/s: 2044.98
137
+ Epoch: 0, Step: 2, Batch: 2, Loss: 9.2034, Time: 1.16s, Token/s: 3517.56
138
+ Epoch: 0, Step: 3, Batch: 3, Loss: 8.5723, Time: 1.09s, Token/s: 3766.14
139
+ Epoch: 0, Step: 4, Batch: 4, Loss: 8.1478, Time: 1.07s, Token/s: 3845.85
140
+ :
141
+ :
142
+ Epoch: 0, Step: 500, Batch: 500, Loss: 5.9723, Time: 1.07s, Token/s: 3825.45
143
+ Saved checkpoint at step 500
144
+ What is Gravity? We call us to use, I can create a `e` function to do to add a few to calculate their lives.
145
+ * An the need
146
+ Epoch: 0, Step: 501, Batch: 501, Loss: 6.0491, Time: 1.58s, Token/s: 2595.98
147
+ :
148
+ :
149
+ Epoch: 0, Step: 998, Batch: 998, Loss: 5.8647, Time: 1.25s, Token/s: 3289.61
150
+ Epoch: 0, Step: 999, Batch: 999, Loss: 6.0096, Time: 1.10s, Token/s: 3726.16
151
+ Epoch: 0, Step: 1000, Batch: 1000, Loss: 6.4388, Time: 1.09s, Token/s: 3763.74
152
+ Saved checkpoint at step 1000
153
+ What is Gravity? These tales of sharing a beautiful blend of the art, where will understand these questions where remain.
154
+
155
+ III. **4.g., the Individuals
156
+ :
157
+ :
158
+ Epoch: 0, Step: 1498, Batch: 1498, Loss: 7.3296, Time: 1.06s, Token/s: 3878.60
159
+ Epoch: 0, Step: 1499, Batch: 1499, Loss: 6.0611, Time: 1.06s, Token/s: 3864.26
160
+ Epoch: 0, Step: 1500, Batch: 1500, Loss: 6.1140, Time: 1.08s, Token/s: 3789.80
161
+ Saved checkpoint at step 1500
162
+ What is Gravity?
163
+
164
+ Now imagine don't forget, "It have been the game?" But there are just as an 'L', does not can he noticed,
165
+
166
+ :
167
+ :
168
+ :
169
+ :
170
+
171
+ Epoch: 0, Step: 3498, Batch: 3498, Loss: 5.7145, Time: 1.07s, Token/s: 3830.33
172
+ Epoch: 0, Step: 3499, Batch: 3499, Loss: 5.7578, Time: 1.09s, Token/s: 3767.61
173
+ Epoch: 0, Step: 3500, Batch: 3500, Loss: 6.0798, Time: 1.07s, Token/s: 3811.98
174
+ Saved checkpoint at step 3500
175
+ What is Gravity? Let's how a "P"? You might need to play and a new environment that makes it up a big planet of the whole piece of the information
176
+ Epoch: 0, Step: 3501, Batch: 3501, Loss: 5.8375, Time: 1.47s, Token/s: 2790.70
177
+ Epoch: 0, Step: 3502, Batch: 3502, Loss: 6.3435, Time: 1.07s, Token/s: 3838.95
178
+ Epoch: 0, Step: 3503, Batch: 3503, Loss: 5.8192, Time: 1.05s, Token/s: 3901.14
179
+
180
+ :
181
+ :
182
+ Epoch: 0, Step: 4496, Batch: 4496, Loss: 5.5488, Time: 1.06s, Token/s: 3862.06
183
+ Epoch: 0, Step: 4497, Batch: 4497, Loss: 5.8281, Time: 1.07s, Token/s: 3821.71
184
+ Epoch: 0, Step: 4498, Batch: 4498, Loss: 5.5703, Time: 1.07s, Token/s: 3844.92
185
+ Epoch: 0, Step: 4499, Batch: 4499, Loss: 6.0630, Time: 1.06s, Token/s: 3854.04
186
+ Epoch: 0, Step: 4500, Batch: 4500, Loss: 5.5889, Time: 1.06s, Token/s: 3860.19
187
+ Saved checkpoint at step 4500
188
+ What is Gravity?
189
+
190
+ V. **Additional 2: Prepare a Power
191
+
192
+ * **I and the Eaught of Life
193
+
194
+ Before our exploration, understanding
195
+ :
196
+ :
197
+ Epoch: 0, Step: 4996, Batch: 4996, Loss: 6.1501, Time: 1.06s, Token/s: 3865.19
198
+ Epoch: 0, Step: 4997, Batch: 4997, Loss: 5.9107, Time: 1.05s, Token/s: 3884.67
199
+ Epoch: 0, Step: 4998, Batch: 4998, Loss: 5.7005, Time: 1.07s, Token/s: 3834.26
200
+ Epoch: 0, Step: 4999, Batch: 4999, Loss: 5.8820, Time: 1.07s, Token/s: 3814.07
201
+ Saved final checkpoint
202
+ What is Gravity? You would be a better big way, there are people have just like!
203
+
204
+ As they saw out to the world in the world or making a
205
+ Training complete
206
+
207
+ ```
208
+
209
+ ## Training with Additional 50 steps (with checkpoint)
210
+ ```bash
211
+ Loading checkpoint from checkpoints/checkpoint_final.pt
212
+ Resuming from epoch 0 at step 5000 with loss 5.881985664367676
213
+ Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 313.79it/s]
214
+ Resolving data files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 462574.35it/s]
215
+ Epoch: 0, Step: 5000, Batch: 0, Loss: 5.6473, Time: 2.69s, Token/s: 1520.90
216
+ Saved checkpoint at step 5000
217
+ What is Gravity? Well, remember, there's where those who do something as part of art and animals, family around us. For instance, there's like! But
218
+ Epoch: 0, Step: 5001, Batch: 1, Loss: 6.1124, Time: 1.54s, Token/s: 2660.36
219
+ Epoch: 0, Step: 5002, Batch: 2, Loss: 5.8381, Time: 1.11s, Token/s: 3680.22
220
+ :
221
+ :
222
+ Epoch: 0, Step: 5044, Batch: 44, Loss: 6.1118, Time: 1.09s, Token/s: 3749.53
223
+ Epoch: 0, Step: 5045, Batch: 45, Loss: 5.8618, Time: 1.11s, Token/s: 3676.88
224
+ Epoch: 0, Step: 5046, Batch: 46, Loss: 5.8893, Time: 1.08s, Token/s: 3784.70
225
+ Epoch: 0, Step: 5047, Batch: 47, Loss: 5.7507, Time: 1.10s, Token/s: 3729.83
226
+ Epoch: 0, Step: 5048, Batch: 48, Loss: 5.6882, Time: 1.10s, Token/s: 3715.38
227
+ Epoch: 0, Step: 5049, Batch: 49, Loss: 5.7396, Time: 1.09s, Token/s: 3745.38
228
+ Saved final checkpoint
229
+ What is Gravity? Have you would be wondering what life, you don't just how to do? She needed, they have had to know that "but these things has
230
+ Training complete
231
+
232
+ ```
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from model import SmolLM2
4
+ from transformers import AutoTokenizer
5
+ from config import Config
6
+ from utils import get_device
7
+ # Initialize model and tokenizer
8
+
9
+ config = Config()
10
+ device = get_device(config.seed)
11
+ print("device: ", device)
12
+
13
+ def load_model():
14
+ model = SmolLM2(config)
15
+ # Load model weights to CPU first
16
+ model.load_state_dict(torch.load(config.checkpoints_path + "/model_final.pt", map_location=torch.device("cpu")))
17
+ model.to(device)
18
+ model.eval()
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
21
+ return model, tokenizer
22
+
23
+ model, tokenizer = load_model() # Get device from load_model
24
+
25
+ def generate_text(input_text, max_new_tokens=100, temperature=0.8, top_k=50):
26
+ """
27
+ Generate text based on the input prompt
28
+ """
29
+ # Tokenize input
30
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
31
+
32
+ # Generate
33
+ with torch.no_grad():
34
+ output_ids = model.generate(
35
+ input_ids=input_ids,
36
+ max_new_tokens=max_new_tokens,
37
+ temperature=temperature,
38
+ top_k=top_k
39
+ )
40
+
41
+ # Move output back to CPU before decoding
42
+ output_ids = output_ids.cpu()
43
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
44
+ return generated_text
45
+
46
+ # Create Gradio interface
47
+ demo = gr.Interface(
48
+ fn=generate_text,
49
+ inputs=[
50
+ gr.Textbox(label="Input Text", placeholder="Enter your prompt here..."),
51
+ gr.Slider(minimum=1, maximum=150, value=30, step=1, label="Max New Tokens"),
52
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
53
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K"),
54
+ ],
55
+ outputs=gr.Textbox(label="Generated Text"),
56
+ title="SmolLM2 Text Generation",
57
+ description="Enter a prompt and the model will generate text based on it.",
58
+ )
59
+
60
+ if __name__ == "__main__":
61
+ demo.launch()
config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class Config:
5
+ seed: int = 49
6
+ vocab_size: int = 49152 # it should match the vocab size of the tokenizer
7
+ num_hidden_layers: int = 30 # number of layers
8
+ num_attention_heads: int = 9 # number of heads
9
+ num_key_value_heads: int = 3 # number of key and value heads
10
+ nn_embed: int = 576 # embedding dimension or hidden_size
11
+ max_sequence_len: int = 2048 # max token sequence length (for pos embedding) # Block size
12
+ ffn_intermediate_size: int = 1536
13
+ rms_norm_eps: float = 1.0e-05
14
+ nn_top_k: int = 50 # top k for the model
15
+ nn_temperature: float = 1.0 # temperature for the model
16
+ tokenizer_name_or_path: str = "HuggingFaceTB/cosmo2-tokenizer"
17
+ checkpoint_interval: int = 1000
18
+ checkpoints_path = "checkpoints"
19
+ # init_method_std: 0.041666666666666664
20
+ nn_train_tok_seq: int = 65 # Actual training token sequence block size 64 + 1 as we are shifting the targets by 1
21
+ # nn_mlp_expansion: int = 4 # Expansion in the MLP layer
22
+ batch_size: int = 64
23
+ # train_tok_size: int = 32
24
+ # saved_model_path = 'data/model_tf.pth'
25
+ # train_input_file = 'data/input.txt'
26
+ optimizer_learning_rate_scheduler_learning_rate: float = 0.003
27
+ optimizer_learning_rate_scheduler_lr_decay_starting_step: int = 1600000
28
+ optimizer_learning_rate_scheduler_lr_decay_steps: int = 400000
29
+ optimizer_learning_rate_scheduler_lr_decay_style: str = "linear"
30
+ optimizer_learning_rate_scheduler_lr_warmup_steps: int = 2000
31
+ optimizer_learning_rate_scheduler_lr_warmup_style: str = "linear"
32
+ optimizer_learning_rate_scheduler_min_decay_lr: float = 0
33
+ optimizer_factory_adam_beta1: float = 0.9
34
+ optimizer_factory_adam_beta2: float = 0.95
35
+ optimizer_factory_adam_eps: float = 1.0e-08
36
+ optimizer_factory_name: str = "adamW"
37
+ optimizer_factory_torch_adam_is_fused: bool = True
38
+ optimizer_weight_decay: float = 0.01
39
+ optimizer_zero_stage: int = 0
40
+ optimizer_clip_grad: float = 1.0
model.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from typing import Optional
5
+ import torch.nn.functional as F
6
+
7
+ # This llama model is based on the paper: https://arxiv.org/pdf/2302.13971.pdf
8
+ # Model Architecturte: static/llamaModel.jpg
9
+ # It is a transformer model with rotary position embeddings (RoPE) and SwiGLU
10
+ # activation function. It uses RMSNorm for normalization.
11
+ # Other Good reads: https://pub.towardsai.net/llama-explained-a70e71e706e9
12
+
13
+ def precompute_rotary_emb(dim: int, max_seq_len: int, base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]:
14
+ """
15
+ Precompute the rotary position embeddings
16
+ Args:
17
+ dim: Dimension of the embeddings
18
+ max_seq_len: Maximum sequence length
19
+ base: Base for the angle calculations
20
+ Returns:
21
+ Tuple of (sin, cos) tensors of shape (max_seq_len, dim//2)
22
+ """
23
+ # Create position indices tensor
24
+ position = torch.arange(max_seq_len).unsqueeze(1) # (seq_len, 1)
25
+ # Create dimension indices tensor
26
+ div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(base) / dim)) # (dim//2)
27
+ # Compute angles
28
+ angles = position * div_term # (seq_len, dim//2)
29
+ # Return sin and cos
30
+ return torch.sin(angles), torch.cos(angles)
31
+
32
+ def apply_rotary_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Apply rotary position embeddings to the input tensor
35
+ Args:
36
+ x: Input tensor of shape (batch_size, seq_len, num_heads, head_dim)
37
+ sin: Sine tensor of shape (seq_len, head_dim//2)
38
+ cos: Cosine tensor of shape (seq_len, head_dim//2)
39
+ Returns:
40
+ Tensor with rotary position embeddings applied
41
+ """
42
+ # Reshape x to split last dimension in half
43
+ x_reshape = x.float().reshape(*x.shape[:-1], -1, 2)
44
+ # Extract even and odd dimensions
45
+ x1, x2 = x_reshape[..., 0], x_reshape[..., 1]
46
+
47
+ # Reshape sin and cos for broadcasting
48
+ sin = sin.view(1, sin.shape[0], 1, sin.shape[1]) # (1, seq_len, 1, dim//2)
49
+ cos = cos.view(1, cos.shape[0], 1, cos.shape[1]) # (1, seq_len, 1, dim//2)
50
+
51
+ # Apply rotation using the rotation matrix multiplication
52
+ result = torch.stack([
53
+ x1 * cos - x2 * sin,
54
+ x2 * cos + x1 * sin
55
+ ], dim=-1)
56
+
57
+ return result.flatten(-2) # Flatten last 2 dimensions
58
+
59
+ class LlamaAttention(nn.Module):
60
+ def __init__(self, dim: int, num_heads: int, num_kv_heads: Optional[int] = None, max_position_embeddings=2048):
61
+ super().__init__()
62
+ self.num_heads = num_heads
63
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
64
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
65
+ self.head_dim = dim // num_heads
66
+ self.scale = 1.0 / math.sqrt(self.head_dim)
67
+
68
+ # self.q_proj = nn.Linear(dim, dim, bias=False)
69
+ # self.k_proj = nn.Linear(dim, dim, bias=False)
70
+ # self.v_proj = nn.Linear(dim, dim, bias=False)
71
+ # Adjust projections for GQA
72
+ self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, D) or (B, T, H * D/H)
73
+ self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, H_kv * D/H)
74
+ self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, H_kv * D/H)
75
+ self.o_proj = nn.Linear(dim, dim, bias=False)
76
+ # self.o_proj.NANGPT_SCALE_INIT = 1 TODO do we need weight initialization scaling?
77
+
78
+ # Cache attributes
79
+ self.k_cache = None
80
+ self.v_cache = None
81
+ self.cache_seq_len = 0
82
+
83
+ # Precompute sin and cos for all positions
84
+ self.sin, self.cos = precompute_rotary_emb(self.head_dim, max_position_embeddings)
85
+
86
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False):
87
+ batch_size, seq_len, _ = x.shape
88
+
89
+ # Project inputs
90
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
91
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
92
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
93
+
94
+ # Get rotary embeddings for the new tokens
95
+ # sin = self.sin[self.cache_seq_len:self.cache_seq_len + seq_len].to(x.device)
96
+ # cos = self.cos[self.cache_seq_len:self.cache_seq_len + seq_len].to(x.device)
97
+ sin = self.sin[:seq_len].to(x.device)
98
+ cos = self.cos[:seq_len].to(x.device)
99
+
100
+ # Apply rotary embeddings
101
+ q = apply_rotary_emb(q, sin, cos)
102
+ k = apply_rotary_emb(k, sin, cos)
103
+
104
+ # Handle KV caching
105
+ # if use_cache:
106
+ # if self.k_cache is None:
107
+ # # Initialize cache if empty
108
+ # self.k_cache = k
109
+ # self.v_cache = v
110
+ # else:
111
+ # # Concatenate new KV with cached KV
112
+ # self.k_cache = torch.cat([self.k_cache, k], dim=1)
113
+ # self.v_cache = torch.cat([self.v_cache, v], dim=1)
114
+
115
+ # # Use concatenated KV pairs
116
+ # k = self.k_cache
117
+ # v = self.v_cache
118
+
119
+ # # Update cache sequence length
120
+ # self.cache_seq_len += seq_len
121
+
122
+ # Reshape for attention computation
123
+ q = q.transpose(1, 2)
124
+ k = k.transpose(1, 2)
125
+ v = v.transpose(1, 2)
126
+
127
+ # Handle GQA (Grouped Query Attention)
128
+ if self.num_queries_per_kv > 1:
129
+ k = k.unsqueeze(2).expand(-1, -1, self.num_queries_per_kv, -1, -1)
130
+ v = v.unsqueeze(2).expand(-1, -1, self.num_queries_per_kv, -1, -1)
131
+ k = k.reshape(batch_size, self.num_heads, -1, self.head_dim)
132
+ v = v.reshape(batch_size, self.num_heads, -1, self.head_dim)
133
+
134
+ # Compute attention
135
+ scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
136
+
137
+ if mask is not None:
138
+ scores = scores.masked_fill(mask == 0, float('-inf'))
139
+
140
+ attn = F.softmax(scores, dim=-1)
141
+ out = torch.matmul(attn, v)
142
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
143
+
144
+ # Speed up - Flash Attention (calculation happens in GPU sram and not GPU RAM) TODO Not sure how to apply this in group query attention?
145
+ # out F.scaled_dot_product_attention(q, k, v, is_causal = True)
146
+
147
+ return self.o_proj(out)
148
+
149
+ def clear_cache(self):
150
+ self.k_cache = None
151
+ self.v_cache = None
152
+ self.cache_seq_len = 0
153
+
154
+ class LlamaFFN(nn.Module):
155
+ def __init__(self, dim: int, hidden_dim: int):
156
+ super().__init__()
157
+ self.gate = nn.Linear(dim, hidden_dim, bias=False)
158
+ self.up = nn.Linear(dim, hidden_dim, bias=False)
159
+ self.down = nn.Linear(hidden_dim, dim, bias=False)
160
+ # self.down.NANGPT_SCALE_INIT = 1 # TODO do we need weight initialization scaling - Optimization ?
161
+ self.act_fn = nn.SiLU() # SwiGLU activation function
162
+
163
+ def forward(self, x):
164
+ return self.down(self.act_fn(self.gate(x)) * self.up(x))
165
+
166
+ class LlamaBlock(nn.Module):
167
+ def __init__(self, config):
168
+ # nn_embed or dim is the dimension of the input to the block
169
+ super().__init__()
170
+ self.attention = LlamaAttention(
171
+ config.nn_embed,
172
+ config.num_attention_heads,
173
+ config.num_key_value_heads,
174
+ config.max_sequence_len
175
+ )
176
+ self.feed_forward = LlamaFFN(config.nn_embed, config.ffn_intermediate_size)
177
+ self.attention_norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
178
+ self.ffn_norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
179
+
180
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False):
181
+ x = x + self.attention(self.attention_norm(x), mask, use_cache)
182
+ x = x + self.feed_forward(self.ffn_norm(x))
183
+ return x
184
+
185
+ class SmolLM2(nn.Module):
186
+ def __init__(self, config):
187
+ super().__init__()
188
+ # Normal Embedding (position embedding will be part of Attention layer)
189
+ self.embedding = nn.Embedding(config.vocab_size, config.nn_embed)
190
+
191
+ # total num_hidden_layers Blocks (Each block has attention and feedforward layer)
192
+ self.layers = nn.ModuleList([
193
+ LlamaBlock(config) for _ in range(config.num_hidden_layers)
194
+ ])
195
+ self.norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
196
+ # final layer returning the logits of size (batch_size, vocab_size)
197
+ self.lm_head = nn.Linear(config.nn_embed, config.vocab_size, bias=False)
198
+
199
+ # Optimization Weight sharing between lm_head and embedding
200
+ self.lm_head.weight = self.embedding.weight
201
+
202
+ # Initialize weights
203
+ self.apply(self._init_weights)
204
+
205
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False, targets: Optional[torch.Tensor] = None):
206
+ if (mask is None):
207
+ mask = self.create_causal_mask(x.shape[1], device=x.device)
208
+ x = self.embedding(x)
209
+ for layer in self.layers:
210
+ x = layer(x, mask, use_cache)
211
+ x = self.norm(x)
212
+ logits = self.lm_head(x)
213
+ if targets is not None:
214
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
215
+ return logits, loss
216
+ return logits
217
+
218
+ # Linear layers (attention projections, FFN layers, lm_head) are initialized from N(0, 0.02)
219
+ # Embedding layer is initialized from N(0, 0.02)
220
+ # All RMSNorm weights are initialized to 1.0
221
+ def _init_weights(self, module):
222
+ if isinstance(module, nn.Linear):
223
+ std = 0.02
224
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
225
+ std *= (2 * self.config.n_layer) ** -0.5
226
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
227
+ if module.bias is not None:
228
+ torch.nn.init.zeros_(module.bias)
229
+ elif isinstance(module, nn.RMSNorm):
230
+ torch.nn.init.ones_(module.weight)
231
+ elif isinstance(module, nn.Embedding):
232
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
233
+
234
+ def clear_cache(self):
235
+ """Clear KV cache in all attention layers"""
236
+ for layer in self.layers:
237
+ layer.attention.clear_cache()
238
+
239
+ def create_causal_mask(self, seq_len, device):
240
+ """Creates a causal attention mask where each position can only attend to previous positions"""
241
+ # Create lower triangular matrix (including diagonal)
242
+ # mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
243
+ # mask = torch.triu(torch.ones(1, 1, seq_len, seq_len), diagonal=1).bool()
244
+ # # Invert and convert to float
245
+ # return (~mask).float()
246
+ return torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len).to(device)
247
+
248
+ @torch.no_grad()
249
+ def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 20,
250
+ temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
251
+ """
252
+ Generate text using the model
253
+ Args:
254
+ input_ids: Starting token ids (B, T)
255
+ max_new_tokens: Number of tokens to generate
256
+ temperature: Controls randomness (1.0 = neutral, <1.0 = more deterministic, >1.0 = more random)
257
+ top_k: Number of highest probability tokens to consider for sampling
258
+ Returns:
259
+ Generated token ids (B, T+max_new_tokens)
260
+ """
261
+ batch_size, seq_len = input_ids.shape
262
+
263
+ # clear existing KV caching
264
+ self.clear_cache()
265
+
266
+ # Create a new tensor to store the generated tokens
267
+ input_ids = torch.cat([input_ids, torch.zeros((batch_size, max_new_tokens),
268
+ dtype=torch.long, device=input_ids.device)], dim=1)
269
+
270
+ # Generate tokens one at a time
271
+ for idx in range(max_new_tokens):
272
+ # print(f"Generating token {idx+1} of {max_new_tokens}")
273
+
274
+ # Get the current sequence length including cached tokens
275
+ current_seq_len = seq_len + idx
276
+
277
+ next_mask = self.create_causal_mask(current_seq_len, device=input_ids.device)
278
+
279
+ # Create mask that includes both the current input and cached tokens
280
+ # if idx == 0:
281
+ # # First iteration - create mask for the full input sequence
282
+ # next_mask = self.create_causal_mask(current_seq_len, device=input_ids.device)
283
+ # else:
284
+ # # Subsequent iterations - create mask for the new token attending to all previous tokens
285
+ # next_mask = torch.ones((1, 1, 1, current_seq_len), device=input_ids.device)
286
+
287
+ # Process including the new tokens
288
+ logits = self(input_ids[:, :current_seq_len], next_mask, use_cache=False)
289
+
290
+ # Get the last token's logits
291
+ next_token_logits = logits[:, -1, :] / temperature
292
+
293
+ # Apply top-k filtering
294
+ top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
295
+ probs = F.softmax(top_k_logits, dim=-1)
296
+
297
+ # Sample from the filtered distribution
298
+ next_token = top_k_indices[
299
+ torch.arange(batch_size, device=input_ids.device),
300
+ torch.multinomial(probs, num_samples=1).squeeze(1)
301
+ ]
302
+
303
+ # Update input_ids with the new token
304
+ input_ids[:, current_seq_len] = next_token
305
+
306
+ return input_ids
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ datasets>=2.12.0
4
+ numpy>=1.24.0
5
+ tqdm>=4.65.0
6
+ huggingface-hub>=0.16.0
7
+ tokenizers>=0.13.0
8
+ gradio>=4.0.0
utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def get_device(seed = 1):
4
+ # Seed is to generate the same random data for each run
5
+ # For reproducibility
6
+ torch.manual_seed(seed)
7
+
8
+ # Set device
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
10
+
11
+ if torch.cuda.is_available():
12
+ print(f"[INFO] GPU: {torch.cuda.get_device_name(0)}")
13
+ print(f"[INFO] CUDA Version: {torch.version.cuda}\n")
14
+ torch.cuda.manual_seed(seed)
15
+
16
+ if not torch.backends.mps.is_available():
17
+ if not torch.backends.mps.is_built():
18
+ print("MPS not available because the current PyTorch install was not "
19
+ "built with MPS enabled.")
20
+ else:
21
+ print("MPS not available because the current MacOS version is not 12.3+ "
22
+ "and/or you do not have an MPS-enabled device on this machine.")
23
+ else:
24
+ torch.mps.manual_seed(seed)
25
+
26
+ return device