ThomasTheMaker commited on
Commit
01ae771
·
verified ·
1 Parent(s): 4620167

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ deepseek-arch.png filter=lfs diff=lfs merge=lfs -text
37
+ deepseek_training_metrics.png filter=lfs diff=lfs merge=lfs -text
38
+ training_metrics.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 IdeaWeaver AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek-Children-Stories
2
+
3
+ A state-of-the-art DeepSeek model optimized for children's story generation, featuring advanced architecture with just ~15-18M parameters.
4
+
5
+ ## Architecture Highlights
6
+
7
+ ![DeepSeek Architecture](deepseek-arch.png)
8
+
9
+ - **Multihead Latent Attention (MLA)** - DeepSeek's efficient attention mechanism
10
+ - **Mixture of Experts (MoE)** - 4 experts with top-2 routing for increased capacity
11
+ - **Multi-token Prediction** - Predicts next 2 tokens simultaneously for efficiency
12
+ - **Rotary Positional Encodings (RoPE)** - Better position understanding
13
+
14
+ ## Model Specifications
15
+
16
+ - **Parameters**: ~15-18M (6 layers, 8 heads, 512 embedding dim)
17
+ - **Context Window**: 1024 tokens
18
+ - **Vocabulary**: GPT-2 compatible (50,257 tokens)
19
+ - **Training Data**: 2,000+ children's stories from Hugging Face
20
+
21
+ ## Hardware Used
22
+
23
+ Training was performed on the following hardware:
24
+
25
+ - **GPU**: NVIDIA RTX 4090 (24 GB VRAM)
26
+ - **RAM**: 41 GB
27
+ - **CPU**: 6 vCPU
28
+
29
+ ## Quick Start
30
+
31
+ ### Installation
32
+
33
+ ```bash
34
+ # Clone the repository
35
+ git clone https://github.com/ideaweaver-ai/DeepSeek-Children-Stories-15M-model.git
36
+ cd DeepSeek-Children-Stories-15M-model
37
+
38
+ # Install dependencies
39
+ pip install -r requirements.txt
40
+
41
+ # Setup the environment
42
+ chmod +x setup.sh
43
+ ./setup.sh
44
+ ```
45
+
46
+ ### Training
47
+
48
+ ```bash
49
+ # Start training
50
+ python src/run_training.py
51
+
52
+ # With custom parameters
53
+ python src/run_training.py --batch-size 8 --max-iters 10000 --learning-rate 6e-4
54
+ ```
55
+
56
+ ### Generation
57
+
58
+ ```bash
59
+ # Generate stories
60
+ python src/generate.py --prompt "Once upon a time, there was a brave little mouse"
61
+
62
+ # With custom parameters
63
+ python src/generate.py --prompt "A magical forest adventure" --max-tokens 200 --temperature 0.8
64
+ ```
65
+
66
+ ## 📖 Example Output
67
+
68
+ Here's an example of a story generated by the model:
69
+
70
+ **Prompt**: "Once upon a time"
71
+
72
+ **Generated Story**:
73
+ ```
74
+ it was a bright, sunny day, and lily and her little brother max were playing in their backyard. they found a piece of paper with two sentence written on it. "let's make sense of some of these sentences," said max, pointing to the first sentence. "these people are playing on the grass," "but i don't know," replied lily. she thought for a moment. "maybe they only talk with the others or not, right?" she asked. max nodded. "yeah, and what about 'he', 'he', 'an', 'man', and 'man'?" lily explained, "it means they're playing with their dogs. but they don't say anything about someone talking." max asked, "but what about the others? we don't talk to each other!" lily thought for a moment before answering, "that's right! sometimes, people try to talk to each other. when we talk about something, we need to tell others
75
+ ```
76
+
77
+ ## Training Metrics
78
+
79
+ <p align="center">
80
+ <img src="training_metrics.png" alt="Training and Validation Loss and Learning Rate" width="800"/>
81
+ </p>
82
+
83
+ ## Configuration
84
+
85
+ The model can be configured through command-line arguments:
86
+
87
+ ```bash
88
+ # Model configuration
89
+ --n-layer 6 # Number of transformer layers
90
+ --n-head 8 # Number of attention heads
91
+ --n-embd 512 # Embedding dimension
92
+ --block-size 1024 # Context window size
93
+
94
+ # Training configuration
95
+ --batch-size 12 # Batch size
96
+ --max-iters 20000 # Maximum training iterations
97
+ --learning-rate 6e-4 # Learning rate
98
+ --eval-interval 1000 # Evaluation interval
99
+
100
+ # Advanced features
101
+ --moe-experts 4 # Number of MoE experts
102
+ --multi-token 2 # Multi-token prediction
103
+ ```
104
+
105
+ ## 🤗 Model Available on Hugging Face
106
+
107
+ The trained model is now available on Hugging Face Hub! You can use it directly:
108
+
109
+ **Model**: [lakhera2023/deepseek-children-stories](https://huggingface.co/lakhera2023/deepseek-children-stories)
110
+
111
+ ## Features
112
+
113
+ ### Advanced Architecture
114
+ - **MLA**: Efficient attention with shared key-value heads
115
+ - **MoE**: Mixture of experts for increased model capacity
116
+ - **Multi-token Prediction**: Simultaneous prediction of multiple tokens
117
+ - **RoPE**: Rotary positional encodings for better position understanding
118
+
119
+ ### Training Optimizations
120
+ - Mixed precision training with gradient scaling
121
+ - PyTorch 2.0 compilation for speed
122
+ - Automatic checkpointing and model saving
123
+ - MoE auxiliary loss for load balancing
124
+
125
+ ### Story Generation
126
+ - Creative and engaging children's stories
127
+ - Moral lessons and educational content
128
+ - Age-appropriate language and themes
129
+ - Consistent character development
130
+
131
+ ## Performance
132
+
133
+ The model achieves:
134
+ - Efficient training with ~2.24GB GPU memory usage
135
+ - Fast inference for real-time story generation
136
+ - High-quality output suitable for children
137
+ - Scalable architecture for different use cases
138
+
139
+
140
+ ## Contributing
141
+
142
+ Contributions are welcome! Please feel free to submit a Pull Request.
143
+
144
+ ## License
145
+
146
+ This project is licensed under the MIT License - see the LICENSE file for details.
147
+
148
+ ## Acknowledgments
149
+
150
+ - DeepSeek team for the original architecture
151
+ - Hugging Face for the children's stories dataset
152
+ - PyTorch team for the excellent framework
153
+
154
+ ## Links
155
+
156
+ - **GitHub**: https://github.com/ideaweaver-ai/DeepSeek-Children-Stories-15M-model
157
+
158
+ ---
159
+
160
+ ⭐ **Star this repository if you think Advanced Architecture + Tiny Models can do Big Things!**
checkpoints/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ed045ca7f558caa89ae9345afc8f279d85838fd88ee066525d3e0497c2b1903
3
+ size 943196083
checkpoints/checkpoint_0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:596d20b3003def973b7ed263f8a3bc8178d4527182e3d8b3d73b9f08a880487a
3
+ size 942850634
checkpoints/checkpoint_1000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef3cd05895dab3ad7b9605e99fd44437886e504064074283f1c01d35d4bdde2c
3
+ size 942871055
checkpoints/checkpoint_10000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e932a30f92cf27cc5cc90220c11cca2a9ba26ca662e1dcc9fbe0c9e56947ea8
3
+ size 943036196
checkpoints/checkpoint_11000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:741587b23e54005aa59122359d1645f9647509283315b1e688cecaa244e14d37
3
+ size 943054443
checkpoints/checkpoint_12000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68e0cfb7e5e1c841ddb4f63a4e35f609ed3ae33a71b2fcae3a42713807716288
3
+ size 943072690
checkpoints/checkpoint_13000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8faec5132a9233ae661cf85447612451692fd568f55b77a331ef372ce7298b1d
3
+ size 943091001
checkpoints/checkpoint_14000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21000abe6d1ce61af34a3a348379b75ceb32c5b3e34934b6bf7f0f56900d57da
3
+ size 943109248
checkpoints/checkpoint_15000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:989c4f47419fbdb682d798c440eb8fa657f28717fe46a8a977cff5438daf6a32
3
+ size 943127495
checkpoints/checkpoint_16000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d821499c970ea1176632b80fa534c13cd941991e7219a2e4b98d443f733ea8d9
3
+ size 943145742
checkpoints/checkpoint_17000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4239cfcac50b335caba1be84fd182100b30a43e199558bdf4320eb6a569a355
3
+ size 943164053
checkpoints/checkpoint_18000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a95a2b08118aafbe1649a94cdcd25c092c9e3337fac55ffb176af1c829a997cd
3
+ size 943182300
checkpoints/checkpoint_19000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc22f5fc4bbfce05ab2b422929f31666dfc5536e8ffe6cf066ae1aa64feb3215
3
+ size 943200547
checkpoints/checkpoint_2000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7031946ec00c7c3433b130ee42a593a8295bb6f9036bef87fa922ce8ef7d3c0
3
+ size 942889365
checkpoints/checkpoint_20000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f2abb8b5c12e411ecc662223dcc00ebb8e76bb0e492bf3782d03b4a5f0d5298
3
+ size 943218531
checkpoints/checkpoint_3000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e676cb28f0aa2fe17da68bb5df5af422601606c11141e56bd26dfe126238971
3
+ size 942907611
checkpoints/checkpoint_4000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f363f026b391e07f1610299450a74a32e217a782cb867a4b4307bb2a7d75559
3
+ size 942925857
checkpoints/checkpoint_5000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20a34a77821367b9708831979f89af4f32b45bac592466cc04e778ee99a973b5
3
+ size 942944103
checkpoints/checkpoint_6000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:935b2adcd4575c3ad9732e898afdff258fd36a7571b7e97ca2796d116526ee0a
3
+ size 942962413
checkpoints/checkpoint_7000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:390a641bc6d086b41d16c8ec6888fc79f2a7970525f6192528ea35a1b4f44728
3
+ size 942980659
checkpoints/checkpoint_8000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90b2614013ed2c2df7078052867a9c19bdd36acc386f3fa2ed17e3013bdbdfb4
3
+ size 942998905
checkpoints/checkpoint_9000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cba9b64ec3237fe1c33366159a883c344d103d35b4cced89d514617f89278c2
3
+ size 943017151
checkpoints/final_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62e5cf559d3abdb97036c62c7daa25e61256d98f6cbab403d692ca648db2e078
3
+ size 942849715
deepseek-arch.png ADDED

Git LFS Details

  • SHA256: 5447c3856700c59997bb6e8463fb50b6e577c3d75db93c5c291a9a6af26ab32d
  • Pointer size: 131 Bytes
  • Size of remote file: 123 kB
deepseek_training_metrics.png ADDED

Git LFS Details

  • SHA256: d4034df4cf3c8bb9f9c463ef6b884aa4724d6bb27a7b2fa82b9ea16a8aeb8f7d
  • Pointer size: 131 Bytes
  • Size of remote file: 288 kB
process_data.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Add the src directory to Python path
5
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
6
+
7
+ from data.data_processor import DeepSeekDataProcessor
8
+
9
+ def main():
10
+ print("[+] Processing dataset into binary files...")
11
+ processor = DeepSeekDataProcessor()
12
+ processor.prepare_dataset()
13
+ print("[+] Data processing completed successfully!")
14
+
15
+ if __name__ == "__main__":
16
+ main()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.30.0
3
+ datasets>=2.12.0
4
+ tiktoken>=0.5.0
5
+ numpy>=1.20.0
6
+ tqdm>=4.65.0
7
+ matplotlib>=3.5.0
8
+ peft>=0.4.0
9
+ accelerate>=0.20.0
10
+ bitsandbytes>=0.41.0
11
+ huggingface_hub>=0.16.0
12
+ wandb>=0.15.0
13
+ psutil>=5.8.0
setup.sh ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Colors for output
4
+ GREEN='\033[0;32m'
5
+ RED='\033[0;31m'
6
+ YELLOW='\033[1;33m'
7
+ BLUE='\033[0;34m'
8
+ NC='\033[0m' # No Color
9
+
10
+ # Default configuration
11
+ PROJECT_ROOT="${PROJECT_ROOT:-$(pwd)}"
12
+ VENV_PATH="${VENV_PATH:-${PROJECT_ROOT}/venv}"
13
+ CHECKPOINT_DIR="${CHECKPOINT_DIR:-${PROJECT_ROOT}/checkpoints}"
14
+ LORA_CHECKPOINT_DIR="${LORA_CHECKPOINT_DIR:-${PROJECT_ROOT}/lora_checkpoints}"
15
+ REQUIRED_SPACE_MB="${REQUIRED_SPACE_MB:-2000}"
16
+
17
+ # Function to print status messages
18
+ print_status() {
19
+ echo -e "${GREEN}[+] $1${NC}"
20
+ }
21
+
22
+ print_error() {
23
+ echo -e "${RED}[-] $1${NC}"
24
+ }
25
+
26
+ print_warning() {
27
+ echo -e "${YELLOW}[!] $1${NC}"
28
+ }
29
+
30
+ print_info() {
31
+ echo -e "${BLUE}[i] $1${NC}"
32
+ }
33
+
34
+ # Function to handle errors
35
+ handle_error() {
36
+ print_error "$1"
37
+ exit 1
38
+ }
39
+
40
+ # Function to check if a command exists
41
+ command_exists() {
42
+ command -v "$1" &> /dev/null
43
+ }
44
+
45
+ # Function to check disk space
46
+ check_disk_space() {
47
+ local available_space_mb=$(df -m . | awk 'NR==2 {print $4}')
48
+ if [ "$available_space_mb" -lt "$REQUIRED_SPACE_MB" ]; then
49
+ print_warning "Low disk space. Only ${available_space_mb}MB available, ${REQUIRED_SPACE_MB}MB required."
50
+ return 1
51
+ fi
52
+ return 0
53
+ }
54
+
55
+ # Function to check GPU memory
56
+ check_gpu_memory() {
57
+ if command_exists nvidia-smi; then
58
+ local total_memory=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits)
59
+ local free_memory=$(nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits)
60
+ local used_memory=$((total_memory - free_memory))
61
+ print_status "GPU Memory: ${used_memory}MB used, ${free_memory}MB free of ${total_memory}MB total"
62
+
63
+ # Check if we have enough memory for training
64
+ if [ "$free_memory" -lt 4000 ]; then
65
+ print_warning "Low GPU memory. Consider reducing batch size or model size."
66
+ fi
67
+ else
68
+ print_warning "nvidia-smi not found. GPU training may not be available."
69
+ fi
70
+ }
71
+
72
+ # Function to create project structure
73
+ create_project_structure() {
74
+ print_status "Creating project structure..."
75
+ mkdir -p "${PROJECT_ROOT}/src/data" \
76
+ "${PROJECT_ROOT}/src/model" \
77
+ "${PROJECT_ROOT}/src/training" \
78
+ "${PROJECT_ROOT}/src/inference" \
79
+ "${CHECKPOINT_DIR}" \
80
+ "${LORA_CHECKPOINT_DIR}" || handle_error "Failed to create directories"
81
+ }
82
+
83
+ # Function to setup virtual environment
84
+ setup_virtual_env() {
85
+ print_status "Creating virtual environment..."
86
+ python3 -m venv "${VENV_PATH}" || handle_error "Failed to create virtual environment"
87
+ source "${VENV_PATH}/bin/activate" || handle_error "Failed to activate virtual environment"
88
+
89
+ print_status "Installing dependencies..."
90
+ pip install --upgrade pip
91
+ pip install -r requirements.txt || handle_error "Failed to install requirements"
92
+ }
93
+
94
+ # Function to prepare dataset
95
+ prepare_dataset() {
96
+ print_status "Preparing dataset..."
97
+ cd "${PROJECT_ROOT}" || handle_error "Failed to change to project directory"
98
+
99
+ # Create a Python script to process the data
100
+ cat > process_data.py << 'EOF'
101
+ import os
102
+ import sys
103
+
104
+ # Add the src directory to Python path
105
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
106
+
107
+ from data.data_processor import DeepSeekDataProcessor
108
+
109
+ def main():
110
+ print("[+] Processing dataset into binary files...")
111
+ processor = DeepSeekDataProcessor()
112
+ processor.prepare_dataset()
113
+ print("[+] Data processing completed successfully!")
114
+
115
+ if __name__ == "__main__":
116
+ main()
117
+ EOF
118
+
119
+ # Run the data processing script
120
+ python3 process_data.py || handle_error "Failed to process dataset"
121
+
122
+ # Verify the files were created
123
+ if [ ! -f "${PROJECT_ROOT}/src/data/train.bin" ] || [ ! -f "${PROJECT_ROOT}/src/data/validation.bin" ]; then
124
+ handle_error "Data processing failed - required files not created"
125
+ fi
126
+ }
127
+
128
+ # Function to train base model
129
+ train_base_model() {
130
+ print_status "Starting DeepSeek base model training..."
131
+ cd "${PROJECT_ROOT}" || handle_error "Failed to change to project directory"
132
+
133
+ python3 src/run_training.py \
134
+ --batch-size "${BATCH_SIZE:-12}" \
135
+ --max-iters "${MAX_ITERS:-20000}" \
136
+ --eval-interval "${EVAL_INTERVAL:-1000}" \
137
+ --eval-iters "${EVAL_ITERS:-200}" \
138
+ --learning-rate "${LEARNING_RATE:-6e-4}" \
139
+ --weight-decay "${WEIGHT_DECAY:-0.1}" \
140
+ --warmup-iters "${WARMUP_ITERS:-2000}" \
141
+ --lr-decay-iters "${LR_DECAY_ITERS:-20000}" \
142
+ --min-lr "${MIN_LR:-6e-5}" \
143
+ --moe-experts "${MOE_EXPERTS:-4}" \
144
+ --multi-token "${MULTI_TOKEN:-2}" || handle_error "Base model training failed"
145
+ }
146
+
147
+ # Function to perform LoRA finetuning
148
+ finetune_lora() {
149
+ while true; do
150
+ read -p "Do you want to perform LoRA finetuning? (y/n) " do_finetune
151
+ case $do_finetune in
152
+ [Yy]* )
153
+ print_status "Starting LoRA finetuning..."
154
+ cd "${PROJECT_ROOT}" || handle_error "Failed to change to project directory"
155
+
156
+ # Create LoRA finetuning script
157
+ cat > finetune_lora.py << 'EOF'
158
+ import torch
159
+ import os
160
+ import sys
161
+ sys.path.append('src')
162
+
163
+ from model.deepseek import DeepSeek, DeepSeekConfig
164
+ from peft import get_peft_model, LoraConfig, TaskType
165
+
166
+ def main():
167
+ print("Loading base model...")
168
+ checkpoint = torch.load('checkpoints/best_model.pt', map_location='cuda' if torch.cuda.is_available() else 'cpu')
169
+ model = DeepSeek(checkpoint['config'])
170
+ model.load_state_dict(checkpoint['model'])
171
+
172
+ # Define LoRA configuration
173
+ lora_config = LoraConfig(
174
+ task_type=TaskType.CAUSAL_LM,
175
+ r=8, # rank
176
+ lora_alpha=32,
177
+ lora_dropout=0.1,
178
+ target_modules=["q_a_proj", "q_b_proj", "kv_a_proj", "kv_b_proj"]
179
+ )
180
+
181
+ # Get PEFT model
182
+ model = get_peft_model(model, lora_config)
183
+ model.print_trainable_parameters()
184
+
185
+ print("LoRA finetuning setup complete!")
186
+
187
+ if __name__ == "__main__":
188
+ main()
189
+ EOF
190
+
191
+ python3 finetune_lora.py || handle_error "LoRA finetuning failed"
192
+ break
193
+ ;;
194
+ [Nn]* )
195
+ print_status "Skipping LoRA finetuning..."
196
+ break
197
+ ;;
198
+ * )
199
+ echo "Please answer 'y' or 'n'"
200
+ ;;
201
+ esac
202
+ done
203
+ }
204
+
205
+ # Function to test the trained model
206
+ test_model() {
207
+ while true; do
208
+ read -p "Do you want to test the trained model? (y/n) " do_test
209
+ case $do_test in
210
+ [Yy]* )
211
+ print_status "Testing the trained model..."
212
+ cd "${PROJECT_ROOT}" || handle_error "Failed to change to project directory"
213
+
214
+ # Create test prompts
215
+ prompts=(
216
+ "Once upon a time"
217
+ "In a magical forest"
218
+ "The little robot"
219
+ "The brave knight"
220
+ )
221
+
222
+ # Test each prompt
223
+ for prompt in "${prompts[@]}"; do
224
+ print_status "Testing with prompt: '$prompt'"
225
+ python3 src/generate.py \
226
+ --model-path "${CHECKPOINT_DIR}/best_model.pt" \
227
+ --prompt "$prompt" \
228
+ --max-tokens 100 \
229
+ --temperature 0.8 \
230
+ --top-k 40
231
+ echo
232
+ done
233
+ break
234
+ ;;
235
+ [Nn]* )
236
+ print_status "Skipping model testing..."
237
+ break
238
+ ;;
239
+ * )
240
+ echo "Please answer 'y' or 'n'"
241
+ ;;
242
+ esac
243
+ done
244
+ }
245
+
246
+ # Function to show usage information
247
+ show_usage() {
248
+ print_info "DeepSeek Children's Stories Model Setup Complete!"
249
+ print_info ""
250
+ print_info "Next steps:"
251
+ print_info "1. Activate virtual environment: source venv/bin/activate"
252
+ print_info "2. Train the model: python src/run_training.py"
253
+ print_info "3. Generate stories: python src/generate.py --prompt 'your prompt'"
254
+ print_info "4. Interactive mode: python src/generate.py --interactive"
255
+ print_info ""
256
+ print_info "Model files:"
257
+ print_info "- Base model: checkpoints/best_model.pt"
258
+ print_info "- LoRA model: lora_checkpoints/best_lora_model.pt"
259
+ print_info ""
260
+ print_info "Configuration options:"
261
+ print_info "- Adjust model size: --n-layer, --n-head, --n-embd"
262
+ print_info "- Training parameters: --batch-size, --learning-rate, --max-iters"
263
+ print_info "- Advanced features: --moe-experts, --multi-token"
264
+ }
265
+
266
+ # Main setup function
267
+ main() {
268
+ print_info "DeepSeek Children's Stories Model Setup"
269
+ print_info "======================================"
270
+
271
+ # Check prerequisites
272
+ if ! command_exists python3; then
273
+ handle_error "Python 3 is required but not installed"
274
+ fi
275
+
276
+ if ! command_exists pip; then
277
+ handle_error "pip is required but not installed"
278
+ fi
279
+
280
+ # Check disk space
281
+ if ! check_disk_space; then
282
+ print_warning "Continuing with low disk space..."
283
+ fi
284
+
285
+ # Check GPU
286
+ check_gpu_memory
287
+
288
+ # Create project structure
289
+ create_project_structure
290
+
291
+ # Setup virtual environment
292
+ setup_virtual_env
293
+
294
+ # Prepare dataset
295
+ prepare_dataset
296
+
297
+ # Train base model
298
+ train_base_model
299
+
300
+ # Optional LoRA finetuning
301
+ finetune_lora
302
+
303
+ # Optional model testing
304
+ test_model
305
+
306
+ # Show usage information
307
+ show_usage
308
+
309
+ print_status "Setup completed successfully!"
310
+ }
311
+
312
+ # Run main function
313
+ main "$@"
src/data/__pycache__/data_processor.cpython-310.pyc ADDED
Binary file (8.3 kB). View file
 
