Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files
README.md
CHANGED
@@ -1,13 +1,232 @@
|
|
1 |
---
|
2 |
-
title: SmolLM2
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
short_description: Demo SmolLM2-135m model trained for only 5k steps
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: SmolLM2 135M Text Generation Demo
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.50.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
+
# SmolLM2 Text Generation Demo
|
13 |
+
|
14 |
+
This is a simple text generation demo using the SmolLM2 language model with a Gradio interface.
|
15 |
+
|
16 |
+
## Description
|
17 |
+
|
18 |
+
This application provides a web interface for text generation using the SmolLM2 language model. Users can input a prompt and adjust various generation parameters to control the output.
|
19 |
+
|
20 |
+
## Features
|
21 |
+
|
22 |
+
- Interactive web interface built with Gradio
|
23 |
+
- Adjustable generation parameters:
|
24 |
+
- Maximum new tokens (1-150)
|
25 |
+
- Temperature (0.1-2.0)
|
26 |
+
- Top-K sampling (1-100)
|
27 |
+
- Real-time text generation
|
28 |
+
|
29 |
+
## Usage
|
30 |
+
|
31 |
+
1. Enter your prompt in the text input field
|
32 |
+
2. Adjust the generation parameters (optional):
|
33 |
+
- **Max New Tokens**: Controls the length of the generated text
|
34 |
+
- **Temperature**: Controls randomness (higher = more creative, lower = more focused)
|
35 |
+
- **Top-K**: Controls diversity of word choices
|
36 |
+
3. Click submit to generate text
|
37 |
+
|
38 |
+
## Installation
|
39 |
+
|
40 |
+
1. Clone the repository
|
41 |
+
2. Install dependencies:
|
42 |
+
```bash
|
43 |
+
pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
## Run the application:
|
46 |
+
```bash
|
47 |
+
python app.py
|
48 |
+
```
|
49 |
+
The interface will be available at `http://localhost:7860`
|
50 |
+
|
51 |
+
|
52 |
+
## Train the model:
|
53 |
+
```bash
|
54 |
+
python train.py
|
55 |
+
```
|
56 |
+
|
57 |
+
|
58 |
+
# Model details
|
59 |
+
SmolLM2 is a language model designed for [add your model's specific details here]. The model uses the [specify tokenizer] tokenizer from Hugging Face's transformers library.
|
60 |
+
|
61 |
+
## Llama 2 Architecture
|
62 |
+
|
63 |
+

|
64 |
+
Read https://pub.towardsai.net/llama-explained-a70e71e706e9 for more details.
|
65 |
+
|
66 |
+
# Compare Custom SmolLM2-135 with HuggingFaceTB/SmolLM2-135M
|
67 |
+
HuggingFaceTB/SmolLM2-135M
|
68 |
+
```bash
|
69 |
+
LlamaForCausalLM(
|
70 |
+
(model): LlamaModel(
|
71 |
+
(embed_tokens): Embedding(49152, 576)
|
72 |
+
(layers): ModuleList(
|
73 |
+
(0-29): 30 x LlamaDecoderLayer(
|
74 |
+
(self_attn): LlamaAttention(
|
75 |
+
(q_proj): Linear(in_features=576, out_features=576, bias=False)
|
76 |
+
(k_proj): Linear(in_features=576, out_features=192, bias=False)
|
77 |
+
(v_proj): Linear(in_features=576, out_features=192, bias=False)
|
78 |
+
(o_proj): Linear(in_features=576, out_features=576, bias=False)
|
79 |
+
)
|
80 |
+
(mlp): LlamaMLP(
|
81 |
+
(gate_proj): Linear(in_features=576, out_features=1536, bias=False)
|
82 |
+
(up_proj): Linear(in_features=576, out_features=1536, bias=False)
|
83 |
+
(down_proj): Linear(in_features=1536, out_features=576, bias=False)
|
84 |
+
(act_fn): SiLU()
|
85 |
+
)
|
86 |
+
(input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
|
87 |
+
(post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
|
88 |
+
)
|
89 |
+
)
|
90 |
+
(norm): LlamaRMSNorm((576,), eps=1e-05)
|
91 |
+
(rotary_emb): LlamaRotaryEmbedding()
|
92 |
+
)
|
93 |
+
(lm_head): Linear(in_features=576, out_features=49152, bias=False)
|
94 |
+
)
|
95 |
+
```
|
96 |
+
|
97 |
+
Custom SmolLM2-135
|
98 |
+
```bash
|
99 |
+
SmolLM2(
|
100 |
+
(embedding): Embedding(49152, 576)
|
101 |
+
(layers): ModuleList(
|
102 |
+
(0-29): 30 x LlamaBlock(
|
103 |
+
(attention): LlamaAttention(
|
104 |
+
(q_proj): Linear(in_features=576, out_features=576, bias=False)
|
105 |
+
(k_proj): Linear(in_features=576, out_features=192, bias=False)
|
106 |
+
(v_proj): Linear(in_features=576, out_features=192, bias=False)
|
107 |
+
(o_proj): Linear(in_features=576, out_features=576, bias=False)
|
108 |
+
)
|
109 |
+
(feed_forward): LlamaFFN(
|
110 |
+
(gate): Linear(in_features=576, out_features=1536, bias=False)
|
111 |
+
(up): Linear(in_features=576, out_features=1536, bias=False)
|
112 |
+
(down): Linear(in_features=1536, out_features=576, bias=False)
|
113 |
+
(act_fn): SiLU()
|
114 |
+
)
|
115 |
+
(attention_norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)
|
116 |
+
(ffn_norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)
|
117 |
+
)
|
118 |
+
)
|
119 |
+
(norm): RMSNorm((576,), eps=1e-05, elementwise_affine=True)
|
120 |
+
(lm_head): Linear(in_features=576, out_features=49152, bias=False)
|
121 |
+
)
|
122 |
+
|
123 |
+
```
|
124 |
+
|
125 |
+
# Training Logs
|
126 |
+
## Training with 5000 steps (without checkpoint)
|
127 |
+
```bash
|
128 |
+
(venv) gitesh.grover@Giteshs-MacBook-Pro ai-era-assignment13 % python train.py
|
129 |
+
|
130 |
+
|
131 |
+
Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 720.56it/s]
|
132 |
+
Resolving data files: 100%|████████████████████████████████████████��████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 562123.22it/s]
|
133 |
+
Epoch: 0, Step: 0, Batch: 0, Loss: 10.9101, Time: 1.44s, Token/s: 2842.75
|
134 |
+
Saved checkpoint at step 0
|
135 |
+
What is Gravity? thymopenedi something aneur checklist fertiliserlete hiding Watching [[GuardinnamonGuard thym thym something multilinguali runway astronlighten runwayinnamon nastylighten disadvant snout plumquest
|
136 |
+
Epoch: 0, Step: 1, Batch: 1, Loss: 10.6729, Time: 2.00s, Token/s: 2044.98
|
137 |
+
Epoch: 0, Step: 2, Batch: 2, Loss: 9.2034, Time: 1.16s, Token/s: 3517.56
|
138 |
+
Epoch: 0, Step: 3, Batch: 3, Loss: 8.5723, Time: 1.09s, Token/s: 3766.14
|
139 |
+
Epoch: 0, Step: 4, Batch: 4, Loss: 8.1478, Time: 1.07s, Token/s: 3845.85
|
140 |
+
:
|
141 |
+
:
|
142 |
+
Epoch: 0, Step: 500, Batch: 500, Loss: 5.9723, Time: 1.07s, Token/s: 3825.45
|
143 |
+
Saved checkpoint at step 500
|
144 |
+
What is Gravity? We call us to use, I can create a `e` function to do to add a few to calculate their lives.
|
145 |
+
* An the need
|
146 |
+
Epoch: 0, Step: 501, Batch: 501, Loss: 6.0491, Time: 1.58s, Token/s: 2595.98
|
147 |
+
:
|
148 |
+
:
|
149 |
+
Epoch: 0, Step: 998, Batch: 998, Loss: 5.8647, Time: 1.25s, Token/s: 3289.61
|
150 |
+
Epoch: 0, Step: 999, Batch: 999, Loss: 6.0096, Time: 1.10s, Token/s: 3726.16
|
151 |
+
Epoch: 0, Step: 1000, Batch: 1000, Loss: 6.4388, Time: 1.09s, Token/s: 3763.74
|
152 |
+
Saved checkpoint at step 1000
|
153 |
+
What is Gravity? These tales of sharing a beautiful blend of the art, where will understand these questions where remain.
|
154 |
+
|
155 |
+
III. **4.g., the Individuals
|
156 |
+
:
|
157 |
+
:
|
158 |
+
Epoch: 0, Step: 1498, Batch: 1498, Loss: 7.3296, Time: 1.06s, Token/s: 3878.60
|
159 |
+
Epoch: 0, Step: 1499, Batch: 1499, Loss: 6.0611, Time: 1.06s, Token/s: 3864.26
|
160 |
+
Epoch: 0, Step: 1500, Batch: 1500, Loss: 6.1140, Time: 1.08s, Token/s: 3789.80
|
161 |
+
Saved checkpoint at step 1500
|
162 |
+
What is Gravity?
|
163 |
+
|
164 |
+
Now imagine don't forget, "It have been the game?" But there are just as an 'L', does not can he noticed,
|
165 |
+
|
166 |
+
:
|
167 |
+
:
|
168 |
+
:
|
169 |
+
:
|
170 |
+
|
171 |
+
Epoch: 0, Step: 3498, Batch: 3498, Loss: 5.7145, Time: 1.07s, Token/s: 3830.33
|
172 |
+
Epoch: 0, Step: 3499, Batch: 3499, Loss: 5.7578, Time: 1.09s, Token/s: 3767.61
|
173 |
+
Epoch: 0, Step: 3500, Batch: 3500, Loss: 6.0798, Time: 1.07s, Token/s: 3811.98
|
174 |
+
Saved checkpoint at step 3500
|
175 |
+
What is Gravity? Let's how a "P"? You might need to play and a new environment that makes it up a big planet of the whole piece of the information
|
176 |
+
Epoch: 0, Step: 3501, Batch: 3501, Loss: 5.8375, Time: 1.47s, Token/s: 2790.70
|
177 |
+
Epoch: 0, Step: 3502, Batch: 3502, Loss: 6.3435, Time: 1.07s, Token/s: 3838.95
|
178 |
+
Epoch: 0, Step: 3503, Batch: 3503, Loss: 5.8192, Time: 1.05s, Token/s: 3901.14
|
179 |
+
|
180 |
+
:
|
181 |
+
:
|
182 |
+
Epoch: 0, Step: 4496, Batch: 4496, Loss: 5.5488, Time: 1.06s, Token/s: 3862.06
|
183 |
+
Epoch: 0, Step: 4497, Batch: 4497, Loss: 5.8281, Time: 1.07s, Token/s: 3821.71
|
184 |
+
Epoch: 0, Step: 4498, Batch: 4498, Loss: 5.5703, Time: 1.07s, Token/s: 3844.92
|
185 |
+
Epoch: 0, Step: 4499, Batch: 4499, Loss: 6.0630, Time: 1.06s, Token/s: 3854.04
|
186 |
+
Epoch: 0, Step: 4500, Batch: 4500, Loss: 5.5889, Time: 1.06s, Token/s: 3860.19
|
187 |
+
Saved checkpoint at step 4500
|
188 |
+
What is Gravity?
|
189 |
+
|
190 |
+
V. **Additional 2: Prepare a Power
|
191 |
+
|
192 |
+
* **I and the Eaught of Life
|
193 |
+
|
194 |
+
Before our exploration, understanding
|
195 |
+
:
|
196 |
+
:
|
197 |
+
Epoch: 0, Step: 4996, Batch: 4996, Loss: 6.1501, Time: 1.06s, Token/s: 3865.19
|
198 |
+
Epoch: 0, Step: 4997, Batch: 4997, Loss: 5.9107, Time: 1.05s, Token/s: 3884.67
|
199 |
+
Epoch: 0, Step: 4998, Batch: 4998, Loss: 5.7005, Time: 1.07s, Token/s: 3834.26
|
200 |
+
Epoch: 0, Step: 4999, Batch: 4999, Loss: 5.8820, Time: 1.07s, Token/s: 3814.07
|
201 |
+
Saved final checkpoint
|
202 |
+
What is Gravity? You would be a better big way, there are people have just like!
|
203 |
+
|
204 |
+
As they saw out to the world in the world or making a
|
205 |
+
Training complete
|
206 |
+
|
207 |
+
```
|
208 |
+
|
209 |
+
## Training with Additional 50 steps (with checkpoint)
|
210 |
+
```bash
|
211 |
+
Loading checkpoint from checkpoints/checkpoint_final.pt
|
212 |
+
Resuming from epoch 0 at step 5000 with loss 5.881985664367676
|
213 |
+
Resolving data files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 313.79it/s]
|
214 |
+
Resolving data files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:00<00:00, 462574.35it/s]
|
215 |
+
Epoch: 0, Step: 5000, Batch: 0, Loss: 5.6473, Time: 2.69s, Token/s: 1520.90
|
216 |
+
Saved checkpoint at step 5000
|
217 |
+
What is Gravity? Well, remember, there's where those who do something as part of art and animals, family around us. For instance, there's like! But
|
218 |
+
Epoch: 0, Step: 5001, Batch: 1, Loss: 6.1124, Time: 1.54s, Token/s: 2660.36
|
219 |
+
Epoch: 0, Step: 5002, Batch: 2, Loss: 5.8381, Time: 1.11s, Token/s: 3680.22
|
220 |
+
:
|
221 |
+
:
|
222 |
+
Epoch: 0, Step: 5044, Batch: 44, Loss: 6.1118, Time: 1.09s, Token/s: 3749.53
|
223 |
+
Epoch: 0, Step: 5045, Batch: 45, Loss: 5.8618, Time: 1.11s, Token/s: 3676.88
|
224 |
+
Epoch: 0, Step: 5046, Batch: 46, Loss: 5.8893, Time: 1.08s, Token/s: 3784.70
|
225 |
+
Epoch: 0, Step: 5047, Batch: 47, Loss: 5.7507, Time: 1.10s, Token/s: 3729.83
|
226 |
+
Epoch: 0, Step: 5048, Batch: 48, Loss: 5.6882, Time: 1.10s, Token/s: 3715.38
|
227 |
+
Epoch: 0, Step: 5049, Batch: 49, Loss: 5.7396, Time: 1.09s, Token/s: 3745.38
|
228 |
+
Saved final checkpoint
|
229 |
+
What is Gravity? Have you would be wondering what life, you don't just how to do? She needed, they have had to know that "but these things has
|
230 |
+
Training complete
|
231 |
+
|
232 |
+
```
|
app.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from model import SmolLM2
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from config import Config
|
6 |
+
from utils import get_device
|
7 |
+
# Initialize model and tokenizer
|
8 |
+
|
9 |
+
config = Config()
|
10 |
+
device = get_device(config.seed)
|
11 |
+
print("device: ", device)
|
12 |
+
|
13 |
+
def load_model():
|
14 |
+
model = SmolLM2(config)
|
15 |
+
# Load model weights to CPU first
|
16 |
+
model.load_state_dict(torch.load(config.checkpoints_path + "/model_final.pt", map_location=torch.device("cpu")))
|
17 |
+
model.to(device)
|
18 |
+
model.eval()
|
19 |
+
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
|
21 |
+
return model, tokenizer
|
22 |
+
|
23 |
+
model, tokenizer = load_model() # Get device from load_model
|
24 |
+
|
25 |
+
def generate_text(input_text, max_new_tokens=100, temperature=0.8, top_k=50):
|
26 |
+
"""
|
27 |
+
Generate text based on the input prompt
|
28 |
+
"""
|
29 |
+
# Tokenize input
|
30 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
31 |
+
|
32 |
+
# Generate
|
33 |
+
with torch.no_grad():
|
34 |
+
output_ids = model.generate(
|
35 |
+
input_ids=input_ids,
|
36 |
+
max_new_tokens=max_new_tokens,
|
37 |
+
temperature=temperature,
|
38 |
+
top_k=top_k
|
39 |
+
)
|
40 |
+
|
41 |
+
# Move output back to CPU before decoding
|
42 |
+
output_ids = output_ids.cpu()
|
43 |
+
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
44 |
+
return generated_text
|
45 |
+
|
46 |
+
# Create Gradio interface
|
47 |
+
demo = gr.Interface(
|
48 |
+
fn=generate_text,
|
49 |
+
inputs=[
|
50 |
+
gr.Textbox(label="Input Text", placeholder="Enter your prompt here..."),
|
51 |
+
gr.Slider(minimum=1, maximum=150, value=30, step=1, label="Max New Tokens"),
|
52 |
+
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
|
53 |
+
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K"),
|
54 |
+
],
|
55 |
+
outputs=gr.Textbox(label="Generated Text"),
|
56 |
+
title="SmolLM2 Text Generation",
|
57 |
+
description="Enter a prompt and the model will generate text based on it.",
|
58 |
+
)
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
demo.launch()
|
config.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class Config:
|
5 |
+
seed: int = 49
|
6 |
+
vocab_size: int = 49152 # it should match the vocab size of the tokenizer
|
7 |
+
num_hidden_layers: int = 30 # number of layers
|
8 |
+
num_attention_heads: int = 9 # number of heads
|
9 |
+
num_key_value_heads: int = 3 # number of key and value heads
|
10 |
+
nn_embed: int = 576 # embedding dimension or hidden_size
|
11 |
+
max_sequence_len: int = 2048 # max token sequence length (for pos embedding) # Block size
|
12 |
+
ffn_intermediate_size: int = 1536
|
13 |
+
rms_norm_eps: float = 1.0e-05
|
14 |
+
nn_top_k: int = 50 # top k for the model
|
15 |
+
nn_temperature: float = 1.0 # temperature for the model
|
16 |
+
tokenizer_name_or_path: str = "HuggingFaceTB/cosmo2-tokenizer"
|
17 |
+
checkpoint_interval: int = 1000
|
18 |
+
checkpoints_path = "checkpoints"
|
19 |
+
# init_method_std: 0.041666666666666664
|
20 |
+
nn_train_tok_seq: int = 65 # Actual training token sequence block size 64 + 1 as we are shifting the targets by 1
|
21 |
+
# nn_mlp_expansion: int = 4 # Expansion in the MLP layer
|
22 |
+
batch_size: int = 64
|
23 |
+
# train_tok_size: int = 32
|
24 |
+
# saved_model_path = 'data/model_tf.pth'
|
25 |
+
# train_input_file = 'data/input.txt'
|
26 |
+
optimizer_learning_rate_scheduler_learning_rate: float = 0.003
|
27 |
+
optimizer_learning_rate_scheduler_lr_decay_starting_step: int = 1600000
|
28 |
+
optimizer_learning_rate_scheduler_lr_decay_steps: int = 400000
|
29 |
+
optimizer_learning_rate_scheduler_lr_decay_style: str = "linear"
|
30 |
+
optimizer_learning_rate_scheduler_lr_warmup_steps: int = 2000
|
31 |
+
optimizer_learning_rate_scheduler_lr_warmup_style: str = "linear"
|
32 |
+
optimizer_learning_rate_scheduler_min_decay_lr: float = 0
|
33 |
+
optimizer_factory_adam_beta1: float = 0.9
|
34 |
+
optimizer_factory_adam_beta2: float = 0.95
|
35 |
+
optimizer_factory_adam_eps: float = 1.0e-08
|
36 |
+
optimizer_factory_name: str = "adamW"
|
37 |
+
optimizer_factory_torch_adam_is_fused: bool = True
|
38 |
+
optimizer_weight_decay: float = 0.01
|
39 |
+
optimizer_zero_stage: int = 0
|
40 |
+
optimizer_clip_grad: float = 1.0
|
model.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
# This llama model is based on the paper: https://arxiv.org/pdf/2302.13971.pdf
|
8 |
+
# Model Architecturte: static/llamaModel.jpg
|
9 |
+
# It is a transformer model with rotary position embeddings (RoPE) and SwiGLU
|
10 |
+
# activation function. It uses RMSNorm for normalization.
|
11 |
+
# Other Good reads: https://pub.towardsai.net/llama-explained-a70e71e706e9
|
12 |
+
|
13 |
+
def precompute_rotary_emb(dim: int, max_seq_len: int, base: int = 10000) -> tuple[torch.Tensor, torch.Tensor]:
|
14 |
+
"""
|
15 |
+
Precompute the rotary position embeddings
|
16 |
+
Args:
|
17 |
+
dim: Dimension of the embeddings
|
18 |
+
max_seq_len: Maximum sequence length
|
19 |
+
base: Base for the angle calculations
|
20 |
+
Returns:
|
21 |
+
Tuple of (sin, cos) tensors of shape (max_seq_len, dim//2)
|
22 |
+
"""
|
23 |
+
# Create position indices tensor
|
24 |
+
position = torch.arange(max_seq_len).unsqueeze(1) # (seq_len, 1)
|
25 |
+
# Create dimension indices tensor
|
26 |
+
div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(base) / dim)) # (dim//2)
|
27 |
+
# Compute angles
|
28 |
+
angles = position * div_term # (seq_len, dim//2)
|
29 |
+
# Return sin and cos
|
30 |
+
return torch.sin(angles), torch.cos(angles)
|
31 |
+
|
32 |
+
def apply_rotary_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
|
33 |
+
"""
|
34 |
+
Apply rotary position embeddings to the input tensor
|
35 |
+
Args:
|
36 |
+
x: Input tensor of shape (batch_size, seq_len, num_heads, head_dim)
|
37 |
+
sin: Sine tensor of shape (seq_len, head_dim//2)
|
38 |
+
cos: Cosine tensor of shape (seq_len, head_dim//2)
|
39 |
+
Returns:
|
40 |
+
Tensor with rotary position embeddings applied
|
41 |
+
"""
|
42 |
+
# Reshape x to split last dimension in half
|
43 |
+
x_reshape = x.float().reshape(*x.shape[:-1], -1, 2)
|
44 |
+
# Extract even and odd dimensions
|
45 |
+
x1, x2 = x_reshape[..., 0], x_reshape[..., 1]
|
46 |
+
|
47 |
+
# Reshape sin and cos for broadcasting
|
48 |
+
sin = sin.view(1, sin.shape[0], 1, sin.shape[1]) # (1, seq_len, 1, dim//2)
|
49 |
+
cos = cos.view(1, cos.shape[0], 1, cos.shape[1]) # (1, seq_len, 1, dim//2)
|
50 |
+
|
51 |
+
# Apply rotation using the rotation matrix multiplication
|
52 |
+
result = torch.stack([
|
53 |
+
x1 * cos - x2 * sin,
|
54 |
+
x2 * cos + x1 * sin
|
55 |
+
], dim=-1)
|
56 |
+
|
57 |
+
return result.flatten(-2) # Flatten last 2 dimensions
|
58 |
+
|
59 |
+
class LlamaAttention(nn.Module):
|
60 |
+
def __init__(self, dim: int, num_heads: int, num_kv_heads: Optional[int] = None, max_position_embeddings=2048):
|
61 |
+
super().__init__()
|
62 |
+
self.num_heads = num_heads
|
63 |
+
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
64 |
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
65 |
+
self.head_dim = dim // num_heads
|
66 |
+
self.scale = 1.0 / math.sqrt(self.head_dim)
|
67 |
+
|
68 |
+
# self.q_proj = nn.Linear(dim, dim, bias=False)
|
69 |
+
# self.k_proj = nn.Linear(dim, dim, bias=False)
|
70 |
+
# self.v_proj = nn.Linear(dim, dim, bias=False)
|
71 |
+
# Adjust projections for GQA
|
72 |
+
self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, D) or (B, T, H * D/H)
|
73 |
+
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, H_kv * D/H)
|
74 |
+
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False) # (B, T, D) -> (B, T, H_kv * D/H)
|
75 |
+
self.o_proj = nn.Linear(dim, dim, bias=False)
|
76 |
+
# self.o_proj.NANGPT_SCALE_INIT = 1 TODO do we need weight initialization scaling?
|
77 |
+
|
78 |
+
# Cache attributes
|
79 |
+
self.k_cache = None
|
80 |
+
self.v_cache = None
|
81 |
+
self.cache_seq_len = 0
|
82 |
+
|
83 |
+
# Precompute sin and cos for all positions
|
84 |
+
self.sin, self.cos = precompute_rotary_emb(self.head_dim, max_position_embeddings)
|
85 |
+
|
86 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False):
|
87 |
+
batch_size, seq_len, _ = x.shape
|
88 |
+
|
89 |
+
# Project inputs
|
90 |
+
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
91 |
+
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
92 |
+
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
93 |
+
|
94 |
+
# Get rotary embeddings for the new tokens
|
95 |
+
# sin = self.sin[self.cache_seq_len:self.cache_seq_len + seq_len].to(x.device)
|
96 |
+
# cos = self.cos[self.cache_seq_len:self.cache_seq_len + seq_len].to(x.device)
|
97 |
+
sin = self.sin[:seq_len].to(x.device)
|
98 |
+
cos = self.cos[:seq_len].to(x.device)
|
99 |
+
|
100 |
+
# Apply rotary embeddings
|
101 |
+
q = apply_rotary_emb(q, sin, cos)
|
102 |
+
k = apply_rotary_emb(k, sin, cos)
|
103 |
+
|
104 |
+
# Handle KV caching
|
105 |
+
# if use_cache:
|
106 |
+
# if self.k_cache is None:
|
107 |
+
# # Initialize cache if empty
|
108 |
+
# self.k_cache = k
|
109 |
+
# self.v_cache = v
|
110 |
+
# else:
|
111 |
+
# # Concatenate new KV with cached KV
|
112 |
+
# self.k_cache = torch.cat([self.k_cache, k], dim=1)
|
113 |
+
# self.v_cache = torch.cat([self.v_cache, v], dim=1)
|
114 |
+
|
115 |
+
# # Use concatenated KV pairs
|
116 |
+
# k = self.k_cache
|
117 |
+
# v = self.v_cache
|
118 |
+
|
119 |
+
# # Update cache sequence length
|
120 |
+
# self.cache_seq_len += seq_len
|
121 |
+
|
122 |
+
# Reshape for attention computation
|
123 |
+
q = q.transpose(1, 2)
|
124 |
+
k = k.transpose(1, 2)
|
125 |
+
v = v.transpose(1, 2)
|
126 |
+
|
127 |
+
# Handle GQA (Grouped Query Attention)
|
128 |
+
if self.num_queries_per_kv > 1:
|
129 |
+
k = k.unsqueeze(2).expand(-1, -1, self.num_queries_per_kv, -1, -1)
|
130 |
+
v = v.unsqueeze(2).expand(-1, -1, self.num_queries_per_kv, -1, -1)
|
131 |
+
k = k.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
132 |
+
v = v.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
133 |
+
|
134 |
+
# Compute attention
|
135 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
136 |
+
|
137 |
+
if mask is not None:
|
138 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
139 |
+
|
140 |
+
attn = F.softmax(scores, dim=-1)
|
141 |
+
out = torch.matmul(attn, v)
|
142 |
+
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
143 |
+
|
144 |
+
# Speed up - Flash Attention (calculation happens in GPU sram and not GPU RAM) TODO Not sure how to apply this in group query attention?
|
145 |
+
# out F.scaled_dot_product_attention(q, k, v, is_causal = True)
|
146 |
+
|
147 |
+
return self.o_proj(out)
|
148 |
+
|
149 |
+
def clear_cache(self):
|
150 |
+
self.k_cache = None
|
151 |
+
self.v_cache = None
|
152 |
+
self.cache_seq_len = 0
|
153 |
+
|
154 |
+
class LlamaFFN(nn.Module):
|
155 |
+
def __init__(self, dim: int, hidden_dim: int):
|
156 |
+
super().__init__()
|
157 |
+
self.gate = nn.Linear(dim, hidden_dim, bias=False)
|
158 |
+
self.up = nn.Linear(dim, hidden_dim, bias=False)
|
159 |
+
self.down = nn.Linear(hidden_dim, dim, bias=False)
|
160 |
+
# self.down.NANGPT_SCALE_INIT = 1 # TODO do we need weight initialization scaling - Optimization ?
|
161 |
+
self.act_fn = nn.SiLU() # SwiGLU activation function
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
return self.down(self.act_fn(self.gate(x)) * self.up(x))
|
165 |
+
|
166 |
+
class LlamaBlock(nn.Module):
|
167 |
+
def __init__(self, config):
|
168 |
+
# nn_embed or dim is the dimension of the input to the block
|
169 |
+
super().__init__()
|
170 |
+
self.attention = LlamaAttention(
|
171 |
+
config.nn_embed,
|
172 |
+
config.num_attention_heads,
|
173 |
+
config.num_key_value_heads,
|
174 |
+
config.max_sequence_len
|
175 |
+
)
|
176 |
+
self.feed_forward = LlamaFFN(config.nn_embed, config.ffn_intermediate_size)
|
177 |
+
self.attention_norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
|
178 |
+
self.ffn_norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
|
179 |
+
|
180 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False):
|
181 |
+
x = x + self.attention(self.attention_norm(x), mask, use_cache)
|
182 |
+
x = x + self.feed_forward(self.ffn_norm(x))
|
183 |
+
return x
|
184 |
+
|
185 |
+
class SmolLM2(nn.Module):
|
186 |
+
def __init__(self, config):
|
187 |
+
super().__init__()
|
188 |
+
# Normal Embedding (position embedding will be part of Attention layer)
|
189 |
+
self.embedding = nn.Embedding(config.vocab_size, config.nn_embed)
|
190 |
+
|
191 |
+
# total num_hidden_layers Blocks (Each block has attention and feedforward layer)
|
192 |
+
self.layers = nn.ModuleList([
|
193 |
+
LlamaBlock(config) for _ in range(config.num_hidden_layers)
|
194 |
+
])
|
195 |
+
self.norm = nn.RMSNorm(config.nn_embed, eps=config.rms_norm_eps)
|
196 |
+
# final layer returning the logits of size (batch_size, vocab_size)
|
197 |
+
self.lm_head = nn.Linear(config.nn_embed, config.vocab_size, bias=False)
|
198 |
+
|
199 |
+
# Optimization Weight sharing between lm_head and embedding
|
200 |
+
self.lm_head.weight = self.embedding.weight
|
201 |
+
|
202 |
+
# Initialize weights
|
203 |
+
self.apply(self._init_weights)
|
204 |
+
|
205 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, use_cache: bool = False, targets: Optional[torch.Tensor] = None):
|
206 |
+
if (mask is None):
|
207 |
+
mask = self.create_causal_mask(x.shape[1], device=x.device)
|
208 |
+
x = self.embedding(x)
|
209 |
+
for layer in self.layers:
|
210 |
+
x = layer(x, mask, use_cache)
|
211 |
+
x = self.norm(x)
|
212 |
+
logits = self.lm_head(x)
|
213 |
+
if targets is not None:
|
214 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
|
215 |
+
return logits, loss
|
216 |
+
return logits
|
217 |
+
|
218 |
+
# Linear layers (attention projections, FFN layers, lm_head) are initialized from N(0, 0.02)
|
219 |
+
# Embedding layer is initialized from N(0, 0.02)
|
220 |
+
# All RMSNorm weights are initialized to 1.0
|
221 |
+
def _init_weights(self, module):
|
222 |
+
if isinstance(module, nn.Linear):
|
223 |
+
std = 0.02
|
224 |
+
if hasattr(module, 'NANGPT_SCALE_INIT'):
|
225 |
+
std *= (2 * self.config.n_layer) ** -0.5
|
226 |
+
torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
|
227 |
+
if module.bias is not None:
|
228 |
+
torch.nn.init.zeros_(module.bias)
|
229 |
+
elif isinstance(module, nn.RMSNorm):
|
230 |
+
torch.nn.init.ones_(module.weight)
|
231 |
+
elif isinstance(module, nn.Embedding):
|
232 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
|
233 |
+
|
234 |
+
def clear_cache(self):
|
235 |
+
"""Clear KV cache in all attention layers"""
|
236 |
+
for layer in self.layers:
|
237 |
+
layer.attention.clear_cache()
|
238 |
+
|
239 |
+
def create_causal_mask(self, seq_len, device):
|
240 |
+
"""Creates a causal attention mask where each position can only attend to previous positions"""
|
241 |
+
# Create lower triangular matrix (including diagonal)
|
242 |
+
# mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
|
243 |
+
# mask = torch.triu(torch.ones(1, 1, seq_len, seq_len), diagonal=1).bool()
|
244 |
+
# # Invert and convert to float
|
245 |
+
# return (~mask).float()
|
246 |
+
return torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len).to(device)
|
247 |
+
|
248 |
+
@torch.no_grad()
|
249 |
+
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 20,
|
250 |
+
temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
|
251 |
+
"""
|
252 |
+
Generate text using the model
|
253 |
+
Args:
|
254 |
+
input_ids: Starting token ids (B, T)
|
255 |
+
max_new_tokens: Number of tokens to generate
|
256 |
+
temperature: Controls randomness (1.0 = neutral, <1.0 = more deterministic, >1.0 = more random)
|
257 |
+
top_k: Number of highest probability tokens to consider for sampling
|
258 |
+
Returns:
|
259 |
+
Generated token ids (B, T+max_new_tokens)
|
260 |
+
"""
|
261 |
+
batch_size, seq_len = input_ids.shape
|
262 |
+
|
263 |
+
# clear existing KV caching
|
264 |
+
self.clear_cache()
|
265 |
+
|
266 |
+
# Create a new tensor to store the generated tokens
|
267 |
+
input_ids = torch.cat([input_ids, torch.zeros((batch_size, max_new_tokens),
|
268 |
+
dtype=torch.long, device=input_ids.device)], dim=1)
|
269 |
+
|
270 |
+
# Generate tokens one at a time
|
271 |
+
for idx in range(max_new_tokens):
|
272 |
+
# print(f"Generating token {idx+1} of {max_new_tokens}")
|
273 |
+
|
274 |
+
# Get the current sequence length including cached tokens
|
275 |
+
current_seq_len = seq_len + idx
|
276 |
+
|
277 |
+
next_mask = self.create_causal_mask(current_seq_len, device=input_ids.device)
|
278 |
+
|
279 |
+
# Create mask that includes both the current input and cached tokens
|
280 |
+
# if idx == 0:
|
281 |
+
# # First iteration - create mask for the full input sequence
|
282 |
+
# next_mask = self.create_causal_mask(current_seq_len, device=input_ids.device)
|
283 |
+
# else:
|
284 |
+
# # Subsequent iterations - create mask for the new token attending to all previous tokens
|
285 |
+
# next_mask = torch.ones((1, 1, 1, current_seq_len), device=input_ids.device)
|
286 |
+
|
287 |
+
# Process including the new tokens
|
288 |
+
logits = self(input_ids[:, :current_seq_len], next_mask, use_cache=False)
|
289 |
+
|
290 |
+
# Get the last token's logits
|
291 |
+
next_token_logits = logits[:, -1, :] / temperature
|
292 |
+
|
293 |
+
# Apply top-k filtering
|
294 |
+
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
|
295 |
+
probs = F.softmax(top_k_logits, dim=-1)
|
296 |
+
|
297 |
+
# Sample from the filtered distribution
|
298 |
+
next_token = top_k_indices[
|
299 |
+
torch.arange(batch_size, device=input_ids.device),
|
300 |
+
torch.multinomial(probs, num_samples=1).squeeze(1)
|
301 |
+
]
|
302 |
+
|
303 |
+
# Update input_ids with the new token
|
304 |
+
input_ids[:, current_seq_len] = next_token
|
305 |
+
|
306 |
+
return input_ids
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.30.0
|
3 |
+
datasets>=2.12.0
|
4 |
+
numpy>=1.24.0
|
5 |
+
tqdm>=4.65.0
|
6 |
+
huggingface-hub>=0.16.0
|
7 |
+
tokenizers>=0.13.0
|
8 |
+
gradio>=4.0.0
|
utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def get_device(seed = 1):
|
4 |
+
# Seed is to generate the same random data for each run
|
5 |
+
# For reproducibility
|
6 |
+
torch.manual_seed(seed)
|
7 |
+
|
8 |
+
# Set device
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
10 |
+
|
11 |
+
if torch.cuda.is_available():
|
12 |
+
print(f"[INFO] GPU: {torch.cuda.get_device_name(0)}")
|
13 |
+
print(f"[INFO] CUDA Version: {torch.version.cuda}\n")
|
14 |
+
torch.cuda.manual_seed(seed)
|
15 |
+
|
16 |
+
if not torch.backends.mps.is_available():
|
17 |
+
if not torch.backends.mps.is_built():
|
18 |
+
print("MPS not available because the current PyTorch install was not "
|
19 |
+
"built with MPS enabled.")
|
20 |
+
else:
|
21 |
+
print("MPS not available because the current MacOS version is not 12.3+ "
|
22 |
+
"and/or you do not have an MPS-enabled device on this machine.")
|
23 |
+
else:
|
24 |
+
torch.mps.manual_seed(seed)
|
25 |
+
|
26 |
+
return device
|