Shekswess commited on
Commit
196ed3d
·
verified ·
1 Parent(s): 7bbf7f8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +111 -68
README.md CHANGED
@@ -13,103 +13,146 @@ model-index:
13
  results: []
14
  ---
15
 
16
- <p align="center">
17
- <img src="https://sdmntprnortheu.oaiusercontent.com/files/00000000-f580-61f4-9d8f-e2ad1ad30cb1/raw?se=2025-09-28T13%3A44%3A27Z&sp=r&sv=2024-08-04&sr=b&scid=d18de0ac-b41e-5d89-82aa-2a8c74df25d6&skoid=f28c0102-4d9d-4950-baf0-4a8e5f6cf9d4&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2025-09-27T15%3A59%3A48Z&ske=2025-09-28T15%3A59%3A48Z&sks=b&skv=2024-08-04&sig=CSrmTwUK5za43FjSFhOlkzGlLkqG2CDPpKYkYtSdV6g%3D" alt="TRLm Stage 3 Banner" width="800"/>
18
- </p>
19
 
20
- # 🧠 trlm-stage-3-dpo-final-2
21
 
22
- `trlm-stage-3-dpo-final-2` is the **Stage 3** post-training model for the **Tiny Reasoning Language Model (trlm)** project.
23
- This stage focuses on **preference alignment** using **Direct Preference Optimization (DPO)** with 50k preference pairs.
24
 
 
 
 
 
 
 
 
 
25
  ---
26
 
27
- ## 📖 Model Description
28
 