src/data/data_processor.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Processor for DeepSeek Children's Stories Model
3
+ Handles dataset loading, preprocessing, and tokenization for children's story generation
4
+ """
5
+
6
+ import tiktoken
7
+ import os
8
+ import numpy as np
9
+ from datasets import load_dataset
10
+ from tqdm.auto import tqdm
11
+ import torch
12
+ from typing import Dict, List, Optional
13
+
14
+ def load_encoder_decoder():
15
+ """Load the encoder and decoder for text processing"""
16
+ enc = tiktoken.get_encoding("gpt2")
17
+ return enc, enc
18
+
19
+ class DeepSeekDataProcessor:
20
+ def __init__(self, config=None):
21
+ # Initialize tokenizer with GPT-2 encoding
22
+ self.enc = tiktoken.get_encoding("gpt2")
23
+
24
+ # Special tokens for story structure (optimized for children's stories)
25
+ self.special_tokens = {
26
+ "story_start": "<|story|>",
27
+ "story_end": "</|story|>",
28
+ "prompt_start": "<|prompt|>",
29
+ "prompt_end": "</|prompt|>",
30
+ "moral_start": "<|moral|>",
31
+ "moral_end": "</|moral|>",
32
+ "character_start": "<|character|>",
33
+ "character_end": "</|character|>"
34
+ }
35
+
36
+ # Ensure data directory exists
37
+ self.data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
38
+ os.makedirs(self.data_dir, exist_ok=True)
39
+ print(f"Data directory: {self.data_dir}")
40
+
41
+ # Configuration for processing
42
+ self.max_length = 1024 # DeepSeek context window
43
+ self.min_length = 50 # Minimum story length
44
+
45
+ def preprocess_text(self, text: str) -> str:
46
+ """Preprocess text for children's stories"""
47
+ # Basic text cleaning
48
+ text = text.lower() # Convert to lowercase for consistency
49
+ text = text.replace('\n', ' ') # Replace newlines with spaces
50
+ text = ' '.join(text.split()) # Normalize whitespace
51
+
52
+ # Remove any inappropriate content markers (basic filtering)
53
+ inappropriate_phrases = ['adult content', 'mature', 'explicit']
54
+ for phrase in inappropriate_phrases:
55
+ if phrase in text:
56
+ return ""
57
+
58
+ # Ensure the text is child-friendly
59
+ if len(text) < self.min_length:
60
+ return ""
61
+
62
+ return text
63
+
64
+ def extract_story_elements(self, example: Dict) -> Dict:
65
+ """Extract story elements for better structure"""
66
+ prompt = self.preprocess_text(example.get('prompt', ''))
67
+ story = self.preprocess_text(example.get('text', ''))
68
+
69
+ # Extract potential moral or lesson
70
+ moral = ""
71
+ if 'moral' in example:
72
+ moral = self.preprocess_text(example['moral'])
73
+ elif 'lesson' in example:
74
+ moral = self.preprocess_text(example['lesson'])
75
+
76
+ # Extract main character if available
77
+ character = ""
78
+ if 'character' in example:
79
+ character = self.preprocess_text(example['character'])
80
+
81
+ return {
82
+ 'prompt': prompt,
83
+ 'story': story,
84
+ 'moral': moral,
85
+ 'character': character
86
+ }
87
+
88
+ def process(self, example: Dict) -> Dict:
89
+ """Process a single example for DeepSeek model"""
90
+ # Extract story elements
91
+ elements = self.extract_story_elements(example)
92
+
93
+ # Skip if no valid content
94
+ if not elements['story'] or not elements['prompt']:
95
+ return {'ids': [], 'len': 0}
96
+
97
+ # Create structured text with special tokens
98
+ full_text = (
99
+ f"{self.special_tokens['prompt_start']} {elements['prompt']} {self.special_tokens['prompt_end']} "
100
+ )
101
+
102
+ # Add character information if available
103
+ if elements['character']:
104
+ full_text += f"{self.special_tokens['character_start']} {elements['character']} {self.special_tokens['character_end']} "
105
+
106
+ # Add the main story
107
+ full_text += f"{self.special_tokens['story_start']} {elements['story']} {self.special_tokens['story_end']}"
108
+
109
+ # Add moral if available
110
+ if elements['moral']:
111
+ full_text += f" {self.special_tokens['moral_start']} {elements['moral']} {self.special_tokens['moral_end']}"
112
+
113
+ # Tokenize with error handling
114
+ try:
115
+ ids = self.enc.encode_ordinary(full_text)
116
+
117
+ # Ensure the sequence isn't too long
118
+ if len(ids) > self.max_length:
119
+ ids = ids[:self.max_length]
120
+
121
+ # Skip if too short
122
+ if len(ids) < 20:
123
+ return {'ids': [], 'len': 0}
124
+
125
+ out = {'ids': ids, 'len': len(ids)}
126
+ return out
127
+
128
+ except Exception as e:
129
+ print(f"Error tokenizing text: {e}")
130
+ return {'ids': [], 'len': 0}
131
+
132
+ def prepare_dataset(self) -> Dict:
133
+ """Prepare the Children Stories Collection dataset for DeepSeek training"""
134
+ # Load the Children Stories Collection dataset
135
+ print("Loading Children Stories Collection dataset...")
136
+ ds = load_dataset("ajibawa-2023/Children-Stories-Collection")
137
+
138
+ train_bin_path = os.path.join(self.data_dir, "train.bin")
139
+ val_bin_path = os.path.join(self.data_dir, "validation.bin")
140
+ finetune_bin_path = os.path.join(self.data_dir, "finetune.bin")
141
+
142
+ print(f"Checking for existing processed files...")
143
+
144
+ # Check if all files exist
145
+ if (os.path.exists(train_bin_path) and
146
+ os.path.exists(val_bin_path) and
147
+ os.path.exists(finetune_bin_path)):
148
+
149
+ print("Found existing processed files!")
150
+ print(f"Train file: {os.path.getsize(train_bin_path) / (1024*1024):.2f} MB")
151
+ print(f"Validation file: {os.path.getsize(val_bin_path) / (1024*1024):.2f} MB")
152
+ print(f"Finetune file: {os.path.getsize(finetune_bin_path) / (1024*1024):.2f} MB")
153
+
154
+ return {
155
+ "train": train_bin_path,
156
+ "validation": val_bin_path,
157
+ "finetune": finetune_bin_path
158
+ }
159
+
160
+ print("Processing dataset...")
161
+
162
+ # Filter out examples that are too short or too long
163
+ def filter_by_length(example):
164
+ text_length = len(example.get('text', ''))
165
+ return self.min_length <= text_length <= 2000 # Reasonable length for children's stories
166
+
167
+ ds = ds.filter(filter_by_length)
168
+ print(f"After filtering: {len(ds['train'])} examples")
169
+
170
+ # Split the dataset into train, validation, and finetune sets
171
+ train_val_test = ds["train"].train_test_split(test_size=0.2, seed=42)
172
+ val_finetune = train_val_test["test"].train_test_split(test_size=0.5, seed=42)
173
+
174
+ # Create a new dataset dictionary with all splits
175
+ ds = {
176
+ "train": train_val_test["train"],
177
+ "validation": val_finetune["train"],
178
+ "finetune": val_finetune["test"]
179
+ }
180
+
181
+ print(f"Dataset split sizes:")
182
+ print(f"Training set: {len(ds['train'])} examples")
183
+ print(f"Validation set: {len(ds['validation'])} examples")
184
+ print(f"Finetune set: {len(ds['finetune'])} examples")
185
+
186
+ # Process each split
187
+ for split_name, split_data in ds.items():
188
+ print(f"\nProcessing {split_name} split...")
189
+
190
+ # Process the data
191
+ tokenized = split_data.map(
192
+ self.process,
193
+ remove_columns=['text', 'prompt', 'text_token_length'],
194
+ desc=f"tokenizing {split_name} split",
195
+ num_proc=8,
196
+ )
197
+
198
+ # Filter out empty sequences
199
+ tokenized = tokenized.filter(lambda x: x['len'] > 0)
200
+ print(f"After processing: {len(tokenized)} valid examples")
201
+
202
+ # Save to binary file
203
+ filename = os.path.join(self.data_dir, f"{split_name}.bin")
204
+ print(f"Saving {split_name} split to: {filename}")
205
+
206
+ # Calculate total length
207
+ arr_len = np.sum(tokenized['len'], dtype=np.uint64)
208
+ dtype = np.uint16
209
+ arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
210
+ total_batches = 1024
211
+
212
+ idx = 0
213
+ for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
214
+ batch = tokenized.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
215
+ arr_batch = np.concatenate(batch['ids'])
216
+ arr[idx : idx + len(arr_batch)] = arr_batch
217
+ idx += len(arr_batch)
218
+ arr.flush()
219
+
220
+ # Verify file was created
221
+ if os.path.exists(filename):
222
+ print(f"Successfully created {filename}")
223
+ print(f"File size: {os.path.getsize(filename) / (1024*1024):.2f} MB")
224
+ else:
225
+ raise RuntimeError(f"Failed to create {filename}")
226
+
227
+ return {
228
+ "train": train_bin_path,
229
+ "validation": val_bin_path,
230
+ "finetune": finetune_bin_path
231
+ }
232
+
233
+ def load_binary_data(self, filepath: str) -> torch.Tensor:
234
+ """Load binary data file as tensor"""
235
+ try:
236
+ data = np.memmap(filepath, dtype=np.uint16, mode='r')
237
+ return torch.from_numpy(data.copy())
238
+ except Exception as e:
239
+ print(f"Error loading data from {filepath}: {e}")
240
+ raise
241
+
242
+ def get_batch(self, data: torch.Tensor, batch_size: int, block_size: int) -> tuple:
243
+ """Get a batch of data for training"""
244
+ # Generate random indices
245
+ ix = torch.randint(len(data) - block_size, (batch_size,))
246
+
247
+ # Get input sequences
248
+ x = torch.stack([data[i:i+block_size].long() for i in ix])
249
+ # Get target sequences (shifted by 1)
250
+ y = torch.stack([data[i+1:i+1+block_size].long() for i in ix])
251
+
252
+ return x, y
253
+
254
+ def decode_tokens(self, token_ids: List[int]) -> str:
255
+ """Decode token IDs back to text"""
256
+ try:
257
+ return self.enc.decode(token_ids)
258
+ except Exception as e:
259
+ print(f"Error decoding tokens: {e}")
260
+ return ""
261
+
262
+ def encode_text(self, text: str) -> List[int]:
263
+ """Encode text to token IDs"""
264
+ try:
265
+ return self.enc.encode_ordinary(text)
266
+ except Exception as e:
267
+ print(f"Error encoding text: {e}")
268
+ return []
269
+
270
+
271
+ def main():
272
+ """Main function to process the dataset"""
273
+ print("DeepSeek Children's Stories Data Processor")
274
+ print("=" * 50)
275
+
276
+ processor = DeepSeekDataProcessor()
277
+ processor.prepare_dataset()
278
+
279
+ print("\nData processing completed successfully!")
280
+ print("Files created:")
281
+ print("- src/data/train.bin")
282
+ print("- src/data/validation.bin")
283
+ print("- src/data/finetune.bin")
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()
src/data/finetune.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4f4819be598b40b35540e069ed44efc44ffc732ab6c29269e5f1a227fb9e77f
3
+ size 61167196
src/data/train.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09234a8bf9f55e6e59f48b765bea680f3c2aa4a8305e9a553d9593de4652d0aa
3
+ size 488961682
src/data/validation.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98148ea1408dc8b01abe08126221c6ef3cd01362240b9d0bb1ce39e477ae211c
3
+ size 61065952
src/generate.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek Children's Stories Text Generation
3
+ Generate children's stories using the trained DeepSeek model
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import torch
10
+ import tiktoken
11
+ from typing import List, Optional
12
+
13
+ # Add the src directory to Python path
14
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
15
+
16
+ from model.deepseek import DeepSeek, DeepSeekConfig
17
+
18
+ # Allowlist DeepSeekConfig for safe deserialization
19
+ torch.serialization.add_safe_globals([DeepSeekConfig])
20
+
21
+ class DeepSeekStoryGenerator:
22
+ def __init__(self, model_path: str, device: str = 'auto'):
23
+ """Initialize the story generator"""
24
+ self.device = self._get_device(device)
25
+ self.model = self._load_model(model_path)
26
+ self.tokenizer = tiktoken.get_encoding("gpt2")
27
+
28
+ # Special tokens for story structure
29
+ self.special_tokens = {
30
+ "story_start": "<|story|>",
31
+ "story_end": "</|story|>",
32
+ "prompt_start": "<|prompt|>",
33
+ "prompt_end": "</|prompt|>",
34
+ "moral_start": "<|moral|>",
35
+ "moral_end": "</|moral|>",
36
+ "character_start": "<|character|>",
37
+ "character_end": "</|character|>"
38
+ }
39
+
40
+ def _get_device(self, device: str) -> str:
41
+ """Get the appropriate device"""
42
+ if device == 'auto':
43
+ return 'cuda' if torch.cuda.is_available() else 'cpu'
44
+ return device
45
+
46
+ def _load_model(self, model_path: str) -> DeepSeek:
47
+ """Load the trained model"""
48
+ print(f"Loading model from {model_path}...")
49
+
50
+ # Load checkpoint
51
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
52
+
53
+ # Create model with the same configuration
54
+ config = checkpoint['config']
55
+ model = DeepSeek(config)
56
+
57
+ # Handle compiled model state dict by removing _orig_mod prefix
58
+ state_dict = checkpoint['model']
59
+ if all(k.startswith('_orig_mod.') for k in state_dict.keys()):
60
+ state_dict = {k[10:]: v for k, v in state_dict.items()} # Remove '_orig_mod.' prefix
61
+
62
+ # Load model weights
63
+ model.load_state_dict(state_dict)
64
+ model.to(self.device)
65
+ model.eval()
66
+
67
+ print(f"Model loaded successfully!")
68
+ print(f"Model configuration: {config.n_layer}L/{config.n_head}H/{config.n_embd}D")
69
+ print(f"Device: {self.device}")
70
+
71
+ return model
72
+
73
+ def encode_prompt(self, prompt: str, character: Optional[str] = None) -> torch.Tensor:
74
+ """Encode a prompt for generation"""
75
+ # Create structured prompt
76
+ full_prompt = f"{self.special_tokens['prompt_start']} {prompt.lower()} {self.special_tokens['prompt_end']}"
77
+
78
+ if character:
79
+ full_prompt += f" {self.special_tokens['character_start']} {character.lower()} {self.special_tokens['character_end']}"
80
+
81
+ full_prompt += f" {self.special_tokens['story_start']}"
82
+
83
+ # Tokenize
84
+ token_ids = self.tokenizer.encode_ordinary(full_prompt)
85
+ return torch.tensor([token_ids], dtype=torch.long, device=self.device)
86
+
87
+ def generate_story(self, prompt: str, character: Optional[str] = None,
88
+ max_tokens: int = 200, temperature: float = 0.8,
89
+ top_k: int = 40, top_p: float = 0.9) -> str:
90
+ """Generate a children's story"""
91
+ print(f"Generating story for prompt: '{prompt}'")
92
+ if character:
93
+ print(f"Character: {character}")
94
+
95
+ # Encode prompt
96
+ input_ids = self.encode_prompt(prompt, character)
97
+
98
+ # Generate
99
+ with torch.no_grad():
100
+ generated_ids = self.model.generate(
101
+ input_ids,
102
+ max_new_tokens=max_tokens,
103
+ temperature=temperature,
104
+ top_k=top_k
105
+ )
106
+
107
+ # Decode the generated text
108
+ generated_text = self.tokenizer.decode(generated_ids[0].tolist())
109
+
110
+ # Extract the story part
111
+ story = self._extract_story(generated_text)
112
+
113
+ return story
114
+
115
+ def _extract_story(self, text: str) -> str:
116
+ """Extract the story from the generated text"""
117
+ # Find story start and end markers
118
+ story_start = text.find(self.special_tokens['story_start'])
119
+ story_end = text.find(self.special_tokens['story_end'])
120
+
121
+ if story_start != -1 and story_end != -1:
122
+ # Extract story content
123
+ story_content = text[story_start + len(self.special_tokens['story_start']):story_end].strip()
124
+ return story_content
125
+ else:
126
+ # Fallback: return the text after the last prompt
127
+ prompt_end = text.find(self.special_tokens['prompt_end'])
128
+ if prompt_end != -1:
129
+ return text[prompt_end + len(self.special_tokens['prompt_end']):].strip()
130
+ else:
131
+ return text.strip()
132
+
133
+ def generate_multiple_stories(self, prompts: List[str], num_stories: int = 3,
134
+ **kwargs) -> List[str]:
135
+ """Generate multiple stories from a list of prompts"""
136
+ stories = []
137
+
138
+ for i, prompt in enumerate(prompts):
139
+ print(f"\nGenerating story {i+1}/{len(prompts)}...")
140
+ story = self.generate_story(prompt, **kwargs)
141
+ stories.append(story)
142
+
143
+ return stories
144
+
145
+ def interactive_generation(self):
146
+ """Interactive story generation mode"""
147
+ print("DeepSeek Children's Stories - Interactive Mode")
148
+ print("Type 'quit' to exit")
149
+ print("-" * 50)
150
+
151
+ while True:
152
+ try:
153
+ # Get prompt from user
154
+ prompt = input("\nEnter a story prompt: ").strip()
155
+
156
+ if prompt.lower() in ['quit', 'exit', 'q']:
157
+ print("Goodbye!")
158
+ break
159
+
160
+ if not prompt:
161
+ print("Please enter a valid prompt.")
162
+ continue
163
+
164
+ # Get character (optional)
165
+ character = input("Enter a character name (optional): ").strip()
166
+ if not character:
167
+ character = None
168
+
169
+ # Get generation parameters
170
+ try:
171
+ max_tokens = int(input("Max tokens (default 200): ") or "200")
172
+ temperature = float(input("Temperature (default 0.8): ") or "0.8")
173
+ except ValueError:
174
+ max_tokens = 200
175
+ temperature = 0.8
176
+
177
+ # Generate story
178
+ story = self.generate_story(
179
+ prompt,
180
+ character=character,
181
+ max_tokens=max_tokens,
182
+ temperature=temperature
183
+ )
184
+
185
+ # Display story
186
+ print("\n" + "="*50)
187
+ print("GENERATED STORY:")
188
+ print("="*50)
189
+ print(story)
190
+ print("="*50)
191
+
192
+ except KeyboardInterrupt:
193
+ print("\nGoodbye!")
194
+ break
195
+ except Exception as e:
196
+ print(f"Error generating story: {e}")
197
+
198
+
199
+ def main():
200
+ """Main generation function"""
201
+ parser = argparse.ArgumentParser(description='Generate children\'s stories with DeepSeek')
202
+
203
+ # Model configuration
204
+ parser.add_argument('--model-path', type=str, default='checkpoints/best_model.pt',
205
+ help='Path to the trained model checkpoint')
206
+ parser.add_argument('--device', type=str, default='auto',
207
+ help='Device to use (auto, cuda, cpu)')
208
+
209
+ # Generation parameters
210
+ parser.add_argument('--prompt', type=str, help='Story prompt')
211
+ parser.add_argument('--character', type=str, help='Character name')
212
+ parser.add_argument('--max-tokens', type=int, default=200, help='Maximum tokens to generate')
213
+ parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature')
214
+ parser.add_argument('--top-k', type=int, default=40, help='Top-k sampling')
215
+ parser.add_argument('--top-p', type=float, default=0.9, help='Top-p sampling')
216
+
217
+ # Multiple generation
218
+ parser.add_argument('--num-stories', type=int, default=1, help='Number of stories to generate')
219
+ parser.add_argument('--interactive', action='store_true', help='Interactive mode')
220
+
221
+ args = parser.parse_args()
222
+
223
+ # Check if model exists
224
+ if not os.path.exists(args.model_path):
225
+ print(f"Error: Model file not found at {args.model_path}")
226
+ print("Please train the model first or specify the correct path.")
227
+ return
228
+
229
+ # Create generator
230
+ generator = DeepSeekStoryGenerator(args.model_path, args.device)
231
+
232
+ if args.interactive:
233
+ # Interactive mode
234
+ generator.interactive_generation()
235
+ else:
236
+ # Single or multiple generation
237
+ if args.prompt:
238
+ if args.num_stories == 1:
239
+ # Single story
240
+ story = generator.generate_story(
241
+ args.prompt,
242
+ character=args.character,
243
+ max_tokens=args.max_tokens,
244
+ temperature=args.temperature,
245
+ top_k=args.top_k,
246
+ top_p=args.top_p
247
+ )
248
+
249
+ print(f"\nPrompt: {args.prompt}")
250
+ if args.character:
251
+ print(f"Character: {args.character}")
252
+ print("\n" + "="*50)
253
+ print("GENERATED STORY:")
254
+ print("="*50)
255
+ print(story)
256
+ print("="*50)
257
+ else:
258
+ # Multiple stories
259
+ prompts = [args.prompt] * args.num_stories
260
+ stories = generator.generate_multiple_stories(
261
+ prompts,
262
+ num_stories=args.num_stories,
263
+ character=args.character,
264
+ max_tokens=args.max_tokens,
265
+ temperature=args.temperature,
266
+ top_k=args.top_k,
267
+ top_p=args.top_p
268
+ )
269
+
270
+ for i, story in enumerate(stories):
271
+ print(f"\nStory {i+1}:")
272
+ print("="*50)
273
+ print(story)
274
+ print("="*50)
275
+ else:
276
+ print("Please provide a prompt or use --interactive mode.")
277
+ print("Example: python generate.py --prompt 'A brave little mouse' --character 'Mickey'")
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
src/model/__pycache__/deepseek.cpython-310.pyc ADDED
Binary file (13.8 kB). View file
 
