Upload folder using huggingface_hub
Browse files- .gitattributes +3 -0
- LICENSE +21 -0
- README.md +160 -0
- checkpoints/best_model.pt +3 -0
- checkpoints/checkpoint_0.pt +3 -0
- checkpoints/checkpoint_1000.pt +3 -0
- checkpoints/checkpoint_10000.pt +3 -0
- checkpoints/checkpoint_11000.pt +3 -0
- checkpoints/checkpoint_12000.pt +3 -0
- checkpoints/checkpoint_13000.pt +3 -0
- checkpoints/checkpoint_14000.pt +3 -0
- checkpoints/checkpoint_15000.pt +3 -0
- checkpoints/checkpoint_16000.pt +3 -0
- checkpoints/checkpoint_17000.pt +3 -0
- checkpoints/checkpoint_18000.pt +3 -0
- checkpoints/checkpoint_19000.pt +3 -0
- checkpoints/checkpoint_2000.pt +3 -0
- checkpoints/checkpoint_20000.pt +3 -0
- checkpoints/checkpoint_3000.pt +3 -0
- checkpoints/checkpoint_4000.pt +3 -0
- checkpoints/checkpoint_5000.pt +3 -0
- checkpoints/checkpoint_6000.pt +3 -0
- checkpoints/checkpoint_7000.pt +3 -0
- checkpoints/checkpoint_8000.pt +3 -0
- checkpoints/checkpoint_9000.pt +3 -0
- checkpoints/final_model.pt +3 -0
- deepseek-arch.png +3 -0
- deepseek_training_metrics.png +3 -0
- process_data.py +16 -0
- requirements.txt +13 -0
- setup.sh +313 -0
- src/data/__pycache__/data_processor.cpython-310.pyc +0 -0
- src/data/data_processor.py +287 -0
- src/data/finetune.bin +3 -0
- src/data/train.bin +3 -0
- src/data/validation.bin +3 -0
- src/generate.py +281 -0
- src/model/__pycache__/deepseek.cpython-310.pyc +0 -0
- src/model/deepseek.py +513 -0
- src/run_training.py +307 -0
- src/training/__pycache__/trainer.cpython-310.pyc +0 -0
- src/training/trainer.py +408 -0
- training_metrics.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
deepseek-arch.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
deepseek_training_metrics.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
training_metrics.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 IdeaWeaver AI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DeepSeek-Children-Stories
|
2 |
+
|
3 |
+
A state-of-the-art DeepSeek model optimized for children's story generation, featuring advanced architecture with just ~15-18M parameters.
|
4 |
+
|
5 |
+
## Architecture Highlights
|
6 |
+
|
7 |
+

|
8 |
+
|
9 |
+
- **Multihead Latent Attention (MLA)** - DeepSeek's efficient attention mechanism
|
10 |
+
- **Mixture of Experts (MoE)** - 4 experts with top-2 routing for increased capacity
|
11 |
+
- **Multi-token Prediction** - Predicts next 2 tokens simultaneously for efficiency
|
12 |
+
- **Rotary Positional Encodings (RoPE)** - Better position understanding
|
13 |
+
|
14 |
+
## Model Specifications
|
15 |
+
|
16 |
+
- **Parameters**: ~15-18M (6 layers, 8 heads, 512 embedding dim)
|
17 |
+
- **Context Window**: 1024 tokens
|
18 |
+
- **Vocabulary**: GPT-2 compatible (50,257 tokens)
|
19 |
+
- **Training Data**: 2,000+ children's stories from Hugging Face
|
20 |
+
|
21 |
+
## Hardware Used
|
22 |
+
|
23 |
+
Training was performed on the following hardware:
|
24 |
+
|
25 |
+
- **GPU**: NVIDIA RTX 4090 (24 GB VRAM)
|
26 |
+
- **RAM**: 41 GB
|
27 |
+
- **CPU**: 6 vCPU
|
28 |
+
|
29 |
+
## Quick Start
|
30 |
+
|
31 |
+
### Installation
|
32 |
+
|
33 |
+
```bash
|
34 |
+
# Clone the repository
|
35 |
+
git clone https://github.com/ideaweaver-ai/DeepSeek-Children-Stories-15M-model.git
|
36 |
+
cd DeepSeek-Children-Stories-15M-model
|
37 |
+
|
38 |
+
# Install dependencies
|
39 |
+
pip install -r requirements.txt
|
40 |
+
|
41 |
+
# Setup the environment
|
42 |
+
chmod +x setup.sh
|
43 |
+
./setup.sh
|
44 |
+
```
|
45 |
+
|
46 |
+
### Training
|
47 |
+
|
48 |
+
```bash
|
49 |
+
# Start training
|
50 |
+
python src/run_training.py
|
51 |
+
|
52 |
+
# With custom parameters
|
53 |
+
python src/run_training.py --batch-size 8 --max-iters 10000 --learning-rate 6e-4
|
54 |
+
```
|
55 |
+
|
56 |
+
### Generation
|
57 |
+
|
58 |
+
```bash
|
59 |
+
# Generate stories
|
60 |
+
python src/generate.py --prompt "Once upon a time, there was a brave little mouse"
|
61 |
+
|
62 |
+
# With custom parameters
|
63 |
+
python src/generate.py --prompt "A magical forest adventure" --max-tokens 200 --temperature 0.8
|
64 |
+
```
|
65 |
+
|
66 |
+
## 📖 Example Output
|
67 |
+
|
68 |
+
Here's an example of a story generated by the model:
|
69 |
+
|
70 |
+
**Prompt**: "Once upon a time"
|
71 |
+
|
72 |
+
**Generated Story**:
|
73 |
+
```
|
74 |
+
it was a bright, sunny day, and lily and her little brother max were playing in their backyard. they found a piece of paper with two sentence written on it. "let's make sense of some of these sentences," said max, pointing to the first sentence. "these people are playing on the grass," "but i don't know," replied lily. she thought for a moment. "maybe they only talk with the others or not, right?" she asked. max nodded. "yeah, and what about 'he', 'he', 'an', 'man', and 'man'?" lily explained, "it means they're playing with their dogs. but they don't say anything about someone talking." max asked, "but what about the others? we don't talk to each other!" lily thought for a moment before answering, "that's right! sometimes, people try to talk to each other. when we talk about something, we need to tell others
|
75 |
+
```
|
76 |
+
|
77 |
+
## Training Metrics
|
78 |
+
|
79 |
+
<p align="center">
|
80 |
+
<img src="training_metrics.png" alt="Training and Validation Loss and Learning Rate" width="800"/>
|
81 |
+
</p>
|
82 |
+
|
83 |
+
## Configuration
|
84 |
+
|
85 |
+
The model can be configured through command-line arguments:
|
86 |
+
|
87 |
+
```bash
|
88 |
+
# Model configuration
|
89 |
+
--n-layer 6 # Number of transformer layers
|
90 |
+
--n-head 8 # Number of attention heads
|
91 |
+
--n-embd 512 # Embedding dimension
|
92 |
+
--block-size 1024 # Context window size
|
93 |
+
|
94 |
+
# Training configuration
|
95 |
+
--batch-size 12 # Batch size
|
96 |
+
--max-iters 20000 # Maximum training iterations
|
97 |
+
--learning-rate 6e-4 # Learning rate
|
98 |
+
--eval-interval 1000 # Evaluation interval
|
99 |
+
|
100 |
+
# Advanced features
|
101 |
+
--moe-experts 4 # Number of MoE experts
|
102 |
+
--multi-token 2 # Multi-token prediction
|
103 |
+
```
|
104 |
+
|
105 |
+
## 🤗 Model Available on Hugging Face
|
106 |
+
|
107 |
+
The trained model is now available on Hugging Face Hub! You can use it directly:
|
108 |
+
|
109 |
+
**Model**: [lakhera2023/deepseek-children-stories](https://huggingface.co/lakhera2023/deepseek-children-stories)
|
110 |
+
|
111 |
+
## Features
|
112 |
+
|
113 |
+
### Advanced Architecture
|
114 |
+
- **MLA**: Efficient attention with shared key-value heads
|
115 |
+
- **MoE**: Mixture of experts for increased model capacity
|
116 |
+
- **Multi-token Prediction**: Simultaneous prediction of multiple tokens
|
117 |
+
- **RoPE**: Rotary positional encodings for better position understanding
|
118 |
+
|
119 |
+
### Training Optimizations
|
120 |
+
- Mixed precision training with gradient scaling
|
121 |
+
- PyTorch 2.0 compilation for speed
|
122 |
+
- Automatic checkpointing and model saving
|
123 |
+
- MoE auxiliary loss for load balancing
|
124 |
+
|
125 |
+
### Story Generation
|
126 |
+
- Creative and engaging children's stories
|
127 |
+
- Moral lessons and educational content
|
128 |
+
- Age-appropriate language and themes
|
129 |
+
- Consistent character development
|
130 |
+
|
131 |
+
## Performance
|
132 |
+
|
133 |
+
The model achieves:
|
134 |
+
- Efficient training with ~2.24GB GPU memory usage
|
135 |
+
- Fast inference for real-time story generation
|
136 |
+
- High-quality output suitable for children
|
137 |
+
- Scalable architecture for different use cases
|
138 |
+
|
139 |
+
|
140 |
+
## Contributing
|
141 |
+
|
142 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
143 |
+
|
144 |
+
## License
|
145 |
+
|
146 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
147 |
+
|
148 |
+
## Acknowledgments
|
149 |
+
|
150 |
+
- DeepSeek team for the original architecture
|
151 |
+
- Hugging Face for the children's stories dataset
|
152 |
+
- PyTorch team for the excellent framework
|
153 |
+
|
154 |
+
## Links
|
155 |
+
|
156 |
+
- **GitHub**: https://github.com/ideaweaver-ai/DeepSeek-Children-Stories-15M-model
|
157 |
+
|
158 |
+
---
|
159 |
+
|
160 |
+
⭐ **Star this repository if you think Advanced Architecture + Tiny Models can do Big Things!**
|
checkpoints/best_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6ed045ca7f558caa89ae9345afc8f279d85838fd88ee066525d3e0497c2b1903
|
3 |
+
size 943196083
|
checkpoints/checkpoint_0.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:596d20b3003def973b7ed263f8a3bc8178d4527182e3d8b3d73b9f08a880487a
|
3 |
+
size 942850634
|
checkpoints/checkpoint_1000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef3cd05895dab3ad7b9605e99fd44437886e504064074283f1c01d35d4bdde2c
|
3 |
+
size 942871055
|
checkpoints/checkpoint_10000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e932a30f92cf27cc5cc90220c11cca2a9ba26ca662e1dcc9fbe0c9e56947ea8
|
3 |
+
size 943036196
|
checkpoints/checkpoint_11000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:741587b23e54005aa59122359d1645f9647509283315b1e688cecaa244e14d37
|
3 |
+
size 943054443
|
checkpoints/checkpoint_12000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68e0cfb7e5e1c841ddb4f63a4e35f609ed3ae33a71b2fcae3a42713807716288
|
3 |
+
size 943072690
|
checkpoints/checkpoint_13000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8faec5132a9233ae661cf85447612451692fd568f55b77a331ef372ce7298b1d
|
3 |
+
size 943091001
|
checkpoints/checkpoint_14000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21000abe6d1ce61af34a3a348379b75ceb32c5b3e34934b6bf7f0f56900d57da
|
3 |
+
size 943109248
|
checkpoints/checkpoint_15000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:989c4f47419fbdb682d798c440eb8fa657f28717fe46a8a977cff5438daf6a32
|
3 |
+
size 943127495
|
checkpoints/checkpoint_16000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d821499c970ea1176632b80fa534c13cd941991e7219a2e4b98d443f733ea8d9
|
3 |
+
size 943145742
|
checkpoints/checkpoint_17000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4239cfcac50b335caba1be84fd182100b30a43e199558bdf4320eb6a569a355
|
3 |
+
size 943164053
|
checkpoints/checkpoint_18000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a95a2b08118aafbe1649a94cdcd25c092c9e3337fac55ffb176af1c829a997cd
|
3 |
+
size 943182300
|
checkpoints/checkpoint_19000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc22f5fc4bbfce05ab2b422929f31666dfc5536e8ffe6cf066ae1aa64feb3215
|
3 |
+
size 943200547
|
checkpoints/checkpoint_2000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7031946ec00c7c3433b130ee42a593a8295bb6f9036bef87fa922ce8ef7d3c0
|
3 |
+
size 942889365
|
checkpoints/checkpoint_20000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f2abb8b5c12e411ecc662223dcc00ebb8e76bb0e492bf3782d03b4a5f0d5298
|
3 |
+
size 943218531
|
checkpoints/checkpoint_3000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e676cb28f0aa2fe17da68bb5df5af422601606c11141e56bd26dfe126238971
|
3 |
+
size 942907611
|
checkpoints/checkpoint_4000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f363f026b391e07f1610299450a74a32e217a782cb867a4b4307bb2a7d75559
|
3 |
+
size 942925857
|
checkpoints/checkpoint_5000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:20a34a77821367b9708831979f89af4f32b45bac592466cc04e778ee99a973b5
|
3 |
+
size 942944103
|
checkpoints/checkpoint_6000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:935b2adcd4575c3ad9732e898afdff258fd36a7571b7e97ca2796d116526ee0a
|
3 |
+
size 942962413
|
checkpoints/checkpoint_7000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:390a641bc6d086b41d16c8ec6888fc79f2a7970525f6192528ea35a1b4f44728
|
3 |
+
size 942980659
|
checkpoints/checkpoint_8000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90b2614013ed2c2df7078052867a9c19bdd36acc386f3fa2ed17e3013bdbdfb4
|
3 |
+
size 942998905
|
checkpoints/checkpoint_9000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1cba9b64ec3237fe1c33366159a883c344d103d35b4cced89d514617f89278c2
|
3 |
+
size 943017151
|
checkpoints/final_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62e5cf559d3abdb97036c62c7daa25e61256d98f6cbab403d692ca648db2e078
|
3 |
+
size 942849715
|
deepseek-arch.png
ADDED
![]() |
Git LFS Details
|
deepseek_training_metrics.png
ADDED
![]() |
Git LFS Details
|
process_data.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
# Add the src directory to Python path
|
5 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
6 |
+
|
7 |
+
from data.data_processor import DeepSeekDataProcessor
|
8 |
+
|
9 |
+
def main():
|
10 |
+
print("[+] Processing dataset into binary files...")
|
11 |
+
processor = DeepSeekDataProcessor()
|
12 |
+
processor.prepare_dataset()
|
13 |
+
print("[+] Data processing completed successfully!")
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
transformers>=4.30.0
|
3 |
+
datasets>=2.12.0
|
4 |
+
tiktoken>=0.5.0
|
5 |
+
numpy>=1.20.0
|
6 |
+
tqdm>=4.65.0
|
7 |
+
matplotlib>=3.5.0
|
8 |
+
peft>=0.4.0
|
9 |
+
accelerate>=0.20.0
|
10 |
+
bitsandbytes>=0.41.0
|
11 |
+
huggingface_hub>=0.16.0
|
12 |
+
wandb>=0.15.0
|
13 |
+
psutil>=5.8.0
|
setup.sh
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Colors for output
|
4 |
+
GREEN='\033[0;32m'
|
5 |
+
RED='\033[0;31m'
|
6 |
+
YELLOW='\033[1;33m'
|
7 |
+
BLUE='\033[0;34m'
|
8 |
+
NC='\033[0m' # No Color
|
9 |
+
|
10 |
+
# Default configuration
|
11 |
+
PROJECT_ROOT="${PROJECT_ROOT:-$(pwd)}"
|
12 |
+
VENV_PATH="${VENV_PATH:-${PROJECT_ROOT}/venv}"
|
13 |
+
CHECKPOINT_DIR="${CHECKPOINT_DIR:-${PROJECT_ROOT}/checkpoints}"
|
14 |
+
LORA_CHECKPOINT_DIR="${LORA_CHECKPOINT_DIR:-${PROJECT_ROOT}/lora_checkpoints}"
|
15 |
+
REQUIRED_SPACE_MB="${REQUIRED_SPACE_MB:-2000}"
|
16 |
+
|
17 |
+
# Function to print status messages
|
18 |
+
print_status() {
|
19 |
+
echo -e "${GREEN}[+] $1${NC}"
|
20 |
+
}
|
21 |
+
|
22 |
+
print_error() {
|
23 |
+
echo -e "${RED}[-] $1${NC}"
|
24 |
+
}
|
25 |
+
|
26 |
+
print_warning() {
|
27 |
+
echo -e "${YELLOW}[!] $1${NC}"
|
28 |
+
}
|
29 |
+
|
30 |
+
print_info() {
|
31 |
+
echo -e "${BLUE}[i] $1${NC}"
|
32 |
+
}
|
33 |
+
|
34 |
+
# Function to handle errors
|
35 |
+
handle_error() {
|
36 |
+
print_error "$1"
|
37 |
+
exit 1
|
38 |
+
}
|
39 |
+
|
40 |
+
# Function to check if a command exists
|
41 |
+
command_exists() {
|
42 |
+
command -v "$1" &> /dev/null
|
43 |
+
}
|
44 |
+
|
45 |
+
# Function to check disk space
|
46 |
+
check_disk_space() {
|
47 |
+
local available_space_mb=$(df -m . | awk 'NR==2 {print $4}')
|
48 |
+
if [ "$available_space_mb" -lt "$REQUIRED_SPACE_MB" ]; then
|
49 |
+
print_warning "Low disk space. Only ${available_space_mb}MB available, ${REQUIRED_SPACE_MB}MB required."
|
50 |
+
return 1
|
51 |
+
fi
|
52 |
+
return 0
|
53 |
+
}
|
54 |
+
|
55 |
+
# Function to check GPU memory
|
56 |
+
check_gpu_memory() {
|
57 |
+
if command_exists nvidia-smi; then
|
58 |
+
local total_memory=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits)
|
59 |
+
local free_memory=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits)
|
60 |
+
local used_memory=$((total_memory - free_memory))
|
61 |
+
print_status "GPU Memory: ${used_memory}MB used, ${free_memory}MB free of ${total_memory}MB total"
|
62 |
+
|
63 |
+
# Check if we have enough memory for training
|
64 |
+
if [ "$free_memory" -lt 4000 ]; then
|
65 |
+
print_warning "Low GPU memory. Consider reducing batch size or model size."
|
66 |
+
fi
|
67 |
+
else
|
68 |
+
print_warning "nvidia-smi not found. GPU training may not be available."
|
69 |
+
fi
|
70 |
+
}
|
71 |
+
|
72 |
+
# Function to create project structure
|
73 |
+
create_project_structure() {
|
74 |
+
print_status "Creating project structure..."
|
75 |
+
mkdir -p "${PROJECT_ROOT}/src/data" \
|
76 |
+
"${PROJECT_ROOT}/src/model" \
|
77 |
+
"${PROJECT_ROOT}/src/training" \
|
78 |
+
"${PROJECT_ROOT}/src/inference" \
|
79 |
+
"${CHECKPOINT_DIR}" \
|
80 |
+
"${LORA_CHECKPOINT_DIR}" || handle_error "Failed to create directories"
|
81 |
+
}
|
82 |
+
|
83 |
+
# Function to setup virtual environment
|
84 |
+
setup_virtual_env() {
|
85 |
+
print_status "Creating virtual environment..."
|
86 |
+
python3 -m venv "${VENV_PATH}" || handle_error "Failed to create virtual environment"
|
87 |
+
source "${VENV_PATH}/bin/activate" || handle_error "Failed to activate virtual environment"
|
88 |
+
|
89 |
+
print_status "Installing dependencies..."
|
90 |
+
pip install --upgrade pip
|
91 |
+
pip install -r requirements.txt || handle_error "Failed to install requirements"
|
92 |
+
}
|
93 |
+
|
94 |
+
# Function to prepare dataset
|
95 |
+
prepare_dataset() {
|
96 |
+
print_status "Preparing dataset..."
|
97 |
+
cd "${PROJECT_ROOT}" || handle_error "Failed to change to project directory"
|
98 |
+
|
99 |
+
# Create a Python script to process the data
|
100 |
+
cat > process_data.py << 'EOF'
|
101 |
+
import os
|
102 |
+
import sys
|
103 |
+
|
104 |
+
# Add the src directory to Python path
|
105 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
106 |
+
|
107 |
+
from data.data_processor import DeepSeekDataProcessor
|
108 |
+
|
109 |
+
def main():
|
110 |
+
print("[+] Processing dataset into binary files...")
|
111 |
+
processor = DeepSeekDataProcessor()
|
112 |
+
processor.prepare_dataset()
|
113 |
+
print("[+] Data processing completed successfully!")
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
main()
|
117 |
+
EOF
|
118 |
+
|
119 |
+
# Run the data processing script
|
120 |
+
python3 process_data.py || handle_error "Failed to process dataset"
|
121 |
+
|
122 |
+
# Verify the files were created
|
123 |
+
if [ ! -f "${PROJECT_ROOT}/src/data/train.bin" ] || [ ! -f "${PROJECT_ROOT}/src/data/validation.bin" ]; then
|
124 |
+
handle_error "Data processing failed - required files not created"
|
125 |
+
fi
|
126 |
+
}
|
127 |
+
|
128 |
+
# Function to train base model
|
129 |
+
train_base_model() {
|
130 |
+
print_status "Starting DeepSeek base model training..."
|
131 |
+
cd "${PROJECT_ROOT}" || handle_error "Failed to change to project directory"
|
132 |
+
|
133 |
+
python3 src/run_training.py \
|
134 |
+
--batch-size "${BATCH_SIZE:-12}" \
|
135 |
+
--max-iters "${MAX_ITERS:-20000}" \
|
136 |
+
--eval-interval "${EVAL_INTERVAL:-1000}" \
|
137 |
+
--eval-iters "${EVAL_ITERS:-200}" \
|
138 |
+
--learning-rate "${LEARNING_RATE:-6e-4}" \
|
139 |
+
--weight-decay "${WEIGHT_DECAY:-0.1}" \
|
140 |
+
--warmup-iters "${WARMUP_ITERS:-2000}" \
|
141 |
+
--lr-decay-iters "${LR_DECAY_ITERS:-20000}" \
|
142 |
+
--min-lr "${MIN_LR:-6e-5}" \
|
143 |
+
--moe-experts "${MOE_EXPERTS:-4}" \
|
144 |
+
--multi-token "${MULTI_TOKEN:-2}" || handle_error "Base model training failed"
|
145 |
+
}
|
146 |
+
|
147 |
+
# Function to perform LoRA finetuning
|
148 |
+
finetune_lora() {
|
149 |
+
while true; do
|
150 |
+
read -p "Do you want to perform LoRA finetuning? (y/n) " do_finetune
|
151 |
+
case $do_finetune in
|
152 |
+
[Yy]* )
|
153 |
+
print_status "Starting LoRA finetuning..."
|
154 |
+
cd "${PROJECT_ROOT}" || handle_error "Failed to change to project directory"
|
155 |
+
|
156 |
+
# Create LoRA finetuning script
|
157 |
+
cat > finetune_lora.py << 'EOF'
|
158 |
+
import torch
|
159 |
+
import os
|
160 |
+
import sys
|
161 |
+
sys.path.append('src')
|
162 |
+
|
163 |
+
from model.deepseek import DeepSeek, DeepSeekConfig
|
164 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
165 |
+
|
166 |
+
def main():
|
167 |
+
print("Loading base model...")
|
168 |
+
checkpoint = torch.load('checkpoints/best_model.pt', map_location='cuda' if torch.cuda.is_available() else 'cpu')
|
169 |
+
model = DeepSeek(checkpoint['config'])
|
170 |
+
model.load_state_dict(checkpoint['model'])
|
171 |
+
|
172 |
+
# Define LoRA configuration
|
173 |
+
lora_config = LoraConfig(
|
174 |
+
task_type=TaskType.CAUSAL_LM,
|
175 |
+
r=8, # rank
|
176 |
+
lora_alpha=32,
|
177 |
+
lora_dropout=0.1,
|
178 |
+
target_modules=["q_a_proj", "q_b_proj", "kv_a_proj", "kv_b_proj"]
|
179 |
+
)
|
180 |
+
|
181 |
+
# Get PEFT model
|
182 |
+
model = get_peft_model(model, lora_config)
|
183 |
+
model.print_trainable_parameters()
|
184 |
+
|
185 |
+
print("LoRA finetuning setup complete!")
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
main()
|
189 |
+
EOF
|
190 |
+
|
191 |
+
python3 finetune_lora.py || handle_error "LoRA finetuning failed"
|
192 |
+
break
|
193 |
+
;;
|
194 |
+
[Nn]* )
|
195 |
+
print_status "Skipping LoRA finetuning..."
|
196 |
+
break
|
197 |
+
;;
|
198 |
+
* )
|
199 |
+
echo "Please answer 'y' or 'n'"
|
200 |
+
;;
|
201 |
+
esac
|
202 |
+
done
|
203 |
+
}
|
204 |
+
|
205 |
+
# Function to test the trained model
|
206 |
+
test_model() {
|
207 |
+
while true; do
|
208 |
+
read -p "Do you want to test the trained model? (y/n) " do_test
|
209 |
+
case $do_test in
|
210 |
+
[Yy]* )
|
211 |
+
print_status "Testing the trained model..."
|
212 |
+
cd "${PROJECT_ROOT}" || handle_error "Failed to change to project directory"
|
213 |
+
|
214 |
+
# Create test prompts
|
215 |
+
prompts=(
|
216 |
+
"Once upon a time"
|
217 |
+
"In a magical forest"
|
218 |
+
"The little robot"
|
219 |
+
"The brave knight"
|
220 |
+
)
|
221 |
+
|
222 |
+
# Test each prompt
|
223 |
+
for prompt in "${prompts[@]}"; do
|
224 |
+
print_status "Testing with prompt: '$prompt'"
|
225 |
+
python3 src/generate.py \
|
226 |
+
--model-path "${CHECKPOINT_DIR}/best_model.pt" \
|
227 |
+
--prompt "$prompt" \
|
228 |
+
--max-tokens 100 \
|
229 |
+
--temperature 0.8 \
|
230 |
+
--top-k 40
|
231 |
+
echo
|
232 |
+
done
|
233 |
+
break
|
234 |
+
;;
|
235 |
+
[Nn]* )
|
236 |
+
print_status "Skipping model testing..."
|
237 |
+
break
|
238 |
+
;;
|
239 |
+
* )
|
240 |
+
echo "Please answer 'y' or 'n'"
|
241 |
+
;;
|
242 |
+
esac
|
243 |
+
done
|
244 |
+
}
|
245 |
+
|
246 |
+
# Function to show usage information
|
247 |
+
show_usage() {
|
248 |
+
print_info "DeepSeek Children's Stories Model Setup Complete!"
|
249 |
+
print_info ""
|
250 |
+
print_info "Next steps:"
|
251 |
+
print_info "1. Activate virtual environment: source venv/bin/activate"
|
252 |
+
print_info "2. Train the model: python src/run_training.py"
|
253 |
+
print_info "3. Generate stories: python src/generate.py --prompt 'your prompt'"
|
254 |
+
print_info "4. Interactive mode: python src/generate.py --interactive"
|
255 |
+
print_info ""
|
256 |
+
print_info "Model files:"
|
257 |
+
print_info "- Base model: checkpoints/best_model.pt"
|
258 |
+
print_info "- LoRA model: lora_checkpoints/best_lora_model.pt"
|
259 |
+
print_info ""
|
260 |
+
print_info "Configuration options:"
|
261 |
+
print_info "- Adjust model size: --n-layer, --n-head, --n-embd"
|
262 |
+
print_info "- Training parameters: --batch-size, --learning-rate, --max-iters"
|
263 |
+
print_info "- Advanced features: --moe-experts, --multi-token"
|
264 |
+
}
|
265 |
+
|
266 |
+
# Main setup function
|
267 |
+
main() {
|
268 |
+
print_info "DeepSeek Children's Stories Model Setup"
|
269 |
+
print_info "======================================"
|
270 |
+
|
271 |
+
# Check prerequisites
|
272 |
+
if ! command_exists python3; then
|
273 |
+
handle_error "Python 3 is required but not installed"
|
274 |
+
fi
|
275 |
+
|
276 |
+
if ! command_exists pip; then
|
277 |
+
handle_error "pip is required but not installed"
|
278 |
+
fi
|
279 |
+
|
280 |
+
# Check disk space
|
281 |
+
if ! check_disk_space; then
|
282 |
+
print_warning "Continuing with low disk space..."
|
283 |
+
fi
|
284 |
+
|
285 |
+
# Check GPU
|
286 |
+
check_gpu_memory
|
287 |
+
|
288 |
+
# Create project structure
|
289 |
+
create_project_structure
|
290 |
+
|
291 |
+
# Setup virtual environment
|
292 |
+
setup_virtual_env
|
293 |
+
|
294 |
+
# Prepare dataset
|
295 |
+
prepare_dataset
|
296 |
+
|
297 |
+
# Train base model
|
298 |
+
train_base_model
|
299 |
+
|
300 |
+
# Optional LoRA finetuning
|
301 |
+
finetune_lora
|
302 |
+
|
303 |
+
# Optional model testing
|
304 |
+
test_model
|
305 |
+
|
306 |
+
# Show usage information
|
307 |
+
show_usage
|
308 |
+
|
309 |
+
print_status "Setup completed successfully!"
|
310 |
+
}
|
311 |
+
|
312 |
+
# Run main function
|
313 |
+
main "$@"
|
src/data/__pycache__/data_processor.cpython-310.pyc
ADDED
Binary file (8.3 kB). View file
|
|
src/data/data_processor.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data Processor for DeepSeek Children's Stories Model
|
3 |
+
Handles dataset loading, preprocessing, and tokenization for children's story generation
|
4 |
+
"""
|
5 |
+
|
6 |
+
import tiktoken
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
from datasets import load_dataset
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
import torch
|
12 |
+
from typing import Dict, List, Optional
|
13 |
+
|
14 |
+
def load_encoder_decoder():
|
15 |
+
"""Load the encoder and decoder for text processing"""
|
16 |
+
enc = tiktoken.get_encoding("gpt2")
|
17 |
+
return enc, enc
|
18 |
+
|
19 |
+
class DeepSeekDataProcessor:
|
20 |
+
def __init__(self, config=None):
|
21 |
+
# Initialize tokenizer with GPT-2 encoding
|
22 |
+
self.enc = tiktoken.get_encoding("gpt2")
|
23 |
+
|
24 |
+
# Special tokens for story structure (optimized for children's stories)
|
25 |
+
self.special_tokens = {
|
26 |
+
"story_start": "<|story|>",
|
27 |
+
"story_end": "</|story|>",
|
28 |
+
"prompt_start": "<|prompt|>",
|
29 |
+
"prompt_end": "</|prompt|>",
|
30 |
+
"moral_start": "<|moral|>",
|
31 |
+
"moral_end": "</|moral|>",
|
32 |
+
"character_start": "<|character|>",
|
33 |
+
"character_end": "</|character|>"
|
34 |
+
}
|
35 |
+
|
36 |
+
# Ensure data directory exists
|
37 |
+
self.data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
|
38 |
+
os.makedirs(self.data_dir, exist_ok=True)
|
39 |
+
print(f"Data directory: {self.data_dir}")
|
40 |
+
|
41 |
+
# Configuration for processing
|
42 |
+
self.max_length = 1024 # DeepSeek context window
|
43 |
+
self.min_length = 50 # Minimum story length
|
44 |
+
|
45 |
+
def preprocess_text(self, text: str) -> str:
|
46 |
+
"""Preprocess text for children's stories"""
|
47 |
+
# Basic text cleaning
|
48 |
+
text = text.lower() # Convert to lowercase for consistency
|
49 |
+
text = text.replace('\n', ' ') # Replace newlines with spaces
|
50 |
+
text = ' '.join(text.split()) # Normalize whitespace
|
51 |
+
|
52 |
+
# Remove any inappropriate content markers (basic filtering)
|
53 |
+
inappropriate_phrases = ['adult content', 'mature', 'explicit']
|
54 |
+
for phrase in inappropriate_phrases:
|
55 |
+
if phrase in text:
|
56 |
+
return ""
|
57 |
+
|
58 |
+
# Ensure the text is child-friendly
|
59 |
+
if len(text) < self.min_length:
|
60 |
+
return ""
|
61 |
+
|
62 |
+
return text
|
63 |
+
|
64 |
+
def extract_story_elements(self, example: Dict) -> Dict:
|
65 |
+
"""Extract story elements for better structure"""
|
66 |
+
prompt = self.preprocess_text(example.get('prompt', ''))
|
67 |
+
story = self.preprocess_text(example.get('text', ''))
|
68 |
+
|
69 |
+
# Extract potential moral or lesson
|
70 |
+
moral = ""
|
71 |
+
if 'moral' in example:
|
72 |
+
moral = self.preprocess_text(example['moral'])
|
73 |
+
elif 'lesson' in example:
|
74 |
+
moral = self.preprocess_text(example['lesson'])
|
75 |
+
|
76 |
+
# Extract main character if available
|
77 |
+
character = ""
|
78 |
+
if 'character' in example:
|
79 |
+
character = self.preprocess_text(example['character'])
|
80 |
+
|
81 |
+
return {
|
82 |
+
'prompt': prompt,
|
83 |
+
'story': story,
|
84 |
+
'moral': moral,
|
85 |
+
'character': character
|
86 |
+
}
|
87 |
+
|
88 |
+
def process(self, example: Dict) -> Dict:
|
89 |
+
"""Process a single example for DeepSeek model"""
|
90 |
+
# Extract story elements
|
91 |
+
elements = self.extract_story_elements(example)
|
92 |
+
|
93 |
+
# Skip if no valid content
|
94 |
+
if not elements['story'] or not elements['prompt']:
|
95 |
+
return {'ids': [], 'len': 0}
|
96 |
+
|
97 |
+
# Create structured text with special tokens
|
98 |
+
full_text = (
|
99 |
+
f"{self.special_tokens['prompt_start']} {elements['prompt']} {self.special_tokens['prompt_end']} "
|
100 |
+
)
|
101 |
+
|
102 |
+
# Add character information if available
|
103 |
+
if elements['character']:
|
104 |
+
full_text += f"{self.special_tokens['character_start']} {elements['character']} {self.special_tokens['character_end']} "
|
105 |
+
|
106 |
+
# Add the main story
|
107 |
+
full_text += f"{self.special_tokens['story_start']} {elements['story']} {self.special_tokens['story_end']}"
|
108 |
+
|
109 |
+
# Add moral if available
|
110 |
+
if elements['moral']:
|
111 |
+
full_text += f" {self.special_tokens['moral_start']} {elements['moral']} {self.special_tokens['moral_end']}"
|
112 |
+
|
113 |
+
# Tokenize with error handling
|
114 |
+
try:
|
115 |
+
ids = self.enc.encode_ordinary(full_text)
|
116 |
+
|
117 |
+
# Ensure the sequence isn't too long
|
118 |
+
if len(ids) > self.max_length:
|
119 |
+
ids = ids[:self.max_length]
|
120 |
+
|
121 |
+
# Skip if too short
|
122 |
+
if len(ids) < 20:
|
123 |
+
return {'ids': [], 'len': 0}
|
124 |
+
|
125 |
+
out = {'ids': ids, 'len': len(ids)}
|
126 |
+
return out
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
print(f"Error tokenizing text: {e}")
|
130 |
+
return {'ids': [], 'len': 0}
|
131 |
+
|
132 |
+
def prepare_dataset(self) -> Dict:
|
133 |
+
"""Prepare the Children Stories Collection dataset for DeepSeek training"""
|
134 |
+
# Load the Children Stories Collection dataset
|
135 |
+
print("Loading Children Stories Collection dataset...")
|
136 |
+
ds = load_dataset("ajibawa-2023/Children-Stories-Collection")
|
137 |
+
|
138 |
+
train_bin_path = os.path.join(self.data_dir, "train.bin")
|
139 |
+
val_bin_path = os.path.join(self.data_dir, "validation.bin")
|
140 |
+
finetune_bin_path = os.path.join(self.data_dir, "finetune.bin")
|
141 |
+
|
142 |
+
print(f"Checking for existing processed files...")
|
143 |
+
|
144 |
+
# Check if all files exist
|
145 |
+
if (os.path.exists(train_bin_path) and
|
146 |
+
os.path.exists(val_bin_path) and
|
147 |
+
os.path.exists(finetune_bin_path)):
|
148 |
+
|
149 |
+
print("Found existing processed files!")
|
150 |
+
print(f"Train file: {os.path.getsize(train_bin_path) / (1024*1024):.2f} MB")
|
151 |
+
print(f"Validation file: {os.path.getsize(val_bin_path) / (1024*1024):.2f} MB")
|
152 |
+
print(f"Finetune file: {os.path.getsize(finetune_bin_path) / (1024*1024):.2f} MB")
|
153 |
+
|
154 |
+
return {
|
155 |
+
"train": train_bin_path,
|
156 |
+
"validation": val_bin_path,
|
157 |
+
"finetune": finetune_bin_path
|
158 |
+
}
|
159 |
+
|
160 |
+
print("Processing dataset...")
|
161 |
+
|
162 |
+
# Filter out examples that are too short or too long
|
163 |
+
def filter_by_length(example):
|
164 |
+
text_length = len(example.get('text', ''))
|
165 |
+
return self.min_length <= text_length <= 2000 # Reasonable length for children's stories
|
166 |
+
|
167 |
+
ds = ds.filter(filter_by_length)
|
168 |
+
print(f"After filtering: {len(ds['train'])} examples")
|
169 |
+
|
170 |
+
# Split the dataset into train, validation, and finetune sets
|
171 |
+
train_val_test = ds["train"].train_test_split(test_size=0.2, seed=42)
|
172 |
+
val_finetune = train_val_test["test"].train_test_split(test_size=0.5, seed=42)
|
173 |
+
|
174 |
+
# Create a new dataset dictionary with all splits
|
175 |
+
ds = {
|
176 |
+
"train": train_val_test["train"],
|
177 |
+
"validation": val_finetune["train"],
|
178 |
+
"finetune": val_finetune["test"]
|
179 |
+
}
|
180 |
+
|
181 |
+
print(f"Dataset split sizes:")
|
182 |
+
print(f"Training set: {len(ds['train'])} examples")
|
183 |
+
print(f"Validation set: {len(ds['validation'])} examples")
|
184 |
+
print(f"Finetune set: {len(ds['finetune'])} examples")
|
185 |
+
|
186 |
+
# Process each split
|
187 |
+
for split_name, split_data in ds.items():
|
188 |
+
print(f"\nProcessing {split_name} split...")
|
189 |
+
|
190 |
+
# Process the data
|
191 |
+
tokenized = split_data.map(
|
192 |
+
self.process,
|
193 |
+
remove_columns=['text', 'prompt', 'text_token_length'],
|
194 |
+
desc=f"tokenizing {split_name} split",
|
195 |
+
num_proc=8,
|
196 |
+
)
|
197 |
+
|
198 |
+
# Filter out empty sequences
|
199 |
+
tokenized = tokenized.filter(lambda x: x['len'] > 0)
|
200 |
+
print(f"After processing: {len(tokenized)} valid examples")
|
201 |
+
|
202 |
+
# Save to binary file
|
203 |
+
filename = os.path.join(self.data_dir, f"{split_name}.bin")
|
204 |
+
print(f"Saving {split_name} split to: {filename}")
|
205 |
+
|
206 |
+
# Calculate total length
|
207 |
+
arr_len = np.sum(tokenized['len'], dtype=np.uint64)
|
208 |
+
dtype = np.uint16
|
209 |
+
arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
|
210 |
+
total_batches = 1024
|
211 |
+
|
212 |
+
idx = 0
|
213 |
+
for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
|
214 |
+
batch = tokenized.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
|
215 |
+
arr_batch = np.concatenate(batch['ids'])
|
216 |
+
arr[idx : idx + len(arr_batch)] = arr_batch
|
217 |
+
idx += len(arr_batch)
|
218 |
+
arr.flush()
|
219 |
+
|
220 |
+
# Verify file was created
|
221 |
+
if os.path.exists(filename):
|
222 |
+
print(f"Successfully created {filename}")
|
223 |
+
print(f"File size: {os.path.getsize(filename) / (1024*1024):.2f} MB")
|
224 |
+
else:
|
225 |
+
raise RuntimeError(f"Failed to create {filename}")
|
226 |
+
|
227 |
+
return {
|
228 |
+
"train": train_bin_path,
|
229 |
+
"validation": val_bin_path,
|
230 |
+
"finetune": finetune_bin_path
|
231 |
+
}
|
232 |
+
|
233 |
+
def load_binary_data(self, filepath: str) -> torch.Tensor:
|
234 |
+
"""Load binary data file as tensor"""
|
235 |
+
try:
|
236 |
+
data = np.memmap(filepath, dtype=np.uint16, mode='r')
|
237 |
+
return torch.from_numpy(data.copy())
|
238 |
+
except Exception as e:
|
239 |
+
print(f"Error loading data from {filepath}: {e}")
|
240 |
+
raise
|
241 |
+
|
242 |
+
def get_batch(self, data: torch.Tensor, batch_size: int, block_size: int) -> tuple:
|
243 |
+
"""Get a batch of data for training"""
|
244 |
+
# Generate random indices
|
245 |
+
ix = torch.randint(len(data) - block_size, (batch_size,))
|
246 |
+
|
247 |
+
# Get input sequences
|
248 |
+
x = torch.stack([data[i:i+block_size].long() for i in ix])
|
249 |
+
# Get target sequences (shifted by 1)
|
250 |
+
y = torch.stack([data[i+1:i+1+block_size].long() for i in ix])
|
251 |
+
|
252 |
+
return x, y
|
253 |
+
|
254 |
+
def decode_tokens(self, token_ids: List[int]) -> str:
|
255 |
+
"""Decode token IDs back to text"""
|
256 |
+
try:
|
257 |
+
return self.enc.decode(token_ids)
|
258 |
+
except Exception as e:
|
259 |
+
print(f"Error decoding tokens: {e}")
|
260 |
+
return ""
|
261 |
+
|
262 |
+
def encode_text(self, text: str) -> List[int]:
|
263 |
+
"""Encode text to token IDs"""
|
264 |
+
try:
|
265 |
+
return self.enc.encode_ordinary(text)
|
266 |
+
except Exception as e:
|
267 |
+
print(f"Error encoding text: {e}")
|
268 |
+
return []
|
269 |
+
|
270 |
+
|
271 |
+
def main():
|
272 |
+
"""Main function to process the dataset"""
|
273 |
+
print("DeepSeek Children's Stories Data Processor")
|
274 |
+
print("=" * 50)
|
275 |
+
|
276 |
+
processor = DeepSeekDataProcessor()
|
277 |
+
processor.prepare_dataset()
|
278 |
+
|
279 |
+
print("\nData processing completed successfully!")
|
280 |
+
print("Files created:")
|
281 |
+
print("- src/data/train.bin")
|
282 |
+
print("- src/data/validation.bin")
|
283 |
+
print("- src/data/finetune.bin")
|
284 |
+
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
main()
|
src/data/finetune.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4f4819be598b40b35540e069ed44efc44ffc732ab6c29269e5f1a227fb9e77f
|
3 |
+
size 61167196
|
src/data/train.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:09234a8bf9f55e6e59f48b765bea680f3c2aa4a8305e9a553d9593de4652d0aa
|
3 |
+
size 488961682
|
src/data/validation.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98148ea1408dc8b01abe08126221c6ef3cd01362240b9d0bb1ce39e477ae211c
|
3 |
+
size 61065952
|
src/generate.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DeepSeek Children's Stories Text Generation
|
3 |
+
Generate children's stories using the trained DeepSeek model
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
import tiktoken
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
# Add the src directory to Python path
|
14 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
15 |
+
|
16 |
+
from model.deepseek import DeepSeek, DeepSeekConfig
|
17 |
+
|
18 |
+
# Allowlist DeepSeekConfig for safe deserialization
|
19 |
+
torch.serialization.add_safe_globals([DeepSeekConfig])
|
20 |
+
|
21 |
+
class DeepSeekStoryGenerator:
|
22 |
+
def __init__(self, model_path: str, device: str = 'auto'):
|
23 |
+
"""Initialize the story generator"""
|
24 |
+
self.device = self._get_device(device)
|
25 |
+
self.model = self._load_model(model_path)
|
26 |
+
self.tokenizer = tiktoken.get_encoding("gpt2")
|
27 |
+
|
28 |
+
# Special tokens for story structure
|
29 |
+
self.special_tokens = {
|
30 |
+
"story_start": "<|story|>",
|
31 |
+
"story_end": "</|story|>",
|
32 |
+
"prompt_start": "<|prompt|>",
|
33 |
+
"prompt_end": "</|prompt|>",
|
34 |
+
"moral_start": "<|moral|>",
|
35 |
+
"moral_end": "</|moral|>",
|
36 |
+
"character_start": "<|character|>",
|
37 |
+
"character_end": "</|character|>"
|
38 |
+
}
|
39 |
+
|
40 |
+
def _get_device(self, device: str) -> str:
|
41 |
+
"""Get the appropriate device"""
|
42 |
+
if device == 'auto':
|
43 |
+
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
44 |
+
return device
|
45 |
+
|
46 |
+
def _load_model(self, model_path: str) -> DeepSeek:
|
47 |
+
"""Load the trained model"""
|
48 |
+
print(f"Loading model from {model_path}...")
|
49 |
+
|
50 |
+
# Load checkpoint
|
51 |
+
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
|
52 |
+
|
53 |
+
# Create model with the same configuration
|
54 |
+
config = checkpoint['config']
|
55 |
+
model = DeepSeek(config)
|
56 |
+
|
57 |
+
# Handle compiled model state dict by removing _orig_mod prefix
|
58 |
+
state_dict = checkpoint['model']
|
59 |
+
if all(k.startswith('_orig_mod.') for k in state_dict.keys()):
|
60 |
+
state_dict = {k[10:]: v for k, v in state_dict.items()} # Remove '_orig_mod.' prefix
|
61 |
+
|
62 |
+
# Load model weights
|
63 |
+
model.load_state_dict(state_dict)
|
64 |
+
model.to(self.device)
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
print(f"Model loaded successfully!")
|
68 |
+
print(f"Model configuration: {config.n_layer}L/{config.n_head}H/{config.n_embd}D")
|
69 |
+
print(f"Device: {self.device}")
|
70 |
+
|
71 |
+
return model
|
72 |
+
|
73 |
+
def encode_prompt(self, prompt: str, character: Optional[str] = None) -> torch.Tensor:
|
74 |
+
"""Encode a prompt for generation"""
|
75 |
+
# Create structured prompt
|
76 |
+
full_prompt = f"{self.special_tokens['prompt_start']} {prompt.lower()} {self.special_tokens['prompt_end']}"
|
77 |
+
|
78 |
+
if character:
|
79 |
+
full_prompt += f" {self.special_tokens['character_start']} {character.lower()} {self.special_tokens['character_end']}"
|
80 |
+
|
81 |
+
full_prompt += f" {self.special_tokens['story_start']}"
|
82 |
+
|
83 |
+
# Tokenize
|
84 |
+
token_ids = self.tokenizer.encode_ordinary(full_prompt)
|
85 |
+
return torch.tensor([token_ids], dtype=torch.long, device=self.device)
|
86 |
+
|
87 |
+
def generate_story(self, prompt: str, character: Optional[str] = None,
|
88 |
+
max_tokens: int = 200, temperature: float = 0.8,
|
89 |
+
top_k: int = 40, top_p: float = 0.9) -> str:
|
90 |
+
"""Generate a children's story"""
|
91 |
+
print(f"Generating story for prompt: '{prompt}'")
|
92 |
+
if character:
|
93 |
+
print(f"Character: {character}")
|
94 |
+
|
95 |
+
# Encode prompt
|
96 |
+
input_ids = self.encode_prompt(prompt, character)
|
97 |
+
|
98 |
+
# Generate
|
99 |
+
with torch.no_grad():
|
100 |
+
generated_ids = self.model.generate(
|
101 |
+
input_ids,
|
102 |
+
max_new_tokens=max_tokens,
|
103 |
+
temperature=temperature,
|
104 |
+
top_k=top_k
|
105 |
+
)
|
106 |
+
|
107 |
+
# Decode the generated text
|
108 |
+
generated_text = self.tokenizer.decode(generated_ids[0].tolist())
|
109 |
+
|
110 |
+
# Extract the story part
|
111 |
+
story = self._extract_story(generated_text)
|
112 |
+
|
113 |
+
return story
|
114 |
+
|
115 |
+
def _extract_story(self, text: str) -> str:
|
116 |
+
"""Extract the story from the generated text"""
|
117 |
+
# Find story start and end markers
|
118 |
+
story_start = text.find(self.special_tokens['story_start'])
|
119 |
+
story_end = text.find(self.special_tokens['story_end'])
|
120 |
+
|
121 |
+
if story_start != -1 and story_end != -1:
|
122 |
+
# Extract story content
|
123 |
+
story_content = text[story_start + len(self.special_tokens['story_start']):story_end].strip()
|
124 |
+
return story_content
|
125 |
+
else:
|
126 |
+
# Fallback: return the text after the last prompt
|
127 |
+
prompt_end = text.find(self.special_tokens['prompt_end'])
|
128 |
+
if prompt_end != -1:
|
129 |
+
return text[prompt_end + len(self.special_tokens['prompt_end']):].strip()
|
130 |
+
else:
|
131 |
+
return text.strip()
|
132 |
+
|
133 |
+
def generate_multiple_stories(self, prompts: List[str], num_stories: int = 3,
|
134 |
+
**kwargs) -> List[str]:
|
135 |
+
"""Generate multiple stories from a list of prompts"""
|
136 |
+
stories = []
|
137 |
+
|
138 |
+
for i, prompt in enumerate(prompts):
|
139 |
+
print(f"\nGenerating story {i+1}/{len(prompts)}...")
|
140 |
+
story = self.generate_story(prompt, **kwargs)
|
141 |
+
stories.append(story)
|
142 |
+
|
143 |
+
return stories
|
144 |
+
|
145 |
+
def interactive_generation(self):
|
146 |
+
"""Interactive story generation mode"""
|
147 |
+
print("DeepSeek Children's Stories - Interactive Mode")
|
148 |
+
print("Type 'quit' to exit")
|
149 |
+
print("-" * 50)
|
150 |
+
|
151 |
+
while True:
|
152 |
+
try:
|
153 |
+
# Get prompt from user
|
154 |
+
prompt = input("\nEnter a story prompt: ").strip()
|
155 |
+
|
156 |
+
if prompt.lower() in ['quit', 'exit', 'q']:
|
157 |
+
print("Goodbye!")
|
158 |
+
break
|
159 |
+
|
160 |
+
if not prompt:
|
161 |
+
print("Please enter a valid prompt.")
|
162 |
+
continue
|
163 |
+
|
164 |
+
# Get character (optional)
|
165 |
+
character = input("Enter a character name (optional): ").strip()
|
166 |
+
if not character:
|
167 |
+
character = None
|
168 |
+
|
169 |
+
# Get generation parameters
|
170 |
+
try:
|
171 |
+
max_tokens = int(input("Max tokens (default 200): ") or "200")
|
172 |
+
temperature = float(input("Temperature (default 0.8): ") or "0.8")
|
173 |
+
except ValueError:
|
174 |
+
max_tokens = 200
|
175 |
+
temperature = 0.8
|
176 |
+
|
177 |
+
# Generate story
|
178 |
+
story = self.generate_story(
|
179 |
+
prompt,
|
180 |
+
character=character,
|
181 |
+
max_tokens=max_tokens,
|
182 |
+
temperature=temperature
|
183 |
+
)
|
184 |
+
|
185 |
+
# Display story
|
186 |
+
print("\n" + "="*50)
|
187 |
+
print("GENERATED STORY:")
|
188 |
+
print("="*50)
|
189 |
+
print(story)
|
190 |
+
print("="*50)
|
191 |
+
|
192 |
+
except KeyboardInterrupt:
|
193 |
+
print("\nGoodbye!")
|
194 |
+
break
|
195 |
+
except Exception as e:
|
196 |
+
print(f"Error generating story: {e}")
|
197 |
+
|
198 |
+
|
199 |
+
def main():
|
200 |
+
"""Main generation function"""
|
201 |
+
parser = argparse.ArgumentParser(description='Generate children\'s stories with DeepSeek')
|
202 |
+
|
203 |
+
# Model configuration
|
204 |
+
parser.add_argument('--model-path', type=str, default='checkpoints/best_model.pt',
|
205 |
+
help='Path to the trained model checkpoint')
|
206 |
+
parser.add_argument('--device', type=str, default='auto',
|
207 |
+
help='Device to use (auto, cuda, cpu)')
|
208 |
+
|
209 |
+
# Generation parameters
|
210 |
+
parser.add_argument('--prompt', type=str, help='Story prompt')
|
211 |
+
parser.add_argument('--character', type=str, help='Character name')
|
212 |
+
parser.add_argument('--max-tokens', type=int, default=200, help='Maximum tokens to generate')
|
213 |
+
parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature')
|
214 |
+
parser.add_argument('--top-k', type=int, default=40, help='Top-k sampling')
|
215 |
+
parser.add_argument('--top-p', type=float, default=0.9, help='Top-p sampling')
|
216 |
+
|
217 |
+
# Multiple generation
|
218 |
+
parser.add_argument('--num-stories', type=int, default=1, help='Number of stories to generate')
|
219 |
+
parser.add_argument('--interactive', action='store_true', help='Interactive mode')
|
220 |
+
|
221 |
+
args = parser.parse_args()
|
222 |
+
|
223 |
+
# Check if model exists
|
224 |
+
if not os.path.exists(args.model_path):
|
225 |
+
print(f"Error: Model file not found at {args.model_path}")
|
226 |
+
print("Please train the model first or specify the correct path.")
|
227 |
+
return
|
228 |
+
|
229 |
+
# Create generator
|
230 |
+
generator = DeepSeekStoryGenerator(args.model_path, args.device)
|
231 |
+
|
232 |
+
if args.interactive:
|
233 |
+
# Interactive mode
|
234 |
+
generator.interactive_generation()
|
235 |
+
else:
|
236 |
+
# Single or multiple generation
|
237 |
+
if args.prompt:
|
238 |
+
if args.num_stories == 1:
|
239 |
+
# Single story
|
240 |
+
story = generator.generate_story(
|
241 |
+
args.prompt,
|
242 |
+
character=args.character,
|
243 |
+
max_tokens=args.max_tokens,
|
244 |
+
temperature=args.temperature,
|
245 |
+
top_k=args.top_k,
|
246 |
+
top_p=args.top_p
|
247 |
+
)
|
248 |
+
|
249 |
+
print(f"\nPrompt: {args.prompt}")
|
250 |
+
if args.character:
|
251 |
+
print(f"Character: {args.character}")
|
252 |
+
print("\n" + "="*50)
|
253 |
+
print("GENERATED STORY:")
|
254 |
+
print("="*50)
|
255 |
+
print(story)
|
256 |
+
print("="*50)
|
257 |
+
else:
|
258 |
+
# Multiple stories
|
259 |
+
prompts = [args.prompt] * args.num_stories
|
260 |
+
stories = generator.generate_multiple_stories(
|
261 |
+
prompts,
|
262 |
+
num_stories=args.num_stories,
|
263 |
+
character=args.character,
|
264 |
+
max_tokens=args.max_tokens,
|
265 |
+
temperature=args.temperature,
|
266 |
+
top_k=args.top_k,
|
267 |
+
top_p=args.top_p
|
268 |
+
)
|
269 |
+
|
270 |
+
for i, story in enumerate(stories):
|
271 |
+
print(f"\nStory {i+1}:")
|
272 |
+
print("="*50)
|
273 |
+
print(story)
|
274 |
+
print("="*50)
|
275 |
+
else:
|
276 |
+
print("Please provide a prompt or use --interactive mode.")
|
277 |
+
print("Example: python generate.py --prompt 'A brave little mouse' --character 'Mickey'")
|
278 |
+
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
main()
|
src/model/__pycache__/deepseek.cpython-310.pyc
ADDED
Binary file (13.8 kB). View file
|
|
src/model/deepseek.py
ADDED
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DeepSeek Model Architecture for Children's Stories
|
3 |
+
Implements advanced features:
|
4 |
+
- Multihead Latent Attention (MLA)
|
5 |
+
- Mixture of Experts (MoE)
|
6 |
+
- Multi-token prediction
|
7 |
+
- Quantization support
|
8 |
+
- Rotary Positional Encodings (RoPE)
|
9 |
+
- Optimized for children's story generation
|
10 |
+
"""
|
11 |
+
|
12 |
+
import math
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from typing import Optional, Tuple, List
|
17 |
+
from dataclasses import dataclass
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class DeepSeekConfig:
|
22 |
+
"""Configuration for DeepSeek model optimized for children's stories"""
|
23 |
+
vocab_size: int = 50257 # GPT-2 vocabulary size
|
24 |
+
n_layer: int = 6 # Reduced for efficiency
|
25 |
+
n_head: int = 8 # Number of attention heads
|
26 |
+
n_embd: int = 512 # Embedding dimension
|
27 |
+
block_size: int = 1024 # Context window
|
28 |
+
dropout: float = 0.1 # Dropout rate
|
29 |
+
bias: bool = True # Use bias in linear layers
|
30 |
+
|
31 |
+
# MLA (Multihead Latent Attention) config
|
32 |
+
use_mla: bool = True # Enable MLA
|
33 |
+
mla_kv_heads: int = 4 # Number of key-value heads for MLA
|
34 |
+
mla_q_lora_rank: int = 32 # LoRA rank for query projection
|
35 |
+
mla_kv_lora_rank: int = 16 # LoRA rank for key-value projection
|
36 |
+
|
37 |
+
# MoE (Mixture of Experts) config
|
38 |
+
moe_num_experts: int = 4 # Number of experts
|
39 |
+
moe_top_k: int = 2 # Number of experts per token
|
40 |
+
moe_expert_capacity: float = 1.25
|
41 |
+
moe_aux_loss_coeff: float = 0.01
|
42 |
+
|
43 |
+
# Multi-token prediction
|
44 |
+
multi_token_predict: int = 2 # Predict next 2 tokens for children's stories
|
45 |
+
|
46 |
+
# Quantization
|
47 |
+
use_quantization: bool = False
|
48 |
+
quantization_bits: int = 8
|
49 |
+
|
50 |
+
|
51 |
+
class RoPEPositionalEncoding(nn.Module):
|
52 |
+
"""Rotary Positional Encoding (RoPE) for better position understanding"""
|
53 |
+
|
54 |
+
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
|
55 |
+
super().__init__()
|
56 |
+
self.dim = dim
|
57 |
+
self.max_seq_len = max_seq_len
|
58 |
+
self.base = base
|
59 |
+
|
60 |
+
# Precompute frequency matrix
|
61 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
62 |
+
self.register_buffer('inv_freq', inv_freq)
|
63 |
+
|
64 |
+
# Cache for efficiency
|
65 |
+
self._cached_cos = None
|
66 |
+
self._cached_sin = None
|
67 |
+
self._cached_seq_len = 0
|
68 |
+
|
69 |
+
def _compute_cos_sin(self, seq_len: int, device: torch.device):
|
70 |
+
"""Compute cosine and sine values for given sequence length"""
|
71 |
+
if seq_len > self._cached_seq_len or self._cached_cos is None:
|
72 |
+
# Create position indices
|
73 |
+
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
74 |
+
|
75 |
+
# Compute frequencies
|
76 |
+
freqs = torch.outer(t, self.inv_freq)
|
77 |
+
|
78 |
+
# Create rotation matrix components
|
79 |
+
cos_vals = torch.cos(freqs)
|
80 |
+
sin_vals = torch.sin(freqs)
|
81 |
+
|
82 |
+
# Cache results
|
83 |
+
self._cached_cos = cos_vals
|
84 |
+
self._cached_sin = sin_vals
|
85 |
+
self._cached_seq_len = seq_len
|
86 |
+
|
87 |
+
return self._cached_cos[:seq_len], self._cached_sin[:seq_len]
|
88 |
+
|
89 |
+
def apply_rope(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None):
|
90 |
+
"""Apply RoPE to input tensor"""
|
91 |
+
batch_size, seq_len, n_heads, head_dim = x.shape
|
92 |
+
|
93 |
+
# Get cos/sin values
|
94 |
+
cos, sin = self._compute_cos_sin(seq_len, x.device)
|
95 |
+
|
96 |
+
# Handle position_ids if provided
|
97 |
+
if position_ids is not None:
|
98 |
+
cos = cos[position_ids]
|
99 |
+
sin = sin[position_ids]
|
100 |
+
|
101 |
+
# Reshape for broadcasting
|
102 |
+
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, head_dim//2]
|
103 |
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
104 |
+
|
105 |
+
# Split x into two halves
|
106 |
+
x1 = x[..., ::2] # Even indices
|
107 |
+
x2 = x[..., 1::2] # Odd indices
|
108 |
+
|
109 |
+
# Apply rotation
|
110 |
+
rotated_x1 = x1 * cos - x2 * sin
|
111 |
+
rotated_x2 = x1 * sin + x2 * cos
|
112 |
+
|
113 |
+
# Recombine
|
114 |
+
rotated_x = torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
|
115 |
+
|
116 |
+
return rotated_x
|
117 |
+
|
118 |
+
|
119 |
+
class MultiheadLatentAttention(nn.Module):
|
120 |
+
"""
|
121 |
+
Multihead Latent Attention (MLA) - DeepSeek's efficient attention mechanism
|
122 |
+
Uses shared key-value heads with LoRA-style projections for efficiency
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, config: DeepSeekConfig):
|
126 |
+
super().__init__()
|
127 |
+
self.config = config
|
128 |
+
self.n_head = config.n_head
|
129 |
+
self.n_embd = config.n_embd
|
130 |
+
self.head_dim = config.n_embd // config.n_head
|
131 |
+
self.kv_heads = config.mla_kv_heads
|
132 |
+
self.kv_head_dim = self.head_dim
|
133 |
+
|
134 |
+
# Query projection with LoRA-style decomposition
|
135 |
+
self.q_a_proj = nn.Linear(config.n_embd, config.mla_q_lora_rank, bias=False)
|
136 |
+
self.q_b_proj = nn.Linear(config.mla_q_lora_rank, config.n_embd, bias=False)
|
137 |
+
|
138 |
+
# Key-Value projection with shared heads
|
139 |
+
self.kv_a_proj = nn.Linear(config.n_embd, config.mla_kv_lora_rank, bias=False)
|
140 |
+
self.kv_b_proj = nn.Linear(config.mla_kv_lora_rank, self.kv_heads * self.head_dim * 2, bias=False)
|
141 |
+
|
142 |
+
# Output projection
|
143 |
+
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
144 |
+
|
145 |
+
# RoPE for positional encoding
|
146 |
+
self.rope = RoPEPositionalEncoding(self.head_dim)
|
147 |
+
|
148 |
+
# Dropout
|
149 |
+
self.dropout = nn.Dropout(config.dropout)
|
150 |
+
|
151 |
+
# Scaling factor
|
152 |
+
self.scale = self.head_dim ** -0.5
|
153 |
+
|
154 |
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
|
155 |
+
batch_size, seq_len, _ = x.shape
|
156 |
+
|
157 |
+
# Query projection through LoRA-style decomposition
|
158 |
+
q_latent = self.q_a_proj(x) # [B, T, rank]
|
159 |
+
q = self.q_b_proj(q_latent) # [B, T, n_embd]
|
160 |
+
q = q.view(batch_size, seq_len, self.n_head, self.head_dim)
|
161 |
+
|
162 |
+
# Key-Value projection through shared heads
|
163 |
+
kv_latent = self.kv_a_proj(x) # [B, T, kv_rank]
|
164 |
+
kv = self.kv_b_proj(kv_latent) # [B, T, kv_heads * kv_head_dim * 2]
|
165 |
+
kv = kv.view(batch_size, seq_len, self.kv_heads, self.head_dim, 2)
|
166 |
+
k, v = kv.unbind(dim=-1) # Each: [B, T, kv_heads, kv_head_dim]
|
167 |
+
|
168 |
+
# Apply RoPE to queries and keys before expansion
|
169 |
+
q = self.rope.apply_rope(q)
|
170 |
+
k = self.rope.apply_rope(k)
|
171 |
+
|
172 |
+
# Expand key-value to match query heads
|
173 |
+
k = k.repeat_interleave(self.n_head // self.kv_heads, dim=2)
|
174 |
+
v = v.repeat_interleave(self.n_head // self.kv_heads, dim=2)
|
175 |
+
|
176 |
+
# Transpose for attention computation
|
177 |
+
q = q.transpose(1, 2) # [B, n_head, T, head_dim]
|
178 |
+
k = k.transpose(1, 2) # [B, n_head, T, head_dim]
|
179 |
+
v = v.transpose(1, 2) # [B, n_head, T, head_dim]
|
180 |
+
|
181 |
+
# Compute attention scores
|
182 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
183 |
+
|
184 |
+
# Apply causal mask
|
185 |
+
if attention_mask is None:
|
186 |
+
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
|
187 |
+
attn_scores.masked_fill_(causal_mask, float('-inf'))
|
188 |
+
else:
|
189 |
+
attn_scores = attn_scores + attention_mask
|
190 |
+
|
191 |
+
# Apply softmax
|
192 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
193 |
+
attn_weights = self.dropout(attn_weights)
|
194 |
+
|
195 |
+
# Apply attention to values
|
196 |
+
out = torch.matmul(attn_weights, v) # [B, n_head, T, head_dim]
|
197 |
+
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd)
|
198 |
+
|
199 |
+
# Output projection
|
200 |
+
out = self.out_proj(out)
|
201 |
+
|
202 |
+
return out
|
203 |
+
|
204 |
+
|
205 |
+
class MoEExpert(nn.Module):
|
206 |
+
"""Expert network for Mixture of Experts"""
|
207 |
+
|
208 |
+
def __init__(self, config: DeepSeekConfig):
|
209 |
+
super().__init__()
|
210 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
211 |
+
self.gelu = nn.GELU()
|
212 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
213 |
+
self.dropout = nn.Dropout(config.dropout)
|
214 |
+
|
215 |
+
def forward(self, x: torch.Tensor):
|
216 |
+
return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
|
217 |
+
|
218 |
+
|
219 |
+
class MixtureOfExperts(nn.Module):
|
220 |
+
"""Mixture of Experts (MoE) for increased model capacity"""
|
221 |
+
|
222 |
+
def __init__(self, config: DeepSeekConfig):
|
223 |
+
super().__init__()
|
224 |
+
self.config = config
|
225 |
+
self.num_experts = config.moe_num_experts
|
226 |
+
self.top_k = config.moe_top_k
|
227 |
+
self.expert_capacity = config.moe_expert_capacity
|
228 |
+
|
229 |
+
# Router
|
230 |
+
self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False)
|
231 |
+
|
232 |
+
# Experts
|
233 |
+
self.experts = nn.ModuleList([MoEExpert(config) for _ in range(config.moe_num_experts)])
|
234 |
+
|
235 |
+
# Layer norm
|
236 |
+
self.ln = nn.LayerNorm(config.n_embd, bias=config.bias)
|
237 |
+
|
238 |
+
def forward(self, x: torch.Tensor):
|
239 |
+
batch_size, seq_len, hidden_dim = x.shape
|
240 |
+
|
241 |
+
# Get router logits
|
242 |
+
router_logits = self.router(x) # [B, T, num_experts]
|
243 |
+
|
244 |
+
# Get top-k experts
|
245 |
+
top_k_logits, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1)
|
246 |
+
top_k_probs = F.softmax(top_k_logits, dim=-1)
|
247 |
+
|
248 |
+
# Initialize output
|
249 |
+
output = torch.zeros_like(x)
|
250 |
+
|
251 |
+
# Process each expert
|
252 |
+
for expert_idx in range(self.num_experts):
|
253 |
+
# Find tokens that use this expert
|
254 |
+
expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [B, T]
|
255 |
+
|
256 |
+
if expert_mask.any():
|
257 |
+
# Get tokens for this expert
|
258 |
+
expert_tokens = x[expert_mask] # [num_tokens, hidden_dim]
|
259 |
+
|
260 |
+
# Get routing weights for this expert
|
261 |
+
expert_weights = top_k_probs[expert_mask] # [num_tokens, top_k]
|
262 |
+
expert_weights = expert_weights[top_k_indices[expert_mask] == expert_idx] # [num_tokens]
|
263 |
+
|
264 |
+
# Apply expert
|
265 |
+
expert_output = self.experts[expert_idx](expert_tokens) # [num_tokens, hidden_dim]
|
266 |
+
|
267 |
+
# Weight the output
|
268 |
+
weighted_output = expert_output * expert_weights.unsqueeze(-1)
|
269 |
+
|
270 |
+
# Add to output
|
271 |
+
output[expert_mask] += weighted_output
|
272 |
+
|
273 |
+
# Apply layer norm
|
274 |
+
output = self.ln(output)
|
275 |
+
|
276 |
+
return output, router_logits
|
277 |
+
|
278 |
+
def _compute_aux_loss(self, router_logits: torch.Tensor):
|
279 |
+
"""Compute auxiliary loss for load balancing"""
|
280 |
+
router_probs = F.softmax(router_logits, dim=-1)
|
281 |
+
mean_expert_usage = router_probs.mean(dim=[0, 1]) # [num_experts]
|
282 |
+
target_usage = 1.0 / self.num_experts
|
283 |
+
|
284 |
+
aux_loss = torch.sum((mean_expert_usage - target_usage) ** 2)
|
285 |
+
return aux_loss
|
286 |
+
|
287 |
+
|
288 |
+
class DeepSeekBlock(nn.Module):
|
289 |
+
"""DeepSeek transformer block with MLA and MoE"""
|
290 |
+
|
291 |
+
def __init__(self, config: DeepSeekConfig):
|
292 |
+
super().__init__()
|
293 |
+
self.config = config
|
294 |
+
|
295 |
+
# Layer norms
|
296 |
+
self.ln1 = nn.LayerNorm(config.n_embd, bias=config.bias)
|
297 |
+
self.ln2 = nn.LayerNorm(config.n_embd, bias=config.bias)
|
298 |
+
|
299 |
+
# Attention - use MLA if enabled, otherwise use standard attention
|
300 |
+
if config.use_mla:
|
301 |
+
self.attn = MultiheadLatentAttention(config)
|
302 |
+
else:
|
303 |
+
# Standard multihead attention as fallback
|
304 |
+
self.attn = nn.MultiheadAttention(
|
305 |
+
config.n_embd,
|
306 |
+
config.n_head,
|
307 |
+
dropout=config.dropout,
|
308 |
+
bias=config.bias,
|
309 |
+
batch_first=True
|
310 |
+
)
|
311 |
+
|
312 |
+
# MoE
|
313 |
+
self.moe = MixtureOfExperts(config)
|
314 |
+
|
315 |
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
|
316 |
+
# Attention with residual connection
|
317 |
+
if self.config.use_mla:
|
318 |
+
x = x + self.attn(self.ln1(x), attention_mask)
|
319 |
+
else:
|
320 |
+
attn_out, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=attention_mask)
|
321 |
+
x = x + attn_out
|
322 |
+
|
323 |
+
# MoE with residual connection
|
324 |
+
moe_output, router_logits = self.moe(self.ln2(x))
|
325 |
+
x = x + moe_output
|
326 |
+
|
327 |
+
return x, router_logits
|
328 |
+
|
329 |
+
|
330 |
+
class MultiTokenPredictor(nn.Module):
|
331 |
+
"""Multi-token prediction head for improved training efficiency"""
|
332 |
+
|
333 |
+
def __init__(self, config: DeepSeekConfig):
|
334 |
+
super().__init__()
|
335 |
+
self.config = config
|
336 |
+
self.num_tokens = config.multi_token_predict
|
337 |
+
|
338 |
+
# Separate prediction heads for each future token
|
339 |
+
self.predictors = nn.ModuleList([
|
340 |
+
nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
341 |
+
for _ in range(config.multi_token_predict)
|
342 |
+
])
|
343 |
+
|
344 |
+
def forward(self, hidden_states: torch.Tensor):
|
345 |
+
"""Forward pass for multi-token prediction"""
|
346 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
347 |
+
|
348 |
+
# Predict multiple future tokens
|
349 |
+
logits = []
|
350 |
+
for i, predictor in enumerate(self.predictors):
|
351 |
+
# Use hidden states shifted by i+1 positions
|
352 |
+
if i + 1 < seq_len:
|
353 |
+
token_logits = predictor(hidden_states[:, i+1:i+2, :]) # [B, 1, vocab_size]
|
354 |
+
logits.append(token_logits)
|
355 |
+
else:
|
356 |
+
# Pad with zeros if not enough sequence length
|
357 |
+
token_logits = torch.zeros(batch_size, 1, self.config.vocab_size,
|
358 |
+
device=hidden_states.device)
|
359 |
+
logits.append(token_logits)
|
360 |
+
|
361 |
+
return torch.cat(logits, dim=1) # [B, num_tokens, vocab_size]
|
362 |
+
|
363 |
+
|
364 |
+
class DeepSeek(nn.Module):
|
365 |
+
"""DeepSeek model for children's story generation"""
|
366 |
+
|
367 |
+
def __init__(self, config: DeepSeekConfig):
|
368 |
+
super().__init__()
|
369 |
+
assert isinstance(config, DeepSeekConfig), "config must be an instance of DeepSeekConfig"
|
370 |
+
self.config = config
|
371 |
+
|
372 |
+
# Token and position embeddings
|
373 |
+
self.transformer = nn.ModuleDict(dict(
|
374 |
+
wte=nn.Embedding(config.vocab_size, config.n_embd),
|
375 |
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
376 |
+
drop=nn.Dropout(config.dropout),
|
377 |
+
h=nn.ModuleList([DeepSeekBlock(config) for _ in range(config.n_layer)]),
|
378 |
+
ln_f=nn.LayerNorm(config.n_embd, bias=config.bias),
|
379 |
+
))
|
380 |
+
|
381 |
+
# Language model head
|
382 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
383 |
+
|
384 |
+
# Multi-token predictor
|
385 |
+
if config.multi_token_predict > 0:
|
386 |
+
self.multi_token_predictor = MultiTokenPredictor(config)
|
387 |
+
else:
|
388 |
+
self.multi_token_predictor = None
|
389 |
+
|
390 |
+
# Weight tying
|
391 |
+
self.transformer.wte.weight = self.lm_head.weight
|
392 |
+
|
393 |
+
# Initialize weights
|
394 |
+
self.apply(self._init_weights)
|
395 |
+
|
396 |
+
# Setup quantization if enabled
|
397 |
+
if config.use_quantization:
|
398 |
+
self._setup_quantization()
|
399 |
+
|
400 |
+
def _init_weights(self, module):
|
401 |
+
"""Initialize model weights"""
|
402 |
+
if isinstance(module, nn.Linear):
|
403 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
404 |
+
if module.bias is not None:
|
405 |
+
nn.init.zeros_(module.bias)
|
406 |
+
elif isinstance(module, nn.Embedding):
|
407 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
408 |
+
elif isinstance(module, nn.LayerNorm):
|
409 |
+
nn.init.ones_(module.weight)
|
410 |
+
if module.bias is not None:
|
411 |
+
nn.init.zeros_(module.bias)
|
412 |
+
|
413 |
+
def _setup_quantization(self):
|
414 |
+
"""Setup quantization for the model"""
|
415 |
+
# This would implement quantization logic
|
416 |
+
# For now, just a placeholder
|
417 |
+
pass
|
418 |
+
|
419 |
+
def forward(self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None):
|
420 |
+
"""Forward pass"""
|
421 |
+
device = input_ids.device
|
422 |
+
batch_size, seq_len = input_ids.size()
|
423 |
+
assert seq_len <= self.config.block_size
|
424 |
+
|
425 |
+
# Position indices
|
426 |
+
pos = torch.arange(0, seq_len, dtype=torch.long, device=device)
|
427 |
+
|
428 |
+
# Token and position embeddings
|
429 |
+
tok_emb = self.transformer.wte(input_ids)
|
430 |
+
pos_emb = self.transformer.wpe(pos)
|
431 |
+
|
432 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
433 |
+
|
434 |
+
# Forward through transformer blocks
|
435 |
+
router_logits_list = []
|
436 |
+
for block in self.transformer.h:
|
437 |
+
x, router_logits = block(x)
|
438 |
+
router_logits_list.append(router_logits)
|
439 |
+
|
440 |
+
# Final layer norm
|
441 |
+
x = self.transformer.ln_f(x)
|
442 |
+
|
443 |
+
if targets is not None:
|
444 |
+
# Training mode
|
445 |
+
if self.multi_token_predictor is not None:
|
446 |
+
# Multi-token prediction
|
447 |
+
multi_logits = self.multi_token_predictor(x)
|
448 |
+
loss = self._compute_multi_token_loss(multi_logits, targets)
|
449 |
+
else:
|
450 |
+
# Standard single-token prediction
|
451 |
+
logits = self.lm_head(x)
|
452 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
|
453 |
+
targets.view(-1), ignore_index=-1)
|
454 |
+
|
455 |
+
# Add MoE auxiliary loss
|
456 |
+
if router_logits_list:
|
457 |
+
aux_loss = sum(self.transformer.h[i].moe._compute_aux_loss(router_logits_list[i])
|
458 |
+
for i in range(len(router_logits_list)))
|
459 |
+
loss += self.config.moe_aux_loss_coeff * aux_loss
|
460 |
+
|
461 |
+
return logits if self.multi_token_predictor is None else multi_logits, loss
|
462 |
+
else:
|
463 |
+
# Inference mode
|
464 |
+
logits = self.lm_head(x[:, [-1], :])
|
465 |
+
return logits, None
|
466 |
+
|
467 |
+
def _compute_multi_token_loss(self, logits: torch.Tensor, targets: torch.Tensor):
|
468 |
+
"""Compute loss for multi-token prediction"""
|
469 |
+
batch_size, num_tokens, vocab_size = logits.shape
|
470 |
+
|
471 |
+
# Reshape for loss computation
|
472 |
+
logits_flat = logits.view(-1, vocab_size)
|
473 |
+
targets_flat = targets.view(-1)
|
474 |
+
|
475 |
+
# Compute cross-entropy loss
|
476 |
+
loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=-1)
|
477 |
+
|
478 |
+
return loss
|
479 |
+
|
480 |
+
@torch.no_grad()
|
481 |
+
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100,
|
482 |
+
temperature: float = 1.0, top_k: Optional[int] = None):
|
483 |
+
"""Generate text using the model"""
|
484 |
+
for _ in range(max_new_tokens):
|
485 |
+
# Ensure input doesn't exceed block size
|
486 |
+
idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
|
487 |
+
|
488 |
+
# Forward pass
|
489 |
+
logits, _ = self(idx_cond)
|
490 |
+
logits = logits[:, -1, :] / temperature
|
491 |
+
|
492 |
+
# Apply top-k filtering
|
493 |
+
if top_k is not None:
|
494 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
495 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
496 |
+
|
497 |
+
# Sample next token
|
498 |
+
probs = F.softmax(logits, dim=-1)
|
499 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
500 |
+
input_ids = torch.cat((input_ids, idx_next), dim=1)
|
501 |
+
|
502 |
+
return input_ids
|
503 |
+
|
504 |
+
@classmethod
|
505 |
+
def from_pretrained(cls, model_type: str, override_args: Optional[dict] = None):
|
506 |
+
"""Load a pretrained model"""
|
507 |
+
# This would implement loading from pretrained weights
|
508 |
+
# For now, return a default configuration
|
509 |
+
config = DeepSeekConfig()
|
510 |
+
if override_args:
|
511 |
+
for key, value in override_args.items():
|
512 |
+
setattr(config, key, value)
|
513 |
+
return cls(config)
|
src/run_training.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DeepSeek Children's Stories Training Script
|
3 |
+
Main training script for the DeepSeek model on children's stories
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
# Add the src directory to Python path
|
14 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
15 |
+
|
16 |
+
from model.deepseek import DeepSeek, DeepSeekConfig
|
17 |
+
from training.trainer import DeepSeekTrainer, create_deepseek_trainer
|
18 |
+
from data.data_processor import DeepSeekDataProcessor
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class TrainingConfig:
|
23 |
+
"""Configuration for DeepSeek training"""
|
24 |
+
# Model configuration
|
25 |
+
vocab_size: int = 50257
|
26 |
+
n_layer: int = 6
|
27 |
+
n_head: int = 8
|
28 |
+
n_embd: int = 512
|
29 |
+
block_size: int = 1024
|
30 |
+
dropout: float = 0.1
|
31 |
+
bias: bool = True
|
32 |
+
|
33 |
+
# MLA configuration
|
34 |
+
use_mla: bool = True
|
35 |
+
mla_kv_heads: int = 4
|
36 |
+
mla_q_lora_rank: int = 32
|
37 |
+
mla_kv_lora_rank: int = 16
|
38 |
+
|
39 |
+
# MoE configuration
|
40 |
+
moe_num_experts: int = 4
|
41 |
+
moe_top_k: int = 2
|
42 |
+
moe_expert_capacity: float = 1.25
|
43 |
+
moe_aux_loss_coeff: float = 0.01
|
44 |
+
|
45 |
+
# Multi-token prediction
|
46 |
+
multi_token_predict: int = 0 # Predict next 2 tokens for efficiency
|
47 |
+
|
48 |
+
# Quantization
|
49 |
+
use_quantization: bool = False
|
50 |
+
quantization_bits: int = 8
|
51 |
+
|
52 |
+
# Training configuration
|
53 |
+
batch_size: int = 12
|
54 |
+
max_iters: int = 20000
|
55 |
+
eval_interval: int = 1000
|
56 |
+
eval_iters: int = 200
|
57 |
+
learning_rate: float = 6e-4
|
58 |
+
weight_decay: float = 0.1
|
59 |
+
warmup_iters: int = 2000
|
60 |
+
lr_decay_iters: int = 20000
|
61 |
+
min_lr: float = 6e-5
|
62 |
+
|
63 |
+
# System configuration
|
64 |
+
checkpoint_dir: str = 'checkpoints'
|
65 |
+
use_mixed_precision: bool = True
|
66 |
+
compile_model: bool = True
|
67 |
+
|
68 |
+
# Data configuration
|
69 |
+
dataset_name: str = "ajibawa-2023/Children-Stories-Collection"
|
70 |
+
data_dir: str = 'src/data'
|
71 |
+
|
72 |
+
|
73 |
+
def setup_environment():
|
74 |
+
"""Setup the training environment"""
|
75 |
+
print("Setting up DeepSeek Children's Stories training environment...")
|
76 |
+
|
77 |
+
# Check CUDA availability
|
78 |
+
if torch.cuda.is_available():
|
79 |
+
print(f"CUDA available: {torch.cuda.get_device_name(0)}")
|
80 |
+
print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
81 |
+
else:
|
82 |
+
print("CUDA not available, using CPU")
|
83 |
+
|
84 |
+
# Create necessary directories
|
85 |
+
os.makedirs('checkpoints', exist_ok=True)
|
86 |
+
os.makedirs('lora_checkpoints', exist_ok=True)
|
87 |
+
os.makedirs('src/data', exist_ok=True)
|
88 |
+
|
89 |
+
print("Environment setup complete!")
|
90 |
+
|
91 |
+
|
92 |
+
def prepare_data():
|
93 |
+
"""Prepare the dataset for training"""
|
94 |
+
print("Preparing dataset...")
|
95 |
+
|
96 |
+
processor = DeepSeekDataProcessor()
|
97 |
+
data_files = processor.prepare_dataset()
|
98 |
+
|
99 |
+
print("Dataset preparation complete!")
|
100 |
+
return data_files
|
101 |
+
|
102 |
+
|
103 |
+
def create_model(config: TrainingConfig) -> DeepSeek:
|
104 |
+
"""Create the DeepSeek model"""
|
105 |
+
print("Creating DeepSeek model...")
|
106 |
+
|
107 |
+
# Create model configuration
|
108 |
+
model_config = DeepSeekConfig(
|
109 |
+
vocab_size=config.vocab_size,
|
110 |
+
n_layer=config.n_layer,
|
111 |
+
n_head=config.n_head,
|
112 |
+
n_embd=config.n_embd,
|
113 |
+
block_size=config.block_size,
|
114 |
+
dropout=config.dropout,
|
115 |
+
bias=config.bias,
|
116 |
+
use_mla=config.use_mla,
|
117 |
+
mla_kv_heads=config.mla_kv_heads,
|
118 |
+
mla_q_lora_rank=config.mla_q_lora_rank,
|
119 |
+
mla_kv_lora_rank=config.mla_kv_lora_rank,
|
120 |
+
moe_num_experts=config.moe_num_experts,
|
121 |
+
moe_top_k=config.moe_top_k,
|
122 |
+
moe_expert_capacity=config.moe_expert_capacity,
|
123 |
+
moe_aux_loss_coeff=config.moe_aux_loss_coeff,
|
124 |
+
multi_token_predict=config.multi_token_predict,
|
125 |
+
use_quantization=config.use_quantization,
|
126 |
+
quantization_bits=config.quantization_bits
|
127 |
+
)
|
128 |
+
|
129 |
+
# Create model
|
130 |
+
model = DeepSeek(model_config)
|
131 |
+
|
132 |
+
# Compile model if requested
|
133 |
+
if config.compile_model and hasattr(torch, 'compile'):
|
134 |
+
print("Compiling model with torch.compile...")
|
135 |
+
model = torch.compile(model)
|
136 |
+
|
137 |
+
# Print model info
|
138 |
+
total_params = sum(p.numel() for p in model.parameters())
|
139 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
140 |
+
|
141 |
+
print(f"Model created successfully!")
|
142 |
+
print(f"Total parameters: {total_params:,}")
|
143 |
+
print(f"Trainable parameters: {trainable_params:,}")
|
144 |
+
print(f"Model configuration:")
|
145 |
+
print(f" - Layers: {config.n_layer}")
|
146 |
+
print(f" - Heads: {config.n_head}")
|
147 |
+
print(f" - Embedding dim: {config.n_embd}")
|
148 |
+
print(f" - MLA enabled: {config.use_mla}")
|
149 |
+
print(f" - MLA KV heads: {config.mla_kv_heads}")
|
150 |
+
print(f" - MoE experts: {config.moe_num_experts}")
|
151 |
+
print(f" - Multi-token prediction: {config.multi_token_predict}")
|
152 |
+
|
153 |
+
return model
|
154 |
+
|
155 |
+
|
156 |
+
def train_model(model: DeepSeek, config: TrainingConfig):
|
157 |
+
"""Train the DeepSeek model"""
|
158 |
+
print(f"[+] Starting training with config:")
|
159 |
+
print(f" - Model size: {sum(p.numel() for p in model.parameters()):,} parameters")
|
160 |
+
print(f" - Multi-token prediction: {config.multi_token_predict}")
|
161 |
+
print(f" - MoE experts: {config.moe_num_experts}")
|
162 |
+
print(f" - MLA enabled: {config.use_mla}")
|
163 |
+
|
164 |
+
# Setup device
|
165 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
166 |
+
model = model.to(device)
|
167 |
+
|
168 |
+
# Create optimizer
|
169 |
+
optimizer = torch.optim.AdamW(
|
170 |
+
model.parameters(),
|
171 |
+
lr=config.learning_rate,
|
172 |
+
weight_decay=config.weight_decay,
|
173 |
+
betas=(0.9, 0.95)
|
174 |
+
)
|
175 |
+
|
176 |
+
# Initialize trainer with individual parameters
|
177 |
+
trainer = DeepSeekTrainer(
|
178 |
+
model=model,
|
179 |
+
optimizer=optimizer,
|
180 |
+
device=device,
|
181 |
+
batch_size=config.batch_size,
|
182 |
+
max_iters=config.max_iters,
|
183 |
+
eval_interval=config.eval_interval,
|
184 |
+
eval_iters=config.eval_iters,
|
185 |
+
learning_rate=config.learning_rate,
|
186 |
+
weight_decay=config.weight_decay,
|
187 |
+
warmup_iters=config.warmup_iters,
|
188 |
+
lr_decay_iters=config.lr_decay_iters,
|
189 |
+
min_lr=config.min_lr,
|
190 |
+
checkpoint_dir=config.checkpoint_dir,
|
191 |
+
use_mixed_precision=config.use_mixed_precision
|
192 |
+
)
|
193 |
+
|
194 |
+
try:
|
195 |
+
# Start training
|
196 |
+
trainer.train()
|
197 |
+
print("[+] Training completed successfully!")
|
198 |
+
|
199 |
+
# Save final model
|
200 |
+
final_model_path = os.path.join(config.checkpoint_dir, "final_model.pt")
|
201 |
+
torch.save({
|
202 |
+
'model_state_dict': model.state_dict(),
|
203 |
+
'config': config,
|
204 |
+
'optimizer_state_dict': trainer.optimizer.state_dict(),
|
205 |
+
}, final_model_path)
|
206 |
+
print(f"[+] Final model saved to {final_model_path}")
|
207 |
+
|
208 |
+
except Exception as e:
|
209 |
+
print(f"[-] Training failed: {e}")
|
210 |
+
import traceback
|
211 |
+
traceback.print_exc()
|
212 |
+
raise
|
213 |
+
|
214 |
+
|
215 |
+
def main():
|
216 |
+
"""Main training function"""
|
217 |
+
parser = argparse.ArgumentParser(description='Train DeepSeek model on children\'s stories')
|
218 |
+
|
219 |
+
# Model configuration
|
220 |
+
parser.add_argument('--n-layer', type=int, default=6, help='Number of layers')
|
221 |
+
parser.add_argument('--n-head', type=int, default=8, help='Number of attention heads')
|
222 |
+
parser.add_argument('--n-embd', type=int, default=512, help='Embedding dimension')
|
223 |
+
parser.add_argument('--block-size', type=int, default=1024, help='Context window size')
|
224 |
+
|
225 |
+
# Training configuration
|
226 |
+
parser.add_argument('--batch-size', type=int, default=12, help='Batch size')
|
227 |
+
parser.add_argument('--max-iters', type=int, default=20000, help='Maximum iterations')
|
228 |
+
parser.add_argument('--learning-rate', type=float, default=6e-4, help='Learning rate')
|
229 |
+
parser.add_argument('--eval-interval', type=int, default=1000, help='Evaluation interval')
|
230 |
+
parser.add_argument('--eval-iters', type=int, default=200, help='Number of evaluation iterations')
|
231 |
+
parser.add_argument('--weight-decay', type=float, default=0.1, help='Weight decay')
|
232 |
+
parser.add_argument('--warmup-iters', type=int, default=2000, help='Warmup iterations')
|
233 |
+
parser.add_argument('--lr-decay-iters', type=int, default=20000, help='Learning rate decay iterations')
|
234 |
+
parser.add_argument('--min-lr', type=float, default=6e-5, help='Minimum learning rate')
|
235 |
+
|
236 |
+
# Advanced features
|
237 |
+
parser.add_argument('--moe-experts', type=int, default=4, help='Number of MoE experts')
|
238 |
+
parser.add_argument('--multi-token', type=int, default=2, help='Multi-token prediction')
|
239 |
+
parser.add_argument('--no-compile', action='store_true', help='Disable model compilation')
|
240 |
+
parser.add_argument('--no-mixed-precision', action='store_true', help='Disable mixed precision')
|
241 |
+
|
242 |
+
# Resume training
|
243 |
+
parser.add_argument('--resume', type=str, help='Resume from checkpoint')
|
244 |
+
|
245 |
+
args = parser.parse_args()
|
246 |
+
|
247 |
+
# Create configuration
|
248 |
+
config = TrainingConfig(
|
249 |
+
n_layer=args.n_layer,
|
250 |
+
n_head=args.n_head,
|
251 |
+
n_embd=args.n_embd,
|
252 |
+
block_size=args.block_size,
|
253 |
+
batch_size=args.batch_size,
|
254 |
+
max_iters=args.max_iters,
|
255 |
+
learning_rate=args.learning_rate,
|
256 |
+
eval_interval=args.eval_interval,
|
257 |
+
eval_iters=args.eval_iters,
|
258 |
+
weight_decay=args.weight_decay,
|
259 |
+
warmup_iters=args.warmup_iters,
|
260 |
+
lr_decay_iters=args.lr_decay_iters,
|
261 |
+
min_lr=args.min_lr,
|
262 |
+
moe_num_experts=args.moe_experts,
|
263 |
+
multi_token_predict=args.multi_token,
|
264 |
+
compile_model=not args.no_compile,
|
265 |
+
use_mixed_precision=not args.no_mixed_precision
|
266 |
+
)
|
267 |
+
|
268 |
+
print("DeepSeek Children's Stories Training")
|
269 |
+
print("=" * 50)
|
270 |
+
print(f"Configuration:")
|
271 |
+
print(f" - Model: {config.n_layer}L/{config.n_head}H/{config.n_embd}D")
|
272 |
+
print(f" - MoE: {config.moe_num_experts} experts")
|
273 |
+
print(f" - Multi-token: {config.multi_token_predict}")
|
274 |
+
print(f" - Batch size: {config.batch_size}")
|
275 |
+
print(f" - Max iterations: {config.max_iters}")
|
276 |
+
print(f" - Learning rate: {config.learning_rate}")
|
277 |
+
print(f" - Weight decay: {config.weight_decay}")
|
278 |
+
print(f" - Warmup iterations: {config.warmup_iters}")
|
279 |
+
print(f" - LR decay iterations: {config.lr_decay_iters}")
|
280 |
+
print(f" - Min learning rate: {config.min_lr}")
|
281 |
+
print("=" * 50)
|
282 |
+
|
283 |
+
# Setup environment
|
284 |
+
setup_environment()
|
285 |
+
|
286 |
+
# Prepare data
|
287 |
+
data_files = prepare_data()
|
288 |
+
|
289 |
+
# Create model
|
290 |
+
model = create_model(config)
|
291 |
+
|
292 |
+
# Resume from checkpoint if specified
|
293 |
+
if args.resume:
|
294 |
+
print(f"Resuming from checkpoint: {args.resume}")
|
295 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
296 |
+
model.load_state_dict(checkpoint['model'])
|
297 |
+
print("Checkpoint loaded successfully!")
|
298 |
+
|
299 |
+
# Train model
|
300 |
+
train_model(model, config)
|
301 |
+
|
302 |
+
print("Training completed successfully!")
|
303 |
+
print("Best model saved to: checkpoints/best_model.pt")
|
304 |
+
|
305 |
+
|
306 |
+
if __name__ == "__main__":
|
307 |
+
main()
|
src/training/__pycache__/trainer.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
src/training/trainer.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DeepSeek Trainer for Children's Stories
|
3 |
+
Advanced training with MLA, MoE, and multi-token prediction
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import os
|
12 |
+
import datetime
|
13 |
+
import time
|
14 |
+
import shutil
|
15 |
+
import psutil
|
16 |
+
import math
|
17 |
+
import gc
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from torch.utils.data.distributed import DistributedSampler
|
21 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
22 |
+
from torch.distributed import init_process_group, destroy_process_group
|
23 |
+
from typing import Dict, List, Optional, Tuple
|
24 |
+
|
25 |
+
class DeepSeekTrainer:
|
26 |
+
def __init__(self, model, optimizer, device, batch_size, max_iters, eval_interval,
|
27 |
+
eval_iters, learning_rate, weight_decay, warmup_iters, lr_decay_iters,
|
28 |
+
min_lr, checkpoint_dir='checkpoints', use_mixed_precision=True):
|
29 |
+
self.model = model
|
30 |
+
self.optimizer = optimizer
|
31 |
+
self.device = device
|
32 |
+
self.batch_size = batch_size
|
33 |
+
self.max_iters = max_iters
|
34 |
+
self.eval_interval = eval_interval
|
35 |
+
self.eval_iters = eval_iters
|
36 |
+
self.learning_rate = learning_rate
|
37 |
+
self.weight_decay = weight_decay
|
38 |
+
self.warmup_iters = warmup_iters
|
39 |
+
self.lr_decay_iters = lr_decay_iters
|
40 |
+
self.min_lr = min_lr
|
41 |
+
self.checkpoint_dir = checkpoint_dir
|
42 |
+
self.use_mixed_precision = use_mixed_precision
|
43 |
+
self.best_loss = float('inf')
|
44 |
+
|
45 |
+
# Training state
|
46 |
+
self.current_iter = 0
|
47 |
+
self.train_losses = []
|
48 |
+
self.val_losses = []
|
49 |
+
self.learning_rates = []
|
50 |
+
|
51 |
+
# Create checkpoint directory if it doesn't exist
|
52 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
53 |
+
|
54 |
+
# Initialize gradient scaler for mixed precision training
|
55 |
+
if use_mixed_precision and device == 'cuda':
|
56 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
57 |
+
else:
|
58 |
+
self.scaler = None
|
59 |
+
|
60 |
+
# Initialize training metrics
|
61 |
+
self.metrics = {
|
62 |
+
'train_loss': [],
|
63 |
+
'val_loss': [],
|
64 |
+
'learning_rates': [],
|
65 |
+
'grad_norm': [],
|
66 |
+
'memory_usage': [],
|
67 |
+
'moe_aux_loss': [],
|
68 |
+
'multi_token_loss': []
|
69 |
+
}
|
70 |
+
|
71 |
+
# Load data
|
72 |
+
self.data = self.load_data()
|
73 |
+
self.n = len(self.data)
|
74 |
+
|
75 |
+
def load_data(self):
|
76 |
+
"""Load the training data"""
|
77 |
+
try:
|
78 |
+
data_file = os.path.join('src', 'data', 'train.bin')
|
79 |
+
if not os.path.exists(data_file):
|
80 |
+
raise FileNotFoundError(f"Training data file not found at {data_file}")
|
81 |
+
|
82 |
+
# Load data as numpy array first
|
83 |
+
data = np.memmap(data_file, dtype=np.uint16, mode='r')
|
84 |
+
# Convert to tensor
|
85 |
+
data = torch.from_numpy(data.copy()) # Make a copy to ensure it's writable
|
86 |
+
return data
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Error loading data: {str(e)}")
|
89 |
+
raise
|
90 |
+
|
91 |
+
def get_batch(self, split):
|
92 |
+
"""Get a batch of data"""
|
93 |
+
try:
|
94 |
+
# Generate random indices
|
95 |
+
ix = torch.randint(len(self.data) - self.model.config.block_size, (self.batch_size,))
|
96 |
+
|
97 |
+
# Get input sequences
|
98 |
+
x = torch.stack([self.data[i:i+self.model.config.block_size].long() for i in ix])
|
99 |
+
# Get target sequences (shifted by 1)
|
100 |
+
y = torch.stack([self.data[i+1:i+1+self.model.config.block_size].long() for i in ix])
|
101 |
+
|
102 |
+
# Move to device
|
103 |
+
x, y = x.to(self.device), y.to(self.device)
|
104 |
+
return x, y
|
105 |
+
except Exception as e:
|
106 |
+
print(f"Error in get_batch: {str(e)}")
|
107 |
+
raise
|
108 |
+
|
109 |
+
def get_lr(self, it):
|
110 |
+
"""Get learning rate for current iteration"""
|
111 |
+
# 1) linear warmup for warmup_iters steps
|
112 |
+
if it < self.warmup_iters:
|
113 |
+
return self.learning_rate * it / self.warmup_iters
|
114 |
+
# 2) if it > lr_decay_iters, return min learning rate
|
115 |
+
if it > self.lr_decay_iters:
|
116 |
+
return self.min_lr
|
117 |
+
# 3) in between, use cosine decay down to min learning rate
|
118 |
+
decay_ratio = (it - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
|
119 |
+
assert 0 <= decay_ratio <= 1
|
120 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
|
121 |
+
return self.min_lr + coeff * (self.learning_rate - self.min_lr)
|
122 |
+
|
123 |
+
def estimate_loss(self):
|
124 |
+
"""Estimate loss on validation set"""
|
125 |
+
out = {}
|
126 |
+
self.model.eval()
|
127 |
+
for split in ['train', 'val']:
|
128 |
+
losses = torch.zeros(self.eval_iters)
|
129 |
+
for k in range(self.eval_iters):
|
130 |
+
try:
|
131 |
+
X, Y = self.get_batch(split)
|
132 |
+
with torch.no_grad():
|
133 |
+
if self.scaler is not None:
|
134 |
+
with torch.cuda.amp.autocast():
|
135 |
+
logits, loss = self.model(X, Y)
|
136 |
+
else:
|
137 |
+
logits, loss = self.model(X, Y)
|
138 |
+
losses[k] = loss.item()
|
139 |
+
except Exception as e:
|
140 |
+
print(f"Error during evaluation: {str(e)}")
|
141 |
+
continue
|
142 |
+
out[split] = losses.mean()
|
143 |
+
self.model.train()
|
144 |
+
return out
|
145 |
+
|
146 |
+
def check_disk_space(self, required_space_mb=1000):
|
147 |
+
"""Check if there's enough disk space for saving the model"""
|
148 |
+
try:
|
149 |
+
# Get disk usage statistics
|
150 |
+
disk_usage = psutil.disk_usage('/')
|
151 |
+
free_space_mb = disk_usage.free / (1024 * 1024) # Convert to MB
|
152 |
+
|
153 |
+
if free_space_mb < required_space_mb:
|
154 |
+
print(f"Warning: Low disk space. Only {free_space_mb:.2f}MB free, {required_space_mb}MB required")
|
155 |
+
return False
|
156 |
+
return True
|
157 |
+
except Exception as e:
|
158 |
+
print(f"Warning: Could not check disk space: {e}")
|
159 |
+
return True # Continue anyway if we can't check
|
160 |
+
|
161 |
+
def save_checkpoint(self, iter_num, loss, is_best=False):
|
162 |
+
"""Save model checkpoint"""
|
163 |
+
try:
|
164 |
+
checkpoint = {
|
165 |
+
'model': self.model.state_dict(),
|
166 |
+
'optimizer': self.optimizer.state_dict(),
|
167 |
+
'iter_num': iter_num,
|
168 |
+
'loss': loss,
|
169 |
+
'config': self.model.config,
|
170 |
+
'train_losses': self.train_losses,
|
171 |
+
'val_losses': self.val_losses,
|
172 |
+
'learning_rates': self.learning_rates,
|
173 |
+
'metrics': self.metrics,
|
174 |
+
'best_loss': self.best_loss
|
175 |
+
}
|
176 |
+
checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_{iter_num}.pt')
|
177 |
+
torch.save(checkpoint, checkpoint_path)
|
178 |
+
|
179 |
+
if is_best:
|
180 |
+
best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
|
181 |
+
torch.save(checkpoint, best_path)
|
182 |
+
print(f"Saved best model with loss {loss:.4f}")
|
183 |
+
|
184 |
+
print(f"Saved checkpoint to {checkpoint_path}")
|
185 |
+
return True
|
186 |
+
except Exception as e:
|
187 |
+
print(f"Error saving checkpoint: {str(e)}")
|
188 |
+
return False
|
189 |
+
|
190 |
+
def load_checkpoint(self, checkpoint_path):
|
191 |
+
"""Load model checkpoint with error handling"""
|
192 |
+
try:
|
193 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
194 |
+
self.model.load_state_dict(checkpoint['model'])
|
195 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
196 |
+
self.current_iter = checkpoint['iter_num']
|
197 |
+
self.best_loss = checkpoint['loss']
|
198 |
+
self.train_losses = checkpoint.get('train_losses', [])
|
199 |
+
self.val_losses = checkpoint.get('val_losses', [])
|
200 |
+
self.learning_rates = checkpoint.get('learning_rates', [])
|
201 |
+
self.metrics = checkpoint.get('metrics', self.metrics)
|
202 |
+
print(f"Successfully loaded checkpoint from iteration {self.current_iter}")
|
203 |
+
return True
|
204 |
+
except Exception as e:
|
205 |
+
print(f"Error loading checkpoint: {e}")
|
206 |
+
return False
|
207 |
+
|
208 |
+
def train(self):
|
209 |
+
"""Train the DeepSeek model"""
|
210 |
+
print(f"DeepSeek Training started at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
211 |
+
print(f"Model: {self.model.config.n_layer} layers, {self.model.config.n_head} heads, {self.model.config.n_embd} dims")
|
212 |
+
print(f"MLA: {self.model.config.mla_kv_heads} KV heads, MoE: {self.model.config.moe_num_experts} experts")
|
213 |
+
print(f"Multi-token prediction: {self.model.config.multi_token_predict} tokens")
|
214 |
+
start_time = time.time()
|
215 |
+
|
216 |
+
try:
|
217 |
+
# Initialize training
|
218 |
+
X, Y = self.get_batch('train')
|
219 |
+
best_loss = float('inf')
|
220 |
+
current_loss = None
|
221 |
+
|
222 |
+
for iter_num in range(self.current_iter, self.max_iters):
|
223 |
+
self.current_iter = iter_num
|
224 |
+
|
225 |
+
# Determine and set the learning rate for this iteration
|
226 |
+
lr = self.get_lr(iter_num)
|
227 |
+
for param_group in self.optimizer.param_groups:
|
228 |
+
param_group['lr'] = lr
|
229 |
+
|
230 |
+
# Forward pass with mixed precision
|
231 |
+
if self.scaler is not None:
|
232 |
+
with torch.cuda.amp.autocast():
|
233 |
+
logits, loss = self.model(X, Y)
|
234 |
+
else:
|
235 |
+
logits, loss = self.model(X, Y)
|
236 |
+
|
237 |
+
# Backward pass
|
238 |
+
if self.scaler is not None:
|
239 |
+
self.scaler.scale(loss).backward()
|
240 |
+
self.scaler.step(self.optimizer)
|
241 |
+
self.scaler.update()
|
242 |
+
else:
|
243 |
+
loss.backward()
|
244 |
+
self.optimizer.step()
|
245 |
+
|
246 |
+
self.optimizer.zero_grad(set_to_none=True)
|
247 |
+
|
248 |
+
# Get new batch
|
249 |
+
X, Y = self.get_batch('train')
|
250 |
+
|
251 |
+
# Track metrics
|
252 |
+
current_loss = loss.item()
|
253 |
+
self.train_losses.append(current_loss)
|
254 |
+
self.learning_rates.append(lr)
|
255 |
+
|
256 |
+
# Update best loss
|
257 |
+
if current_loss < best_loss:
|
258 |
+
best_loss = current_loss
|
259 |
+
|
260 |
+
# Evaluation
|
261 |
+
if iter_num % self.eval_interval == 0:
|
262 |
+
losses = self.estimate_loss()
|
263 |
+
self.val_losses.append(losses['val'])
|
264 |
+
|
265 |
+
# Save checkpoint if it's the best so far
|
266 |
+
if losses['val'] < self.best_loss:
|
267 |
+
self.best_loss = losses['val']
|
268 |
+
self.save_checkpoint(iter_num, losses['val'], is_best=True)
|
269 |
+
|
270 |
+
# Regular checkpoint saving
|
271 |
+
if iter_num % (self.eval_interval * 5) == 0:
|
272 |
+
self.save_checkpoint(iter_num, losses['val'])
|
273 |
+
|
274 |
+
# Print progress
|
275 |
+
elapsed = time.time() - start_time
|
276 |
+
print(f"iter {iter_num}: train_loss {current_loss:.4f}, val_loss {losses['val']:.4f}, "
|
277 |
+
f"lr {lr:.2e}, time {elapsed:.2f}s")
|
278 |
+
|
279 |
+
# Memory usage
|
280 |
+
if self.device == 'cuda':
|
281 |
+
memory_used = torch.cuda.memory_allocated() / 1024**3
|
282 |
+
print(f"GPU memory: {memory_used:.2f} GB")
|
283 |
+
|
284 |
+
# Memory cleanup
|
285 |
+
if iter_num % 100 == 0:
|
286 |
+
gc.collect()
|
287 |
+
if self.device == 'cuda':
|
288 |
+
torch.cuda.empty_cache()
|
289 |
+
|
290 |
+
# Final checkpoint
|
291 |
+
self.save_checkpoint(self.max_iters, current_loss)
|
292 |
+
|
293 |
+
# Plot training metrics
|
294 |
+
self.plot_metrics()
|
295 |
+
|
296 |
+
print(f"Training completed in {time.time() - start_time:.2f} seconds")
|
297 |
+
|
298 |
+
except Exception as e:
|
299 |
+
print(f"Error during training: {str(e)}")
|
300 |
+
# Save emergency checkpoint
|
301 |
+
if current_loss is not None:
|
302 |
+
self.save_checkpoint(self.current_iter, current_loss)
|
303 |
+
raise
|
304 |
+
|
305 |
+
def plot_losses(self, train_losses, val_losses):
|
306 |
+
"""Plot training and validation losses"""
|
307 |
+
plt.figure(figsize=(12, 4))
|
308 |
+
|
309 |
+
plt.subplot(1, 2, 1)
|
310 |
+
plt.plot(train_losses, label='Training Loss')
|
311 |
+
plt.plot(val_losses, label='Validation Loss')
|
312 |
+
plt.title('Training and Validation Loss')
|
313 |
+
plt.xlabel('Iteration')
|
314 |
+
plt.ylabel('Loss')
|
315 |
+
plt.legend()
|
316 |
+
plt.grid(True)
|
317 |
+
|
318 |
+
plt.subplot(1, 2, 2)
|
319 |
+
plt.plot(self.learning_rates)
|
320 |
+
plt.title('Learning Rate Schedule')
|
321 |
+
plt.xlabel('Iteration')
|
322 |
+
plt.ylabel('Learning Rate')
|
323 |
+
plt.grid(True)
|
324 |
+
|
325 |
+
plt.tight_layout()
|
326 |
+
plt.savefig('training_metrics.png', dpi=300, bbox_inches='tight')
|
327 |
+
plt.close()
|
328 |
+
|
329 |
+
def plot_metrics(self):
|
330 |
+
"""Plot comprehensive training metrics"""
|
331 |
+
if not self.train_losses or not self.val_losses:
|
332 |
+
print("No metrics to plot")
|
333 |
+
return
|
334 |
+
|
335 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
336 |
+
|
337 |
+
# Training and validation loss
|
338 |
+
axes[0, 0].plot(self.train_losses, label='Training Loss', alpha=0.7)
|
339 |
+
axes[0, 0].plot(self.val_losses, label='Validation Loss', alpha=0.7)
|
340 |
+
axes[0, 0].set_title('Training and Validation Loss')
|
341 |
+
axes[0, 0].set_xlabel('Iteration')
|
342 |
+
axes[0, 0].set_ylabel('Loss')
|
343 |
+
axes[0, 0].legend()
|
344 |
+
axes[0, 0].grid(True)
|
345 |
+
|
346 |
+
# Learning rate
|
347 |
+
axes[0, 1].plot(self.learning_rates)
|
348 |
+
axes[0, 1].set_title('Learning Rate Schedule')
|
349 |
+
axes[0, 1].set_xlabel('Iteration')
|
350 |
+
axes[0, 1].set_ylabel('Learning Rate')
|
351 |
+
axes[0, 1].grid(True)
|
352 |
+
|
353 |
+
# Memory usage
|
354 |
+
if self.metrics['memory_usage']:
|
355 |
+
axes[1, 0].plot(self.metrics['memory_usage'])
|
356 |
+
axes[1, 0].set_title('GPU Memory Usage')
|
357 |
+
axes[1, 0].set_xlabel('Iteration')
|
358 |
+
axes[1, 0].set_ylabel('Memory (GB)')
|
359 |
+
axes[1, 0].grid(True)
|
360 |
+
|
361 |
+
# Gradient norm
|
362 |
+
if self.metrics['grad_norm']:
|
363 |
+
axes[1, 1].plot(self.metrics['grad_norm'])
|
364 |
+
axes[1, 1].set_title('Gradient Norm')
|
365 |
+
axes[1, 1].set_xlabel('Iteration')
|
366 |
+
axes[1, 1].set_ylabel('Norm')
|
367 |
+
axes[1, 1].grid(True)
|
368 |
+
|
369 |
+
plt.tight_layout()
|
370 |
+
plt.savefig('deepseek_training_metrics.png', dpi=300, bbox_inches='tight')
|
371 |
+
plt.close()
|
372 |
+
|
373 |
+
print("Training metrics saved to deepseek_training_metrics.png")
|
374 |
+
|
375 |
+
|
376 |
+
def create_deepseek_trainer(model, config):
|
377 |
+
"""Create a DeepSeek trainer with the given configuration"""
|
378 |
+
# Optimizer
|
379 |
+
optimizer = torch.optim.AdamW(
|
380 |
+
model.parameters(),
|
381 |
+
lr=config.learning_rate,
|
382 |
+
weight_decay=config.weight_decay,
|
383 |
+
betas=(0.9, 0.95)
|
384 |
+
)
|
385 |
+
|
386 |
+
# Device
|
387 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
388 |
+
model = model.to(device)
|
389 |
+
|
390 |
+
# Trainer
|
391 |
+
trainer = DeepSeekTrainer(
|
392 |
+
model=model,
|
393 |
+
optimizer=optimizer,
|
394 |
+
device=device,
|
395 |
+
batch_size=config.batch_size,
|
396 |
+
max_iters=config.max_iters,
|
397 |
+
eval_interval=config.eval_interval,
|
398 |
+
eval_iters=config.eval_iters,
|
399 |
+
learning_rate=config.learning_rate,
|
400 |
+
weight_decay=config.weight_decay,
|
401 |
+
warmup_iters=config.warmup_iters,
|
402 |
+
lr_decay_iters=config.lr_decay_iters,
|
403 |
+
min_lr=config.min_lr,
|
404 |
+
checkpoint_dir=config.checkpoint_dir,
|
405 |
+
use_mixed_precision=config.use_mixed_precision
|
406 |
+
)
|
407 |
+
|
408 |
+
return trainer
|
training_metrics.png
ADDED
![]() |
Git LFS Details
|