29
- - **Base Model**: [Shekswess/trlm-stage-2-sft-final-2](https://huggingface.co/Shekswess/trlm-stage-2-sft-final-2)
30
- - **Type**: Causal Language Model (decoder-only transformer)
31
- - **Stage**: Post-training **Stage 3 (DPO)**
32
- - **Objective**: Align model outputs with human-preferred reasoning and answers by contrasting **chosen** vs **rejected** completions.
33
 
34
- This stage improves the model’s **alignment**, **coherence**, and **reasoning stability**.
 
 
35
 
36
- ---
37
-
38
- ## 🎯 Intended Uses & Limitations
39
-
40
- ### Intended Uses
41
- - Aligned reasoning assistant with structured `<think>` traces
42
- - Multi-turn reasoning with preference-optimized outputs
43
- - Safer, more useful responses for reasoning tasks
44
-
45
- ### Limitations
46
- - Trained only on preference data → may inherit biases from source datasets
47
- - Limited parameter count (135M) restricts knowledge breadth
48
- - Still prone to hallucinations under complex reasoning chains
49
 
50
  ---
51
 
52
- ## 📊 Training Data
53
-
54
- This model was trained on the dataset:
55
- 👉 [**Shekswess/trlm-dpo-stage-3-final-2**](https://huggingface.co/datasets/Shekswess/trlm-dpo-stage-3-final-2)
56
-
57
- **Dataset summary**:
58
- - **Entries**: 50,000 preference pairs
59
- - **Source**: `scottgeng00/olmo-3-preference-mix-deltas_reasoning-yolo_scottmix-DECON-chfiltered`
60
- - **Focus**: Preference alignment with **chosen vs rejected responses**
61
-
62
- | Source Dataset | Split | Entries | % |
63
- |----------------|-------|---------|---|
64
- | scottgeng00/olmo-3-preference-mix-deltas_reasoning-yolo_scottmix-DECON-chfiltered | train | 50,000 | 100% |
65
 
66
  ---
67
 
68
- ## ⚙️ Training Procedure
69
-
70
- ### Training Hyperparameters
71
- - **Learning rate**: 1e-5
72
- - **Train batch size**: 32
73
- - **Eval batch size**: 8
74
- - **Gradient accumulation steps**: 4
75
- - **Total effective batch size**: 128
76
- - **Optimizer**: AdamW (betas=(0.9, 0.999), eps=1e-08)
77
- - **LR Scheduler**: Cosine with minimum LR + warmup ratio 0.1
78
- - **Epochs**: 1
79
- - **Seed**: 42
80
-
81
- ### Framework Versions
82
- - **Transformers**: 4.56.2
83
- - **PyTorch**: 2.7.1+rocm7.0.0.git698b58a9
84
- - **Datasets**: 4.0.0
85
- - **Tokenizers**: 0.22.1
86
 
87
- ---
88
-
89
- ## 🚀 Usage
90
 
91
  ```python
92
- from transformers import AutoTokenizer, AutoModelForCausalLM
93
 
94
- model_name = "Shekswess/trlm-stage-3-dpo-final-2"
 
95
 
96
  # Load tokenizer & model
97
  tokenizer = AutoTokenizer.from_pretrained(model_name)
98
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
99
 
100
- # Example inference with preference-aligned reasoning
 
101
  messages = [
102
- {"role": "user", "content": "Explain why the sky is blue in simple terms."}
103
  ]
104
 
105
  # Apply chat template
106
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
107
- inputs = tokenizer([text], return_tensors="pt")
 
 
 
108
 
 
 
 
109
  outputs = model.generate(**inputs, max_new_tokens=256)
110
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
111
  ```
112
 
113
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- Part of the Tiny Reasoning Language Model (trlm) post-training pipeline.
 
13
  results: []
14
  ---
15
 
16
+ # Tiny Reasoning Language Model (trlm-135)
 
 
17
 
18
+ ![image/png](https://github.com/user-attachments/assets/5f453496-8180-4cf4-94da-26ebbe1159d4)
19
 
20
+ ## Table of Contents
 
21
 
22
+ 1. [Model Summary](#model-summary)
23
+ 2. [Post-Training Pipeline](#post-training-pipeline)
24
+ 3. [How to use](#how-to-use)
25
+ 4. [Training](#training)
26
+ 5. [Evaluation](#evaluation)
27
+ 6. [Limitations](#limitations)
28
+ 7. [Acknowledgements](#acknowledgements)
29
+ 8. [License](#license)
30
  ---
31
 
32
+ ## Model Summary
33
 
34
+ The **Tiny Reasoning Language Model (trlm-135)** is a **135M parameter** research prototype designed to study how small models can learn step-by-step reasoning.
35
+ It was built on top of [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) and fine-tuned through a **3-stage pipeline**:
 
 
36
 
37
+ * **[Stage 1 SFT](https://huggingface.co/Shekswess/trlm-stage-1-sft-final-2)**: general instruction tuning (non-reasoning).
38
+ * **[Stage 2 SFT](https://huggingface.co/Shekswess/trlm-stage-2-sft-final-2)**: reasoning traces with `<think>` tags.
39
+ * **[Stage 3 DPO](https://huggingface.co/Shekswess/trlm-stage-3-dpo-final-2)**: preference alignment for reasoning style.
40
 
41
+ The **code** for everything can be found **[here](https://github.com/Shekswess/tiny-reasoning-language-model/blob/main/README.md)**
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  ---
44
 
45
+ ## Post-Training Pipeline
46
+ <img width="1014" height="563" alt="image" src="https://github.com/user-attachments/assets/195ef389-6aa9-4527-b4f0-bea68c0841ae" />
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  ---
49
 
50
+ ## How to use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ ```bash
53
+ pip install -U transformers accelerate
54
+ ```
55
 
56
  ```python
57
+ from transformers import AutoModelForCausalLM, AutoTokenizer
58
 
59
+ model_name = "Shekswess/trlm-135m"
60
+ device = "cuda" # or "cpu"
61
 
62
  # Load tokenizer & model
63
  tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+ model = AutoModelForCausalLM.from_pretrained(
65
+ model_name,
66
+ ).to(device)
67
 
68
+ # Example prompt
69
+ prompt = "Give me a brief explanation of gravity in simple terms."
70
  messages = [
71
+ {"role": "user", "content": prompt}
72
  ]
73
 
74
  # Apply chat template
75
+ text = tokenizer.apply_chat_template(
76
+ messages,
77
+ tokenize=False,
78
+ add_generation_prompt=True,
79
+ )
80
 
81
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
82
+
83
+ # Generate
84
  outputs = model.generate(**inputs, max_new_tokens=256)
85
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
86
  ```
87
 
88
+ > [!TIP]
89
+ > For reasoning-heavy tasks, set `temperature=0.6` and `top_p=0.95`.
90
+
91
+ ---
92
+
93
+ ## Training
94
+
95
+ ### Model
96
+
97
+ * **Architecture**: Decoder-only transformer (SmolLM2 backbone which infact is Llama 3 based model).
98
+ * **Parameters**: ~135M.
99
+ * **Precision**: mix-precision (bfloat16) during training.
100
+
101
+ ### Software & Hardware
102
+
103
+ * **Training Frameworks**: PyTorch (ROCm), Hugging Face Transformers & TRL.
104
+ * **Hardware**: AMD MI300X (192GB VRAM, 224GB RAM).
105
+
106
+ **Special thanks to [@HotAisle](https://x.com/HotAisle)**
107
+
108
+ ### Training Stages
109
+
110
+ 1. **Stage 1 – SFT (non-reasoning)**
111
+ * ~58k samples, everyday conversations & instruction following.
112
+ 2. **Stage 2 – SFT (reasoning)**
113
+ * ~78k samples with `<think>` segments.
114
+ 3. **Stage 3 – DPO (alignment)**
115
+ * ~50k preference pairs (chosen vs. rejected reasoning traces).
116
+ ---
117
+
118
+ ## Evaluation
119
+
120
+ Evaluation was done with `lm-eval-harness`:
121
+
122
+ | **Benchmark** | **Tiny Reasoning Language Model (trlm-135M)** | **SmolLM2-135M-Instruct** | **Improvements** |
123
+ | -------------------- | ---------------------------- | ------------------------- | ---------------------------- |
124
+ | **ARC Challenge** | **40.61** (avg) | 37.3 (avg) | **+3.31** |
125
+ | **BBH** | **36.80** (3-shot) | 28.2 (3-shot) | **+8.6** |
126
+ | **BoolQ** | **62.17** | – | N/A |
127
+ | **GSM8K** | **2.59** (5-shot) | 1.4 (5-shot) | **+1.19** |
128
+ | **IFEval** | **35.49** (avg) | 29.9 (avg) | **+5.59** |
129
+ | **MMLU** | **34.95** | 29.3 | **+5.65** |
130
+ | **PIQA** | **64.91** | 66.3 | **–1.39** |
131
+ | **HellaSwag** | – | 40.9 | N/A |
132
+ | **MT-Bench** | – | 19.8 | N/A |
133
+
134
+ ---
135
+
136
+ ## Limitations
137
+
138
+ * **Not production-ready**: hallucinations and logical errors are frequent.
139
+ * **Small size**: limited general knowledge and reasoning depth.
140
+ * **English-only**: multilingual capabilities not explored.
141
+
142
+ ---
143
+
144
+ ## Acknowledgements
145
+
146
+ - [@HotAisle](https://x.com/HotAisle) for providing the compute resources to train all three stages on a awesome AMD MI300x setup.
147
+ - [@mkurman88](https://x.com/mkurman88) for ideas, feedback and code samples.
148
+ - [HuggingFaceTB team](https://huggingface.co/HuggingFaceTB) for SmolLM2-135M-Instruct model and the Smoltalk2 dataset collection.
149
+ - [@scottgeng00](https://huggingface.co/scottgeng00) for the OLmO-3-Preference-Mix-Deltas dataset.
150
+ - [@eliebakouchi](https://x.com/eliebakouch) for help with the tokenization.
151
+
152
+ ---
153
+
154
+ ## License
155
+
156
+ [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)
157
 
158
+ ---