src/model/deepseek.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek Model Architecture for Children's Stories
3
+ Implements advanced features:
4
+ - Multihead Latent Attention (MLA)
5
+ - Mixture of Experts (MoE)
6
+ - Multi-token prediction
7
+ - Quantization support
8
+ - Rotary Positional Encodings (RoPE)
9
+ - Optimized for children's story generation
10
+ """
11
+
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from typing import Optional, Tuple, List
17
+ from dataclasses import dataclass
18
+
19
+
20
+ @dataclass
21
+ class DeepSeekConfig:
22
+ """Configuration for DeepSeek model optimized for children's stories"""
23
+ vocab_size: int = 50257 # GPT-2 vocabulary size
24
+ n_layer: int = 6 # Reduced for efficiency
25
+ n_head: int = 8 # Number of attention heads
26
+ n_embd: int = 512 # Embedding dimension
27
+ block_size: int = 1024 # Context window
28
+ dropout: float = 0.1 # Dropout rate
29
+ bias: bool = True # Use bias in linear layers
30
+
31
+ # MLA (Multihead Latent Attention) config
32
+ use_mla: bool = True # Enable MLA
33
+ mla_kv_heads: int = 4 # Number of key-value heads for MLA
34
+ mla_q_lora_rank: int = 32 # LoRA rank for query projection
35
+ mla_kv_lora_rank: int = 16 # LoRA rank for key-value projection
36
+
37
+ # MoE (Mixture of Experts) config
38
+ moe_num_experts: int = 4 # Number of experts
39
+ moe_top_k: int = 2 # Number of experts per token
40
+ moe_expert_capacity: float = 1.25
41
+ moe_aux_loss_coeff: float = 0.01
42
+
43
+ # Multi-token prediction
44
+ multi_token_predict: int = 2 # Predict next 2 tokens for children's stories
45
+
46
+ # Quantization
47
+ use_quantization: bool = False
48
+ quantization_bits: int = 8
49
+
50
+
51
+ class RoPEPositionalEncoding(nn.Module):
52
+ """Rotary Positional Encoding (RoPE) for better position understanding"""
53
+
54
+ def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.max_seq_len = max_seq_len
58
+ self.base = base
59
+
60
+ # Precompute frequency matrix
61
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
62
+ self.register_buffer('inv_freq', inv_freq)
63
+
64
+ # Cache for efficiency
65
+ self._cached_cos = None
66
+ self._cached_sin = None
67
+ self._cached_seq_len = 0
68
+
69
+ def _compute_cos_sin(self, seq_len: int, device: torch.device):
70
+ """Compute cosine and sine values for given sequence length"""
71
+ if seq_len > self._cached_seq_len or self._cached_cos is None:
72
+ # Create position indices
73
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
74
+
75
+ # Compute frequencies
76
+ freqs = torch.outer(t, self.inv_freq)
77
+
78
+ # Create rotation matrix components
79
+ cos_vals = torch.cos(freqs)
80
+ sin_vals = torch.sin(freqs)
81
+
82
+ # Cache results
83
+ self._cached_cos = cos_vals
84
+ self._cached_sin = sin_vals
85
+ self._cached_seq_len = seq_len
86
+
87
+ return self._cached_cos[:seq_len], self._cached_sin[:seq_len]
88
+
89
+ def apply_rope(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None):
90
+ """Apply RoPE to input tensor"""
91
+ batch_size, seq_len, n_heads, head_dim = x.shape
92
+
93
+ # Get cos/sin values
94
+ cos, sin = self._compute_cos_sin(seq_len, x.device)
95
+
96
+ # Handle position_ids if provided
97
+ if position_ids is not None:
98
+ cos = cos[position_ids]
99
+ sin = sin[position_ids]
100
+
101
+ # Reshape for broadcasting
102
+ cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, head_dim//2]
103
+ sin = sin.unsqueeze(0).unsqueeze(2)
104
+
105
+ # Split x into two halves
106
+ x1 = x[..., ::2] # Even indices
107
+ x2 = x[..., 1::2] # Odd indices
108
+
109
+ # Apply rotation
110
+ rotated_x1 = x1 * cos - x2 * sin
111
+ rotated_x2 = x1 * sin + x2 * cos
112
+
113
+ # Recombine
114
+ rotated_x = torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
115
+
116
+ return rotated_x
117
+
118
+
119
+ class MultiheadLatentAttention(nn.Module):
120
+ """
121
+ Multihead Latent Attention (MLA) - DeepSeek's efficient attention mechanism
122
+ Uses shared key-value heads with LoRA-style projections for efficiency
123
+ """
124
+
125
+ def __init__(self, config: DeepSeekConfig):
126
+ super().__init__()
127
+ self.config = config
128
+ self.n_head = config.n_head
129
+ self.n_embd = config.n_embd
130
+ self.head_dim = config.n_embd // config.n_head
131
+ self.kv_heads = config.mla_kv_heads
132
+ self.kv_head_dim = self.head_dim
133
+
134
+ # Query projection with LoRA-style decomposition
135
+ self.q_a_proj = nn.Linear(config.n_embd, config.mla_q_lora_rank, bias=False)
136
+ self.q_b_proj = nn.Linear(config.mla_q_lora_rank, config.n_embd, bias=False)
137
+
138
+ # Key-Value projection with shared heads
139
+ self.kv_a_proj = nn.Linear(config.n_embd, config.mla_kv_lora_rank, bias=False)
140
+ self.kv_b_proj = nn.Linear(config.mla_kv_lora_rank, self.kv_heads * self.head_dim * 2, bias=False)
141
+
142
+ # Output projection
143
+ self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
144
+
145
+ # RoPE for positional encoding
146
+ self.rope = RoPEPositionalEncoding(self.head_dim)
147
+
148
+ # Dropout
149
+ self.dropout = nn.Dropout(config.dropout)
150
+
151
+ # Scaling factor
152
+ self.scale = self.head_dim ** -0.5
153
+
154
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
155
+ batch_size, seq_len, _ = x.shape
156
+
157
+ # Query projection through LoRA-style decomposition
158
+ q_latent = self.q_a_proj(x) # [B, T, rank]
159
+ q = self.q_b_proj(q_latent) # [B, T, n_embd]
160
+ q = q.view(batch_size, seq_len, self.n_head, self.head_dim)
161
+
162
+ # Key-Value projection through shared heads
163
+ kv_latent = self.kv_a_proj(x) # [B, T, kv_rank]
164
+ kv = self.kv_b_proj(kv_latent) # [B, T, kv_heads * kv_head_dim * 2]
165
+ kv = kv.view(batch_size, seq_len, self.kv_heads, self.head_dim, 2)
166
+ k, v = kv.unbind(dim=-1) # Each: [B, T, kv_heads, kv_head_dim]
167
+
168
+ # Apply RoPE to queries and keys before expansion
169
+ q = self.rope.apply_rope(q)
170
+ k = self.rope.apply_rope(k)
171
+
172
+ # Expand key-value to match query heads
173
+ k = k.repeat_interleave(self.n_head // self.kv_heads, dim=2)
174
+ v = v.repeat_interleave(self.n_head // self.kv_heads, dim=2)
175
+
176
+ # Transpose for attention computation
177
+ q = q.transpose(1, 2) # [B, n_head, T, head_dim]
178
+ k = k.transpose(1, 2) # [B, n_head, T, head_dim]
179
+ v = v.transpose(1, 2) # [B, n_head, T, head_dim]
180
+
181
+ # Compute attention scores
182
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
183
+
184
+ # Apply causal mask
185
+ if attention_mask is None:
186
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
187
+ attn_scores.masked_fill_(causal_mask, float('-inf'))
188
+ else:
189
+ attn_scores = attn_scores + attention_mask
190
+
191
+ # Apply softmax
192
+ attn_weights = F.softmax(attn_scores, dim=-1)
193
+ attn_weights = self.dropout(attn_weights)
194
+
195
+ # Apply attention to values
196
+ out = torch.matmul(attn_weights, v) # [B, n_head, T, head_dim]
197
+ out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd)
198
+
199
+ # Output projection
200
+ out = self.out_proj(out)
201
+
202
+ return out
203
+
204
+
205
+ class MoEExpert(nn.Module):
206
+ """Expert network for Mixture of Experts"""
207
+
208
+ def __init__(self, config: DeepSeekConfig):
209
+ super().__init__()
210
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
211
+ self.gelu = nn.GELU()
212
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
213
+ self.dropout = nn.Dropout(config.dropout)
214
+
215
+ def forward(self, x: torch.Tensor):
216
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
217
+
218
+
219
+ class MixtureOfExperts(nn.Module):
220
+ """Mixture of Experts (MoE) for increased model capacity"""
221
+
222
+ def __init__(self, config: DeepSeekConfig):
223
+ super().__init__()
224
+ self.config = config
225
+ self.num_experts = config.moe_num_experts
226
+ self.top_k = config.moe_top_k
227
+ self.expert_capacity = config.moe_expert_capacity
228
+
229
+ # Router
230
+ self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False)
231
+
232
+ # Experts
233
+ self.experts = nn.ModuleList([MoEExpert(config) for _ in range(config.moe_num_experts)])
234
+
235
+ # Layer norm
236
+ self.ln = nn.LayerNorm(config.n_embd, bias=config.bias)
237
+
238
+ def forward(self, x: torch.Tensor):
239
+ batch_size, seq_len, hidden_dim = x.shape
240
+
241
+ # Get router logits
242
+ router_logits = self.router(x) # [B, T, num_experts]
243
+
244
+ # Get top-k experts
245
+ top_k_logits, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1)
246
+ top_k_probs = F.softmax(top_k_logits, dim=-1)
247
+
248
+ # Initialize output
249
+ output = torch.zeros_like(x)
250
+
251
+ # Process each expert
252
+ for expert_idx in range(self.num_experts):
253
+ # Find tokens that use this expert
254
+ expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [B, T]
255
+
256
+ if expert_mask.any():
257
+ # Get tokens for this expert
258
+ expert_tokens = x[expert_mask] # [num_tokens, hidden_dim]
259
+
260
+ # Get routing weights for this expert
261
+ expert_weights = top_k_probs[expert_mask] # [num_tokens, top_k]
262
+ expert_weights = expert_weights[top_k_indices[expert_mask] == expert_idx] # [num_tokens]
263
+
264
+ # Apply expert
265
+ expert_output = self.experts[expert_idx](expert_tokens) # [num_tokens, hidden_dim]
266
+
267
+ # Weight the output
268
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
269
+
270
+ # Add to output
271
+ output[expert_mask] += weighted_output
272
+
273
+ # Apply layer norm
274
+ output = self.ln(output)
275
+
276
+ return output, router_logits
277
+
278
+ def _compute_aux_loss(self, router_logits: torch.Tensor):
279
+ """Compute auxiliary loss for load balancing"""
280
+ router_probs = F.softmax(router_logits, dim=-1)
281
+ mean_expert_usage = router_probs.mean(dim=[0, 1]) # [num_experts]
282
+ target_usage = 1.0 / self.num_experts
283
+
284
+ aux_loss = torch.sum((mean_expert_usage - target_usage) ** 2)
285
+ return aux_loss
286
+
287
+
288
+ class DeepSeekBlock(nn.Module):
289
+ """DeepSeek transformer block with MLA and MoE"""
290
+
291
+ def __init__(self, config: DeepSeekConfig):
292
+ super().__init__()
293
+ self.config = config
294
+
295
+ # Layer norms
296
+ self.ln1 = nn.LayerNorm(config.n_embd, bias=config.bias)
297
+ self.ln2 = nn.LayerNorm(config.n_embd, bias=config.bias)
298
+
299
+ # Attention - use MLA if enabled, otherwise use standard attention
300
+ if config.use_mla:
301
+ self.attn = MultiheadLatentAttention(config)
302
+ else:
303
+ # Standard multihead attention as fallback
304
+ self.attn = nn.MultiheadAttention(
305
+ config.n_embd,
306
+ config.n_head,
307
+ dropout=config.dropout,
308
+ bias=config.bias,
309
+ batch_first=True
310
+ )
311
+
312
+ # MoE
313
+ self.moe = MixtureOfExperts(config)
314
+
315
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
316
+ # Attention with residual connection
317
+ if self.config.use_mla:
318
+ x = x + self.attn(self.ln1(x), attention_mask)
319
+ else:
320
+ attn_out, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=attention_mask)
321
+ x = x + attn_out
322
+
323
+ # MoE with residual connection
324
+ moe_output, router_logits = self.moe(self.ln2(x))
325
+ x = x + moe_output
326
+
327
+ return x, router_logits
328
+
329
+
330
+ class MultiTokenPredictor(nn.Module):
331
+ """Multi-token prediction head for improved training efficiency"""
332
+
333
+ def __init__(self, config: DeepSeekConfig):
334
+ super().__init__()
335
+ self.config = config
336
+ self.num_tokens = config.multi_token_predict
337
+
338
+ # Separate prediction heads for each future token
339
+ self.predictors = nn.ModuleList([
340
+ nn.Linear(config.n_embd, config.vocab_size, bias=False)
341
+ for _ in range(config.multi_token_predict)
342
+ ])
343
+
344
+ def forward(self, hidden_states: torch.Tensor):
345
+ """Forward pass for multi-token prediction"""
346
+ batch_size, seq_len, hidden_dim = hidden_states.shape
347
+
348
+ # Predict multiple future tokens
349
+ logits = []
350
+ for i, predictor in enumerate(self.predictors):
351
+ # Use hidden states shifted by i+1 positions
352
+ if i + 1 < seq_len:
353
+ token_logits = predictor(hidden_states[:, i+1:i+2, :]) # [B, 1, vocab_size]
354
+ logits.append(token_logits)
355
+ else:
356
+ # Pad with zeros if not enough sequence length
357
+ token_logits = torch.zeros(batch_size, 1, self.config.vocab_size,
358
+ device=hidden_states.device)
359
+ logits.append(token_logits)
360
+
361
+ return torch.cat(logits, dim=1) # [B, num_tokens, vocab_size]
362
+
363
+
364
+ class DeepSeek(nn.Module):
365
+ """DeepSeek model for children's story generation"""
366
+
367
+ def __init__(self, config: DeepSeekConfig):
368
+ super().__init__()
369
+ assert isinstance(config, DeepSeekConfig), "config must be an instance of DeepSeekConfig"
370
+ self.config = config
371
+
372
+ # Token and position embeddings
373
+ self.transformer = nn.ModuleDict(dict(
374
+ wte=nn.Embedding(config.vocab_size, config.n_embd),
375
+ wpe=nn.Embedding(config.block_size, config.n_embd),
376
+ drop=nn.Dropout(config.dropout),
377
+ h=nn.ModuleList([DeepSeekBlock(config) for _ in range(config.n_layer)]),
378
+ ln_f=nn.LayerNorm(config.n_embd, bias=config.bias),
379
+ ))
380
+
381
+ # Language model head
382
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
383
+
384
+ # Multi-token predictor
385
+ if config.multi_token_predict > 0:
386
+ self.multi_token_predictor = MultiTokenPredictor(config)
387
+ else:
388
+ self.multi_token_predictor = None
389
+
390
+ # Weight tying
391
+ self.transformer.wte.weight = self.lm_head.weight
392
+
393
+ # Initialize weights
394
+ self.apply(self._init_weights)
395
+
396
+ # Setup quantization if enabled
397
+ if config.use_quantization:
398
+ self._setup_quantization()
399
+
400
+ def _init_weights(self, module):
401
+ """Initialize model weights"""
402
+ if isinstance(module, nn.Linear):
403
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
404
+ if module.bias is not None:
405
+ nn.init.zeros_(module.bias)
406
+ elif isinstance(module, nn.Embedding):
407
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
408
+ elif isinstance(module, nn.LayerNorm):
409
+ nn.init.ones_(module.weight)
410
+ if module.bias is not None:
411
+ nn.init.zeros_(module.bias)
412
+
413
+ def _setup_quantization(self):
414
+ """Setup quantization for the model"""
415
+ # This would implement quantization logic
416
+ # For now, just a placeholder
417
+ pass
418
+
419
+ def forward(self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None):
420
+ """Forward pass"""
421
+ device = input_ids.device
422
+ batch_size, seq_len = input_ids.size()
423
+ assert seq_len <= self.config.block_size
424
+
425
+ # Position indices
426
+ pos = torch.arange(0, seq_len, dtype=torch.long, device=device)
427
+
428
+ # Token and position embeddings
429
+ tok_emb = self.transformer.wte(input_ids)
430
+ pos_emb = self.transformer.wpe(pos)
431
+
432
+ x = self.transformer.drop(tok_emb + pos_emb)
433
+
434
+ # Forward through transformer blocks
435
+ router_logits_list = []
436
+ for block in self.transformer.h:
437
+ x, router_logits = block(x)
438
+ router_logits_list.append(router_logits)
439
+
440
+ # Final layer norm
441
+ x = self.transformer.ln_f(x)
442
+
443
+ if targets is not None:
444
+ # Training mode
445
+ if self.multi_token_predictor is not None:
446
+ # Multi-token prediction
447
+ multi_logits = self.multi_token_predictor(x)
448
+ loss = self._compute_multi_token_loss(multi_logits, targets)
449
+ else:
450
+ # Standard single-token prediction
451
+ logits = self.lm_head(x)
452
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
453
+ targets.view(-1), ignore_index=-1)
454
+
455
+ # Add MoE auxiliary loss
456
+ if router_logits_list:
457
+ aux_loss = sum(self.transformer.h[i].moe._compute_aux_loss(router_logits_list[i])
458
+ for i in range(len(router_logits_list)))
459
+ loss += self.config.moe_aux_loss_coeff * aux_loss
460
+
461
+ return logits if self.multi_token_predictor is None else multi_logits, loss
462
+ else:
463
+ # Inference mode
464
+ logits = self.lm_head(x[:, [-1], :])
465
+ return logits, None
466
+
467
+ def _compute_multi_token_loss(self, logits: torch.Tensor, targets: torch.Tensor):
468
+ """Compute loss for multi-token prediction"""
469
+ batch_size, num_tokens, vocab_size = logits.shape
470
+
471
+ # Reshape for loss computation
472
+ logits_flat = logits.view(-1, vocab_size)
473
+ targets_flat = targets.view(-1)
474
+
475
+ # Compute cross-entropy loss
476
+ loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=-1)
477
+
478
+ return loss
479
+
480
+ @torch.no_grad()
481
+ def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100,
482
+ temperature: float = 1.0, top_k: Optional[int] = None):
483
+ """Generate text using the model"""
484
+ for _ in range(max_new_tokens):
485
+ # Ensure input doesn't exceed block size
486
+ idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
487
+
488
+ # Forward pass
489
+ logits, _ = self(idx_cond)
490
+ logits = logits[:, -1, :] / temperature
491
+
492
+ # Apply top-k filtering
493
+ if top_k is not None:
494
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
495
+ logits[logits < v[:, [-1]]] = -float('Inf')
496
+
497
+ # Sample next token
498
+ probs = F.softmax(logits, dim=-1)
499
+ idx_next = torch.multinomial(probs, num_samples=1)
500
+ input_ids = torch.cat((input_ids, idx_next), dim=1)
501
+
502
+ return input_ids
503
+
504
+ @classmethod
505
+ def from_pretrained(cls, model_type: str, override_args: Optional[dict] = None):
506
+ """Load a pretrained model"""
507
+ # This would implement loading from pretrained weights
508
+ # For now, return a default configuration
509
+ config = DeepSeekConfig()
510
+ if override_args:
511
+ for key, value in override_args.items():
512
+ setattr(config, key, value)
513
+ return cls(config)
src/run_training.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek Children's Stories Training Script
3
+ Main training script for the DeepSeek model on children's stories
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import torch
10
+ from dataclasses import dataclass
11
+ from typing import Optional
12
+
13
+ # Add the src directory to Python path
14
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
15
+
16
+ from model.deepseek import DeepSeek, DeepSeekConfig
17
+ from training.trainer import DeepSeekTrainer, create_deepseek_trainer
18
+ from data.data_processor import DeepSeekDataProcessor
19
+
20
+
21
+ @dataclass
22
+ class TrainingConfig:
23
+ """Configuration for DeepSeek training"""
24
+ # Model configuration
25
+ vocab_size: int = 50257
26
+ n_layer: int = 6
27
+ n_head: int = 8
28
+ n_embd: int = 512
29
+ block_size: int = 1024
30
+ dropout: float = 0.1
31
+ bias: bool = True
32
+
33
+ # MLA configuration
34
+ use_mla: bool = True
35
+ mla_kv_heads: int = 4
36
+ mla_q_lora_rank: int = 32
37
+ mla_kv_lora_rank: int = 16
38
+
39
+ # MoE configuration
40
+ moe_num_experts: int = 4
41
+ moe_top_k: int = 2
42
+ moe_expert_capacity: float = 1.25
43
+ moe_aux_loss_coeff: float = 0.01
44
+
45
+ # Multi-token prediction
46
+ multi_token_predict: int = 0 # Predict next 2 tokens for efficiency
47
+
48
+ # Quantization
49
+ use_quantization: bool = False
50
+ quantization_bits: int = 8
51
+
52
+ # Training configuration
53
+ batch_size: int = 12
54
+ max_iters: int = 20000
55
+ eval_interval: int = 1000
56
+ eval_iters: int = 200
57
+ learning_rate: float = 6e-4
58
+ weight_decay: float = 0.1
59
+ warmup_iters: int = 2000
60
+ lr_decay_iters: int = 20000
61
+ min_lr: float = 6e-5
62
+
63
+ # System configuration
64
+ checkpoint_dir: str = 'checkpoints'
65
+ use_mixed_precision: bool = True
66
+ compile_model: bool = True
67
+
68
+ # Data configuration
69
+ dataset_name: str = "ajibawa-2023/Children-Stories-Collection"
70
+ data_dir: str = 'src/data'
71
+
72
+
73
+ def setup_environment():
74
+ """Setup the training environment"""
75
+ print("Setting up DeepSeek Children's Stories training environment...")
76
+
77
+ # Check CUDA availability
78
+ if torch.cuda.is_available():
79
+ print(f"CUDA available: {torch.cuda.get_device_name(0)}")
80
+ print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
81
+ else:
82
+ print("CUDA not available, using CPU")
83
+
84
+ # Create necessary directories
85
+ os.makedirs('checkpoints', exist_ok=True)
86
+ os.makedirs('lora_checkpoints', exist_ok=True)
87
+ os.makedirs('src/data', exist_ok=True)
88
+
89
+ print("Environment setup complete!")
90
+
91
+
92
+ def prepare_data():
93
+ """Prepare the dataset for training"""
94
+ print("Preparing dataset...")
95
+
96
+ processor = DeepSeekDataProcessor()
97
+ data_files = processor.prepare_dataset()
98
+
99
+ print("Dataset preparation complete!")
100
+ return data_files
101
+
102
+
103
+ def create_model(config: TrainingConfig) -> DeepSeek:
104
+ """Create the DeepSeek model"""
105
+ print("Creating DeepSeek model...")
106
+
107
+ # Create model configuration
108
+ model_config = DeepSeekConfig(
109
+ vocab_size=config.vocab_size,
110
+ n_layer=config.n_layer,
111
+ n_head=config.n_head,
112
+ n_embd=config.n_embd,
113
+ block_size=config.block_size,
114
+ dropout=config.dropout,
115
+ bias=config.bias,
116
+ use_mla=config.use_mla,
117
+ mla_kv_heads=config.mla_kv_heads,
118
+ mla_q_lora_rank=config.mla_q_lora_rank,
119
+ mla_kv_lora_rank=config.mla_kv_lora_rank,
120
+ moe_num_experts=config.moe_num_experts,
121
+ moe_top_k=config.moe_top_k,
122
+ moe_expert_capacity=config.moe_expert_capacity,
123
+ moe_aux_loss_coeff=config.moe_aux_loss_coeff,
124
+ multi_token_predict=config.multi_token_predict,
125
+ use_quantization=config.use_quantization,
126
+ quantization_bits=config.quantization_bits
127
+ )
128
+
129
+ # Create model
130
+ model = DeepSeek(model_config)
131
+
132
+ # Compile model if requested
133
+ if config.compile_model and hasattr(torch, 'compile'):
134
+ print("Compiling model with torch.compile...")
135
+ model = torch.compile(model)
136
+
137
+ # Print model info
138
+ total_params = sum(p.numel() for p in model.parameters())
139
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
140
+
141
+ print(f"Model created successfully!")
142
+ print(f"Total parameters: {total_params:,}")
143
+ print(f"Trainable parameters: {trainable_params:,}")
144
+ print(f"Model configuration:")
145
+ print(f" - Layers: {config.n_layer}")
146
+ print(f" - Heads: {config.n_head}")
147
+ print(f" - Embedding dim: {config.n_embd}")
148
+ print(f" - MLA enabled: {config.use_mla}")
149
+ print(f" - MLA KV heads: {config.mla_kv_heads}")
150
+ print(f" - MoE experts: {config.moe_num_experts}")
151
+ print(f" - Multi-token prediction: {config.multi_token_predict}")
152
+
153
+ return model
154
+
155
+
156
+ def train_model(model: DeepSeek, config: TrainingConfig):
157
+ """Train the DeepSeek model"""
158
+ print(f"[+] Starting training with config:")
159
+ print(f" - Model size: {sum(p.numel() for p in model.parameters()):,} parameters")
160
+ print(f" - Multi-token prediction: {config.multi_token_predict}")
161
+ print(f" - MoE experts: {config.moe_num_experts}")
162
+ print(f" - MLA enabled: {config.use_mla}")
163
+
164
+ # Setup device
165
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
166
+ model = model.to(device)
167
+
168
+ # Create optimizer
169
+ optimizer = torch.optim.AdamW(
170
+ model.parameters(),
171
+ lr=config.learning_rate,
172
+ weight_decay=config.weight_decay,
173
+ betas=(0.9, 0.95)
174
+ )
175
+
176
+ # Initialize trainer with individual parameters
177
+ trainer = DeepSeekTrainer(
178
+ model=model,
179
+ optimizer=optimizer,
180
+ device=device,
181
+ batch_size=config.batch_size,
182
+ max_iters=config.max_iters,
183
+ eval_interval=config.eval_interval,
184
+ eval_iters=config.eval_iters,
185
+ learning_rate=config.learning_rate,
186
+ weight_decay=config.weight_decay,
187
+ warmup_iters=config.warmup_iters,
188
+ lr_decay_iters=config.lr_decay_iters,
189
+ min_lr=config.min_lr,
190
+ checkpoint_dir=config.checkpoint_dir,
191
+ use_mixed_precision=config.use_mixed_precision
192
+ )
193
+
194
+ try:
195
+ # Start training
196
+ trainer.train()
197
+ print("[+] Training completed successfully!")
198
+
199
+ # Save final model
200
+ final_model_path = os.path.join(config.checkpoint_dir, "final_model.pt")
201
+ torch.save({
202
+ 'model_state_dict': model.state_dict(),
203
+ 'config': config,
204
+ 'optimizer_state_dict': trainer.optimizer.state_dict(),
205
+ }, final_model_path)
206
+ print(f"[+] Final model saved to {final_model_path}")
207
+
208
+ except Exception as e:
209
+ print(f"[-] Training failed: {e}")
210
+ import traceback
211
+ traceback.print_exc()
212
+ raise
213
+
214
+
215
+ def main():
216
+ """Main training function"""
217
+ parser = argparse.ArgumentParser(description='Train DeepSeek model on children\'s stories')
218
+
219
+ # Model configuration
220
+ parser.add_argument('--n-layer', type=int, default=6, help='Number of layers')
221
+ parser.add_argument('--n-head', type=int, default=8, help='Number of attention heads')
222
+ parser.add_argument('--n-embd', type=int, default=512, help='Embedding dimension')
223
+ parser.add_argument('--block-size', type=int, default=1024, help='Context window size')
224
+
225
+ # Training configuration
226
+ parser.add_argument('--batch-size', type=int, default=12, help='Batch size')
227
+ parser.add_argument('--max-iters', type=int, default=20000, help='Maximum iterations')
228
+ parser.add_argument('--learning-rate', type=float, default=6e-4, help='Learning rate')
229
+ parser.add_argument('--eval-interval', type=int, default=1000, help='Evaluation interval')
230
+ parser.add_argument('--eval-iters', type=int, default=200, help='Number of evaluation iterations')
231
+ parser.add_argument('--weight-decay', type=float, default=0.1, help='Weight decay')
232
+ parser.add_argument('--warmup-iters', type=int, default=2000, help='Warmup iterations')
233
+ parser.add_argument('--lr-decay-iters', type=int, default=20000, help='Learning rate decay iterations')
234
+ parser.add_argument('--min-lr', type=float, default=6e-5, help='Minimum learning rate')
235
+
236
+ # Advanced features
237
+ parser.add_argument('--moe-experts', type=int, default=4, help='Number of MoE experts')
238
+ parser.add_argument('--multi-token', type=int, default=2, help='Multi-token prediction')
239
+ parser.add_argument('--no-compile', action='store_true', help='Disable model compilation')
240
+ parser.add_argument('--no-mixed-precision', action='store_true', help='Disable mixed precision')
241
+
242
+ # Resume training
243
+ parser.add_argument('--resume', type=str, help='Resume from checkpoint')
244
+
245
+ args = parser.parse_args()
246
+
247
+ # Create configuration
248
+ config = TrainingConfig(
249
+ n_layer=args.n_layer,
250
+ n_head=args.n_head,
251
+ n_embd=args.n_embd,
252
+ block_size=args.block_size,
253
+ batch_size=args.batch_size,
254
+ max_iters=args.max_iters,
255
+ learning_rate=args.learning_rate,
256
+ eval_interval=args.eval_interval,
257
+ eval_iters=args.eval_iters,
258
+ weight_decay=args.weight_decay,
259
+ warmup_iters=args.warmup_iters,
260
+ lr_decay_iters=args.lr_decay_iters,
261
+ min_lr=args.min_lr,
262
+ moe_num_experts=args.moe_experts,
263
+ multi_token_predict=args.multi_token,
264
+ compile_model=not args.no_compile,
265
+ use_mixed_precision=not args.no_mixed_precision
266
+ )
267
+
268
+ print("DeepSeek Children's Stories Training")
269
+ print("=" * 50)
270
+ print(f"Configuration:")
271
+ print(f" - Model: {config.n_layer}L/{config.n_head}H/{config.n_embd}D")
272
+ print(f" - MoE: {config.moe_num_experts} experts")
273
+ print(f" - Multi-token: {config.multi_token_predict}")
274
+ print(f" - Batch size: {config.batch_size}")
275
+ print(f" - Max iterations: {config.max_iters}")
276
+ print(f" - Learning rate: {config.learning_rate}")
277
+ print(f" - Weight decay: {config.weight_decay}")
278
+ print(f" - Warmup iterations: {config.warmup_iters}")
279
+ print(f" - LR decay iterations: {config.lr_decay_iters}")
280
+ print(f" - Min learning rate: {config.min_lr}")
281
+ print("=" * 50)
282
+
283
+ # Setup environment
284
+ setup_environment()
285
+
286
+ # Prepare data
287
+ data_files = prepare_data()
288
+
289
+ # Create model
290
+ model = create_model(config)
291
+
292
+ # Resume from checkpoint if specified
293
+ if args.resume:
294
+ print(f"Resuming from checkpoint: {args.resume}")
295
+ checkpoint = torch.load(args.resume, map_location='cpu')
296
+ model.load_state_dict(checkpoint['model'])
297
+ print("Checkpoint loaded successfully!")
298
+
299
+ # Train model
300
+ train_model(model, config)
301
+
302
+ print("Training completed successfully!")
303
+ print("Best model saved to: checkpoints/best_model.pt")
304
+
305
+
306
+ if __name__ == "__main__":
307
+ main()
src/training/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
src/training/trainer.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek Trainer for Children's Stories
3
+ Advanced training with MLA, MoE, and multi-token prediction
4
+ """
5
+
6
+ import torch
7
+ import numpy as np
8
+ from tqdm.auto import tqdm
9
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR
10
+ import matplotlib.pyplot as plt
11
+ import os
12
+ import datetime
13
+ import time
14
+ import shutil
15
+ import psutil
16
+ import math
17
+ import gc
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from torch.utils.data.distributed import DistributedSampler
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from torch.distributed import init_process_group, destroy_process_group
23
+ from typing import Dict, List, Optional, Tuple
24
+
25
+ class DeepSeekTrainer:
26
+ def __init__(self, model, optimizer, device, batch_size, max_iters, eval_interval,
27
+ eval_iters, learning_rate, weight_decay, warmup_iters, lr_decay_iters,
28
+ min_lr, checkpoint_dir='checkpoints', use_mixed_precision=True):
29
+ self.model = model
30
+ self.optimizer = optimizer
31
+ self.device = device
32
+ self.batch_size = batch_size
33
+ self.max_iters = max_iters
34
+ self.eval_interval = eval_interval
35
+ self.eval_iters = eval_iters
36
+ self.learning_rate = learning_rate
37
+ self.weight_decay = weight_decay
38
+ self.warmup_iters = warmup_iters
39
+ self.lr_decay_iters = lr_decay_iters
40
+ self.min_lr = min_lr
41
+ self.checkpoint_dir = checkpoint_dir
42
+ self.use_mixed_precision = use_mixed_precision
43
+ self.best_loss = float('inf')
44
+
45
+ # Training state
46
+ self.current_iter = 0
47
+ self.train_losses = []
48
+ self.val_losses = []
49
+ self.learning_rates = []
50
+
51
+ # Create checkpoint directory if it doesn't exist
52
+ os.makedirs(checkpoint_dir, exist_ok=True)
53
+
54
+ # Initialize gradient scaler for mixed precision training
55
+ if use_mixed_precision and device == 'cuda':
56
+ self.scaler = torch.cuda.amp.GradScaler()
57
+ else:
58
+ self.scaler = None
59
+
60
+ # Initialize training metrics
61
+ self.metrics = {
62
+ 'train_loss': [],
63
+ 'val_loss': [],
64
+ 'learning_rates': [],
65
+ 'grad_norm': [],
66
+ 'memory_usage': [],
67
+ 'moe_aux_loss': [],
68
+ 'multi_token_loss': []
69
+ }
70
+
71
+ # Load data
72
+ self.data = self.load_data()
73
+ self.n = len(self.data)
74
+
75
+ def load_data(self):
76
+ """Load the training data"""
77
+ try:
78
+ data_file = os.path.join('src', 'data', 'train.bin')
79
+ if not os.path.exists(data_file):
80
+ raise FileNotFoundError(f"Training data file not found at {data_file}")
81
+
82
+ # Load data as numpy array first
83
+ data = np.memmap(data_file, dtype=np.uint16, mode='r')
84
+ # Convert to tensor
85
+ data = torch.from_numpy(data.copy()) # Make a copy to ensure it's writable
86
+ return data
87
+ except Exception as e:
88
+ print(f"Error loading data: {str(e)}")
89
+ raise
90
+
91
+ def get_batch(self, split):
92
+ """Get a batch of data"""
93
+ try:
94
+ # Generate random indices
95
+ ix = torch.randint(len(self.data) - self.model.config.block_size, (self.batch_size,))
96
+
97
+ # Get input sequences
98
+ x = torch.stack([self.data[i:i+self.model.config.block_size].long() for i in ix])
99
+ # Get target sequences (shifted by 1)
100
+ y = torch.stack([self.data[i+1:i+1+self.model.config.block_size].long() for i in ix])
101
+
102
+ # Move to device
103
+ x, y = x.to(self.device), y.to(self.device)
104
+ return x, y
105
+ except Exception as e:
106
+ print(f"Error in get_batch: {str(e)}")
107
+ raise
108
+
109
+ def get_lr(self, it):
110
+ """Get learning rate for current iteration"""
111
+ # 1) linear warmup for warmup_iters steps
112
+ if it < self.warmup_iters:
113
+ return self.learning_rate * it / self.warmup_iters
114
+ # 2) if it > lr_decay_iters, return min learning rate
115
+ if it > self.lr_decay_iters:
116
+ return self.min_lr
117
+ # 3) in between, use cosine decay down to min learning rate
118
+ decay_ratio = (it - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters)
119
+ assert 0 <= decay_ratio <= 1
120
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
121
+ return self.min_lr + coeff * (self.learning_rate - self.min_lr)
122
+
123
+ def estimate_loss(self):
124
+ """Estimate loss on validation set"""
125
+ out = {}
126
+ self.model.eval()
127
+ for split in ['train', 'val']:
128
+ losses = torch.zeros(self.eval_iters)
129
+ for k in range(self.eval_iters):
130
+ try:
131
+ X, Y = self.get_batch(split)
132
+ with torch.no_grad():
133
+ if self.scaler is not None:
134
+ with torch.cuda.amp.autocast():
135
+ logits, loss = self.model(X, Y)
136
+ else:
137
+ logits, loss = self.model(X, Y)
138
+ losses[k] = loss.item()
139
+ except Exception as e:
140
+ print(f"Error during evaluation: {str(e)}")
141
+ continue
142
+ out[split] = losses.mean()
143
+ self.model.train()
144
+ return out
145
+
146
+ def check_disk_space(self, required_space_mb=1000):
147
+ """Check if there's enough disk space for saving the model"""
148
+ try:
149
+ # Get disk usage statistics
150
+ disk_usage = psutil.disk_usage('/')
151
+ free_space_mb = disk_usage.free / (1024 * 1024) # Convert to MB
152
+
153
+ if free_space_mb < required_space_mb:
154
+ print(f"Warning: Low disk space. Only {free_space_mb:.2f}MB free, {required_space_mb}MB required")
155
+ return False
156
+ return True
157
+ except Exception as e:
158
+ print(f"Warning: Could not check disk space: {e}")
159
+ return True # Continue anyway if we can't check
160
+
161
+ def save_checkpoint(self, iter_num, loss, is_best=False):
162
+ """Save model checkpoint"""
163
+ try:
164
+ checkpoint = {
165
+ 'model': self.model.state_dict(),
166
+ 'optimizer': self.optimizer.state_dict(),
167
+ 'iter_num': iter_num,
168
+ 'loss': loss,
169
+ 'config': self.model.config,
170
+ 'train_losses': self.train_losses,
171
+ 'val_losses': self.val_losses,
172
+ 'learning_rates': self.learning_rates,
173
+ 'metrics': self.metrics,
174
+ 'best_loss': self.best_loss
175
+ }
176
+ checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_{iter_num}.pt')
177
+ torch.save(checkpoint, checkpoint_path)
178
+
179
+ if is_best:
180
+ best_path = os.path.join(self.checkpoint_dir, 'best_model.pt')
181
+ torch.save(checkpoint, best_path)
182
+ print(f"Saved best model with loss {loss:.4f}")
183
+
184
+ print(f"Saved checkpoint to {checkpoint_path}")
185
+ return True
186
+ except Exception as e:
187
+ print(f"Error saving checkpoint: {str(e)}")
188
+ return False
189
+
190
+ def load_checkpoint(self, checkpoint_path):
191
+ """Load model checkpoint with error handling"""
192
+ try:
193
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
194
+ self.model.load_state_dict(checkpoint['model'])
195
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
196
+ self.current_iter = checkpoint['iter_num']
197
+ self.best_loss = checkpoint['loss']
198
+ self.train_losses = checkpoint.get('train_losses', [])
199
+ self.val_losses = checkpoint.get('val_losses', [])
200
+ self.learning_rates = checkpoint.get('learning_rates', [])
201
+ self.metrics = checkpoint.get('metrics', self.metrics)
202
+ print(f"Successfully loaded checkpoint from iteration {self.current_iter}")
203
+ return True
204
+ except Exception as e:
205
+ print(f"Error loading checkpoint: {e}")
206
+ return False
207
+
208
+ def train(self):
209
+ """Train the DeepSeek model"""
210
+ print(f"DeepSeek Training started at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
211
+ print(f"Model: {self.model.config.n_layer} layers, {self.model.config.n_head} heads, {self.model.config.n_embd} dims")
212
+ print(f"MLA: {self.model.config.mla_kv_heads} KV heads, MoE: {self.model.config.moe_num_experts} experts")
213
+ print(f"Multi-token prediction: {self.model.config.multi_token_predict} tokens")
214
+ start_time = time.time()
215
+
216
+ try:
217
+ # Initialize training
218
+ X, Y = self.get_batch('train')
219
+ best_loss = float('inf')
220
+ current_loss = None
221
+
222
+ for iter_num in range(self.current_iter, self.max_iters):
223
+ self.current_iter = iter_num
224
+
225
+ # Determine and set the learning rate for this iteration
226
+ lr = self.get_lr(iter_num)
227
+ for param_group in self.optimizer.param_groups:
228
+ param_group['lr'] = lr
229
+
230
+ # Forward pass with mixed precision
231
+ if self.scaler is not None:
232
+ with torch.cuda.amp.autocast():
233
+ logits, loss = self.model(X, Y)
234
+ else:
235
+ logits, loss = self.model(X, Y)
236
+
237
+ # Backward pass
238
+ if self.scaler is not None:
239
+ self.scaler.scale(loss).backward()
240
+ self.scaler.step(self.optimizer)
241
+ self.scaler.update()
242
+ else:
243
+ loss.backward()
244
+ self.optimizer.step()
245
+
246
+ self.optimizer.zero_grad(set_to_none=True)
247
+
248
+ # Get new batch
249
+ X, Y = self.get_batch('train')
250
+
251
+ # Track metrics
252
+ current_loss = loss.item()
253
+ self.train_losses.append(current_loss)
254
+ self.learning_rates.append(lr)
255
+
256
+ # Update best loss
257
+ if current_loss < best_loss:
258
+ best_loss = current_loss
259
+
260
+ # Evaluation
261
+ if iter_num % self.eval_interval == 0:
262
+ losses = self.estimate_loss()
263
+ self.val_losses.append(losses['val'])
264
+
265
+ # Save checkpoint if it's the best so far
266
+ if losses['val'] < self.best_loss:
267
+ self.best_loss = losses['val']
268
+ self.save_checkpoint(iter_num, losses['val'], is_best=True)
269
+
270
+ # Regular checkpoint saving
271
+ if iter_num % (self.eval_interval * 5) == 0:
272
+ self.save_checkpoint(iter_num, losses['val'])
273
+
274
+ # Print progress
275
+ elapsed = time.time() - start_time
276
+ print(f"iter {iter_num}: train_loss {current_loss:.4f}, val_loss {losses['val']:.4f}, "
277
+ f"lr {lr:.2e}, time {elapsed:.2f}s")
278
+
279
+ # Memory usage
280
+ if self.device == 'cuda':
281
+ memory_used = torch.cuda.memory_allocated() / 1024**3
282
+ print(f"GPU memory: {memory_used:.2f} GB")
283
+
284
+ # Memory cleanup
285
+ if iter_num % 100 == 0:
286
+ gc.collect()
287
+ if self.device == 'cuda':
288
+ torch.cuda.empty_cache()
289
+
290
+ # Final checkpoint
291
+ self.save_checkpoint(self.max_iters, current_loss)
292
+
293
+ # Plot training metrics
294
+ self.plot_metrics()
295
+
296
+ print(f"Training completed in {time.time() - start_time:.2f} seconds")
297
+
298
+ except Exception as e:
299
+ print(f"Error during training: {str(e)}")
300
+ # Save emergency checkpoint
301
+ if current_loss is not None:
302
+ self.save_checkpoint(self.current_iter, current_loss)
303
+ raise
304
+
305
+ def plot_losses(self, train_losses, val_losses):
306
+ """Plot training and validation losses"""
307
+ plt.figure(figsize=(12, 4))
308
+
309
+ plt.subplot(1, 2, 1)
310
+ plt.plot(train_losses, label='Training Loss')
311
+ plt.plot(val_losses, label='Validation Loss')
312
+ plt.title('Training and Validation Loss')
313
+ plt.xlabel('Iteration')
314
+ plt.ylabel('Loss')
315
+ plt.legend()
316
+ plt.grid(True)
317
+
318
+ plt.subplot(1, 2, 2)
319
+ plt.plot(self.learning_rates)
320
+ plt.title('Learning Rate Schedule')
321
+ plt.xlabel('Iteration')
322
+ plt.ylabel('Learning Rate')
323
+ plt.grid(True)
324
+
325
+ plt.tight_layout()
326
+ plt.savefig('training_metrics.png', dpi=300, bbox_inches='tight')
327
+ plt.close()
328
+
329
+ def plot_metrics(self):
330
+ """Plot comprehensive training metrics"""
331
+ if not self.train_losses or not self.val_losses:
332
+ print("No metrics to plot")
333
+ return
334
+
335
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
336
+
337
+ # Training and validation loss
338
+ axes[0, 0].plot(self.train_losses, label='Training Loss', alpha=0.7)
339
+ axes[0, 0].plot(self.val_losses, label='Validation Loss', alpha=0.7)
340
+ axes[0, 0].set_title('Training and Validation Loss')
341
+ axes[0, 0].set_xlabel('Iteration')
342
+ axes[0, 0].set_ylabel('Loss')
343
+ axes[0, 0].legend()
344
+ axes[0, 0].grid(True)
345
+
346
+ # Learning rate
347
+ axes[0, 1].plot(self.learning_rates)
348
+ axes[0, 1].set_title('Learning Rate Schedule')
349
+ axes[0, 1].set_xlabel('Iteration')
350
+ axes[0, 1].set_ylabel('Learning Rate')
351
+ axes[0, 1].grid(True)
352
+
353
+ # Memory usage
354
+ if self.metrics['memory_usage']:
355
+ axes[1, 0].plot(self.metrics['memory_usage'])
356
+ axes[1, 0].set_title('GPU Memory Usage')
357
+ axes[1, 0].set_xlabel('Iteration')
358
+ axes[1, 0].set_ylabel('Memory (GB)')
359
+ axes[1, 0].grid(True)
360
+
361
+ # Gradient norm
362
+ if self.metrics['grad_norm']:
363
+ axes[1, 1].plot(self.metrics['grad_norm'])
364
+ axes[1, 1].set_title('Gradient Norm')
365
+ axes[1, 1].set_xlabel('Iteration')
366
+ axes[1, 1].set_ylabel('Norm')
367
+ axes[1, 1].grid(True)
368
+
369
+ plt.tight_layout()
370
+ plt.savefig('deepseek_training_metrics.png', dpi=300, bbox_inches='tight')
371
+ plt.close()
372
+
373
+ print("Training metrics saved to deepseek_training_metrics.png")
374
+
375
+
376
+ def create_deepseek_trainer(model, config):
377
+ """Create a DeepSeek trainer with the given configuration"""
378
+ # Optimizer
379
+ optimizer = torch.optim.AdamW(
380
+ model.parameters(),
381
+ lr=config.learning_rate,
382
+ weight_decay=config.weight_decay,
383
+ betas=(0.9, 0.95)
384
+ )
385
+
386
+ # Device
387
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
388
+ model = model.to(device)
389
+
390
+ # Trainer
391
+ trainer = DeepSeekTrainer(
392
+ model=model,
393
+ optimizer=optimizer,
394
+ device=device,
395
+ batch_size=config.batch_size,
396
+ max_iters=config.max_iters,
397
+ eval_interval=config.eval_interval,
398
+ eval_iters=config.eval_iters,
399
+ learning_rate=config.learning_rate,
400
+ weight_decay=config.weight_decay,
401
+ warmup_iters=config.warmup_iters,
402
+ lr_decay_iters=config.lr_decay_iters,
403
+ min_lr=config.min_lr,
404
+ checkpoint_dir=config.checkpoint_dir,
405
+ use_mixed_precision=config.use_mixed_precision
406
+ )
407
+
408
+ return trainer
training_metrics.png ADDED

Git LFS Details

  • SHA256: 73f8b42b8ab0aa737ed8a7f4c16b607e28ee36b32de6cdf170ec8bd92281d31b
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB