Update README.md
Browse files
README.md
CHANGED
@@ -13,103 +13,146 @@ model-index:
|
|
13 |
results: []
|
14 |
---
|
15 |
|
16 |
-
|
17 |
-
<img src="https://sdmntprnortheu.oaiusercontent.com/files/00000000-f580-61f4-9d8f-e2ad1ad30cb1/raw?se=2025-09-28T13%3A44%3A27Z&sp=r&sv=2024-08-04&sr=b&scid=d18de0ac-b41e-5d89-82aa-2a8c74df25d6&skoid=f28c0102-4d9d-4950-baf0-4a8e5f6cf9d4&sktid=a48cca56-e6da-484e-a814-9c849652bcb3&skt=2025-09-27T15%3A59%3A48Z&ske=2025-09-28T15%3A59%3A48Z&sks=b&skv=2024-08-04&sig=CSrmTwUK5za43FjSFhOlkzGlLkqG2CDPpKYkYtSdV6g%3D" alt="TRLm Stage 3 Banner" width="800"/>
|
18 |
-
</p>
|
19 |
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
This stage focuses on **preference alignment** using **Direct Preference Optimization (DPO)** with 50k preference pairs.
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
---
|
26 |
|
27 |
-
##
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
- **Stage**: Post-training **Stage 3 (DPO)**
|
32 |
-
- **Objective**: Align model outputs with human-preferred reasoning and answers by contrasting **chosen** vs **rejected** completions.
|
33 |
|
34 |
-
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
## 🎯 Intended Uses & Limitations
|
39 |
-
|
40 |
-
### Intended Uses
|
41 |
-
- Aligned reasoning assistant with structured `<think>` traces
|
42 |
-
- Multi-turn reasoning with preference-optimized outputs
|
43 |
-
- Safer, more useful responses for reasoning tasks
|
44 |
-
|
45 |
-
### Limitations
|
46 |
-
- Trained only on preference data → may inherit biases from source datasets
|
47 |
-
- Limited parameter count (135M) restricts knowledge breadth
|
48 |
-
- Still prone to hallucinations under complex reasoning chains
|
49 |
|
50 |
---
|
51 |
|
52 |
-
##
|
53 |
-
|
54 |
-
This model was trained on the dataset:
|
55 |
-
👉 [**Shekswess/trlm-dpo-stage-3-final-2**](https://huggingface.co/datasets/Shekswess/trlm-dpo-stage-3-final-2)
|
56 |
-
|
57 |
-
**Dataset summary**:
|
58 |
-
- **Entries**: 50,000 preference pairs
|
59 |
-
- **Source**: `scottgeng00/olmo-3-preference-mix-deltas_reasoning-yolo_scottmix-DECON-chfiltered`
|
60 |
-
- **Focus**: Preference alignment with **chosen vs rejected responses**
|
61 |
-
|
62 |
-
| Source Dataset | Split | Entries | % |
|
63 |
-
|----------------|-------|---------|---|
|
64 |
-
| scottgeng00/olmo-3-preference-mix-deltas_reasoning-yolo_scottmix-DECON-chfiltered | train | 50,000 | 100% |
|
65 |
|
66 |
---
|
67 |
|
68 |
-
##
|
69 |
-
|
70 |
-
### Training Hyperparameters
|
71 |
-
- **Learning rate**: 1e-5
|
72 |
-
- **Train batch size**: 32
|
73 |
-
- **Eval batch size**: 8
|
74 |
-
- **Gradient accumulation steps**: 4
|
75 |
-
- **Total effective batch size**: 128
|
76 |
-
- **Optimizer**: AdamW (betas=(0.9, 0.999), eps=1e-08)
|
77 |
-
- **LR Scheduler**: Cosine with minimum LR + warmup ratio 0.1
|
78 |
-
- **Epochs**: 1
|
79 |
-
- **Seed**: 42
|
80 |
-
|
81 |
-
### Framework Versions
|
82 |
-
- **Transformers**: 4.56.2
|
83 |
-
- **PyTorch**: 2.7.1+rocm7.0.0.git698b58a9
|
84 |
-
- **Datasets**: 4.0.0
|
85 |
-
- **Tokenizers**: 0.22.1
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
|
91 |
```python
|
92 |
-
from transformers import
|
93 |
|
94 |
-
model_name = "Shekswess/trlm-
|
|
|
95 |
|
96 |
# Load tokenizer & model
|
97 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
98 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
99 |
|
100 |
-
# Example
|
|
|
101 |
messages = [
|
102 |
-
{"role": "user", "content":
|
103 |
]
|
104 |
|
105 |
# Apply chat template
|
106 |
-
text = tokenizer.apply_chat_template(
|
107 |
-
|
|
|
|
|
|
|
108 |
|
|
|
|
|
|
|
109 |
outputs = model.generate(**inputs, max_new_tokens=256)
|
110 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
111 |
```
|
112 |
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
|
|
13 |
results: []
|
14 |
---
|
15 |
|
16 |
+
# Tiny Reasoning Language Model (trlm-135)
|
|
|
|
|
17 |
|
18 |
+

|
19 |
|
20 |
+
## Table of Contents
|
|
|
21 |
|
22 |
+
1. [Model Summary](#model-summary)
|
23 |
+
2. [Post-Training Pipeline](#post-training-pipeline)
|
24 |
+
3. [How to use](#how-to-use)
|
25 |
+
4. [Training](#training)
|
26 |
+
5. [Evaluation](#evaluation)
|
27 |
+
6. [Limitations](#limitations)
|
28 |
+
7. [Acknowledgements](#acknowledgements)
|
29 |
+
8. [License](#license)
|
30 |
---
|
31 |
|
32 |
+
## Model Summary
|
33 |
|
34 |
+
The **Tiny Reasoning Language Model (trlm-135)** is a **135M parameter** research prototype designed to study how small models can learn step-by-step reasoning.
|
35 |
+
It was built on top of [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) and fine-tuned through a **3-stage pipeline**:
|
|
|
|
|
36 |
|
37 |
+
* **[Stage 1 SFT](https://huggingface.co/Shekswess/trlm-stage-1-sft-final-2)**: general instruction tuning (non-reasoning).
|
38 |
+
* **[Stage 2 SFT](https://huggingface.co/Shekswess/trlm-stage-2-sft-final-2)**: reasoning traces with `<think>` tags.
|
39 |
+
* **[Stage 3 DPO](https://huggingface.co/Shekswess/trlm-stage-3-dpo-final-2)**: preference alignment for reasoning style.
|
40 |
|
41 |
+
The **code** for everything can be found **[here](https://github.com/Shekswess/tiny-reasoning-language-model/blob/main/README.md)**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
---
|
44 |
|
45 |
+
## Post-Training Pipeline
|
46 |
+
<img width="1014" height="563" alt="image" src="https://github.com/user-attachments/assets/195ef389-6aa9-4527-b4f0-bea68c0841ae" />
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
---
|
49 |
|
50 |
+
## How to use
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
```bash
|
53 |
+
pip install -U transformers accelerate
|
54 |
+
```
|
55 |
|
56 |
```python
|
57 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
58 |
|
59 |
+
model_name = "Shekswess/trlm-135m"
|
60 |
+
device = "cuda" # or "cpu"
|
61 |
|
62 |
# Load tokenizer & model
|
63 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
64 |
+
model = AutoModelForCausalLM.from_pretrained(
|
65 |
+
model_name,
|
66 |
+
).to(device)
|
67 |
|
68 |
+
# Example prompt
|
69 |
+
prompt = "Give me a brief explanation of gravity in simple terms."
|
70 |
messages = [
|
71 |
+
{"role": "user", "content": prompt}
|
72 |
]
|
73 |
|
74 |
# Apply chat template
|
75 |
+
text = tokenizer.apply_chat_template(
|
76 |
+
messages,
|
77 |
+
tokenize=False,
|
78 |
+
add_generation_prompt=True,
|
79 |
+
)
|
80 |
|
81 |
+
inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
82 |
+
|
83 |
+
# Generate
|
84 |
outputs = model.generate(**inputs, max_new_tokens=256)
|
85 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
86 |
```
|
87 |
|
88 |
+
> [!TIP]
|
89 |
+
> For reasoning-heavy tasks, set `temperature=0.6` and `top_p=0.95`.
|
90 |
+
|
91 |
+
---
|
92 |
+
|
93 |
+
## Training
|
94 |
+
|
95 |
+
### Model
|
96 |
+
|
97 |
+
* **Architecture**: Decoder-only transformer (SmolLM2 backbone which infact is Llama 3 based model).
|
98 |
+
* **Parameters**: ~135M.
|
99 |
+
* **Precision**: mix-precision (bfloat16) during training.
|
100 |
+
|
101 |
+
### Software & Hardware
|
102 |
+
|
103 |
+
* **Training Frameworks**: PyTorch (ROCm), Hugging Face Transformers & TRL.
|
104 |
+
* **Hardware**: AMD MI300X (192GB VRAM, 224GB RAM).
|
105 |
+
|
106 |
+
**Special thanks to [@HotAisle](https://x.com/HotAisle)**
|
107 |
+
|
108 |
+
### Training Stages
|
109 |
+
|
110 |
+
1. **Stage 1 – SFT (non-reasoning)**
|
111 |
+
* ~58k samples, everyday conversations & instruction following.
|
112 |
+
2. **Stage 2 – SFT (reasoning)**
|
113 |
+
* ~78k samples with `<think>` segments.
|
114 |
+
3. **Stage 3 – DPO (alignment)**
|
115 |
+
* ~50k preference pairs (chosen vs. rejected reasoning traces).
|
116 |
+
---
|
117 |
+
|
118 |
+
## Evaluation
|
119 |
+
|
120 |
+
Evaluation was done with `lm-eval-harness`:
|
121 |
+
|
122 |
+
| **Benchmark** | **Tiny Reasoning Language Model (trlm-135M)** | **SmolLM2-135M-Instruct** | **Improvements** |
|
123 |
+
| -------------------- | ---------------------------- | ------------------------- | ---------------------------- |
|
124 |
+
| **ARC Challenge** | **40.61** (avg) | 37.3 (avg) | **+3.31** |
|
125 |
+
| **BBH** | **36.80** (3-shot) | 28.2 (3-shot) | **+8.6** |
|
126 |
+
| **BoolQ** | **62.17** | – | N/A |
|
127 |
+
| **GSM8K** | **2.59** (5-shot) | 1.4 (5-shot) | **+1.19** |
|
128 |
+
| **IFEval** | **35.49** (avg) | 29.9 (avg) | **+5.59** |
|
129 |
+
| **MMLU** | **34.95** | 29.3 | **+5.65** |
|
130 |
+
| **PIQA** | **64.91** | 66.3 | **–1.39** |
|
131 |
+
| **HellaSwag** | – | 40.9 | N/A |
|
132 |
+
| **MT-Bench** | – | 19.8 | N/A |
|
133 |
+
|
134 |
+
---
|
135 |
+
|
136 |
+
## Limitations
|
137 |
+
|
138 |
+
* **Not production-ready**: hallucinations and logical errors are frequent.
|
139 |
+
* **Small size**: limited general knowledge and reasoning depth.
|
140 |
+
* **English-only**: multilingual capabilities not explored.
|
141 |
+
|
142 |
+
---
|
143 |
+
|
144 |
+
## Acknowledgements
|
145 |
+
|
146 |
+
- [@HotAisle](https://x.com/HotAisle) for providing the compute resources to train all three stages on a awesome AMD MI300x setup.
|
147 |
+
- [@mkurman88](https://x.com/mkurman88) for ideas, feedback and code samples.
|
148 |
+
- [HuggingFaceTB team](https://huggingface.co/HuggingFaceTB) for SmolLM2-135M-Instruct model and the Smoltalk2 dataset collection.
|
149 |
+
- [@scottgeng00](https://huggingface.co/scottgeng00) for the OLmO-3-Preference-Mix-Deltas dataset.
|
150 |
+
- [@eliebakouchi](https://x.com/eliebakouch) for help with the tokenization.
|
151 |
+
|
152 |
+
---
|
153 |
+
|
154 |
+
## License
|
155 |
+
|
156 |
+
[Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)
|
157 |
|
158 |
+
---
|