Shekswess commited on
Commit
98ab9d7
·
verified ·
1 Parent(s): 44da050

Upload 12 files

Browse files
Files changed (5) hide show
  1. README.md +136 -45
  2. config.json +1 -1
  3. model.safetensors +1 -1
  4. tokenizer.json +1 -8
  5. training_args.bin +2 -2
README.md CHANGED
@@ -1,67 +1,158 @@
1
  ---
2
  library_name: transformers
3
- model_name: trlm-135m
 
4
  tags:
5
- - generated_from_trainer
6
- - grpo
7
  - trl
8
- licence: license
 
 
 
 
 
 
9
  ---
10
 
11
- # Model Card for trlm-135m
12
 
13
- This model is a fine-tuned version of [None](https://huggingface.co/None).
14
- It has been trained using [TRL](https://github.com/huggingface/trl).
15
 
16
- ## Quick start
17
 
18
- ```python
19
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
22
- generator = pipeline("text-generation", model="Shekswess/trlm-135m", device="cuda")
23
- output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
24
- print(output["generated_text"])
25
  ```
26
 
27
- ## Training procedure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
-
 
30
 
 
31
 
32
- This model was trained with GRPO, a method introduced in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
33
 
34
- ### Framework versions
35
 
36
- - TRL: 0.23.0
37
- - Transformers: 4.56.1
38
- - Pytorch: 2.8.0+cu126
39
- - Datasets: 4.0.0
40
- - Tokenizers: 0.22.0
41
 
42
- ## Citations
43
 
44
- Cite GRPO as:
 
45
 
46
- ```bibtex
47
- @article{shao2024deepseekmath,
48
- title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
49
- author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
50
- year = 2024,
51
- eprint = {arXiv:2402.03300},
52
- }
53
 
54
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- Cite TRL as:
57
-
58
- ```bibtex
59
- @misc{vonwerra2022trl,
60
- title = {{TRL: Transformer Reinforcement Learning}},
61
- author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
62
- year = 2020,
63
- journal = {GitHub repository},
64
- publisher = {GitHub},
65
- howpublished = {\url{https://github.com/huggingface/trl}}
66
- }
67
- ```
 
1
  ---
2
  library_name: transformers
3
+ license: apache-2.0
4
+ base_model: Shekswess/trlm-stage-2-sft-final-2
5
  tags:
 
 
6
  - trl
7
+ - dpo
8
+ - preference-alignment
9
+ - reasoning
10
+ - generated_from_trainer
11
+ model-index:
12
+ - name: trlm-stage-3-dpo-final-2
13
+ results: []
14
  ---
15
 
16
+ # Tiny Reasoning Language Model (trlm-135)
17
 
18
+ ![image/png](https://github.com/user-attachments/assets/5f453496-8180-4cf4-94da-26ebbe1159d4)
 
19
 
20
+ ## Table of Contents
21
 
22
+ 1. [Model Summary](#model-summary)
23
+ 2. [Post-Training Pipeline](#post-training-pipeline)
24
+ 3. [How to use](#how-to-use)
25
+ 4. [Training](#training)
26
+ 5. [Evaluation](#evaluation)
27
+ 6. [Limitations](#limitations)
28
+ 7. [Acknowledgements](#acknowledgements)
29
+ 8. [License](#license)
30
+ ---
31
+
32
+ ## Model Summary
33
+
34
+ The **Tiny Reasoning Language Model (trlm-135)** is a **135M parameter** research prototype designed to study how small models can learn step-by-step reasoning.
35
+ It was built on top of [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) and fine-tuned through a **3-stage pipeline**:
36
+
37
+ * **[Stage 1 SFT](https://huggingface.co/Shekswess/trlm-stage-1-sft-final-2)**: general instruction tuning (non-reasoning).
38
+ * **[Stage 2 SFT](https://huggingface.co/Shekswess/trlm-stage-2-sft-final-2)**: reasoning traces with `<think>` tags.
39
+ * **[Stage 3 DPO](https://huggingface.co/Shekswess/trlm-stage-3-dpo-final-2)**: preference alignment for reasoning style.
40
+
41
+ The **code** for everything can be found **[here](https://github.com/Shekswess/tiny-reasoning-language-model/blob/main/README.md)**
42
+
43
+ ---
44
+
45
+ ## Post-Training Pipeline
46
+ <img width="1014" height="563" alt="image" src="https://github.com/user-attachments/assets/195ef389-6aa9-4527-b4f0-bea68c0841ae" />
47
+
48
+ ---
49
+
50
+ ## How to use
51
 
52
+ ```bash
53
+ pip install -U transformers accelerate
 
 
54
  ```
55
 
56
+ ```python
57
+ from transformers import AutoModelForCausalLM, AutoTokenizer
58
+
59
+ model_name = "Shekswess/trlm-135m"
60
+ device = "cuda" # or "cpu"
61
+
62
+ # Load tokenizer & model
63
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+ model = AutoModelForCausalLM.from_pretrained(
65
+ model_name,
66
+ ).to(device)
67
+
68
+ # Example prompt
69
+ prompt = "Give me a brief explanation of gravity in simple terms."
70
+ messages = [
71
+ {"role": "user", "content": prompt}
72
+ ]
73
+
74
+ # Apply chat template
75
+ text = tokenizer.apply_chat_template(
76
+ messages,
77
+ tokenize=False,
78
+ add_generation_prompt=True,
79
+ )
80
+
81
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
82
+
83
+ # Generate
84
+ outputs = model.generate(**inputs, max_new_tokens=256)
85
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
86
+ ```
87
 
88
+ > [!TIP]
89
+ > For reasoning-heavy tasks, set `temperature=0.6` and `top_p=0.95`.
90
 
91
+ ---
92
 
93
+ ## Training
94
 
95
+ ### Model
96
 
97
+ * **Architecture**: Decoder-only transformer (SmolLM2 backbone which infact is Llama 3 based model).
98
+ * **Parameters**: ~135M.
99
+ * **Precision**: mix-precision (bfloat16) during training.
 
 
100
 
101
+ ### Software & Hardware
102
 
103
+ * **Training Frameworks**: PyTorch (ROCm), Hugging Face Transformers & TRL.
104
+ * **Hardware**: AMD MI300X (192GB VRAM, 224GB RAM).
105
 
106
+ **Special thanks to [@HotAisle](https://x.com/HotAisle)**
 
 
 
 
 
 
107
 
108
+ ### Training Stages
109
+
110
+ 1. **Stage 1 – SFT (non-reasoning)**
111
+ * ~58k samples, everyday conversations & instruction following.
112
+ 2. **Stage 2 – SFT (reasoning)**
113
+ * ~78k samples with `<think>` segments.
114
+ 3. **Stage 3 – DPO (alignment)**
115
+ * ~50k preference pairs (chosen vs. rejected reasoning traces).
116
+ ---
117
+
118
+ ## Evaluation
119
+
120
+ Evaluation was done with `lm-eval-harness`:
121
+
122
+ | **Benchmark** | **Tiny Reasoning Language Model (trlm-135M)** | **SmolLM2-135M-Instruct** | **Improvements** |
123
+ | -------------------- | ---------------------------- | ------------------------- | ---------------------------- |
124
+ | **ARC Challenge** | **40.61** (avg) | 37.3 (avg) | **+3.31** |
125
+ | **BBH** | **36.80** (3-shot) | 28.2 (3-shot) | **+8.6** |
126
+ | **BoolQ** | **62.17** | – | N/A |
127
+ | **GSM8K** | **2.59** (5-shot) | 1.4 (5-shot) | **+1.19** |
128
+ | **IFEval** | **35.49** (avg) | 29.9 (avg) | **+5.59** |
129
+ | **MMLU** | **34.95** | 29.3 | **+5.65** |
130
+ | **PIQA** | **64.91** | 66.3 | **–1.39** |
131
+ | **HellaSwag** | – | 40.9 | N/A |
132
+ | **MT-Bench** | – | 19.8 | N/A |
133
+
134
+ ---
135
+
136
+ ## Limitations
137
+
138
+ * **Not production-ready**: hallucinations and logical errors are frequent.
139
+ * **Small size**: limited general knowledge and reasoning depth.
140
+ * **English-only**: multilingual capabilities not explored.
141
+
142
+ ---
143
+
144
+ ## Acknowledgements
145
+
146
+ - [@HotAisle](https://x.com/HotAisle) for providing the compute resources to train all three stages on a awesome AMD MI300x setup.
147
+ - [@mkurman88](https://x.com/mkurman88) for ideas, feedback and code samples.
148
+ - [HuggingFaceTB team](https://huggingface.co/HuggingFaceTB) for SmolLM2-135M-Instruct model and the Smoltalk2 dataset collection.
149
+ - [@scottgeng00](https://huggingface.co/scottgeng00) for the OLmO-3-Preference-Mix-Deltas dataset.
150
+ - [@eliebakouchi](https://x.com/eliebakouch) for help with the tokenization.
151
+
152
+ ---
153
+
154
+ ## License
155
+
156
+ [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)
157
 
158
+ ---
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -32,7 +32,7 @@
32
  "q4f16": "float16"
33
  }
34
  },
35
- "transformers_version": "4.56.1",
36
  "use_cache": true,
37
  "vocab_size": 49154
38
  }
 
32
  "q4f16": "float16"
33
  }
34
  },
35
+ "transformers_version": "4.56.2",
36
  "use_cache": true,
37
  "vocab_size": 49154
38
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6df764ae38de74036a73950a26c4ce2eabe6461855cb96ee3668e6c6c0d1b15a
3
  size 269062856
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f72be34d84f9d95d2f1ba9e6aa5d352ec841ce1e7aed81998135ce5bf96e9a09
3
  size 269062856
tokenizer.json CHANGED
@@ -1,14 +1,7 @@
1
  {
2
  "version": "1.0",
3
  "truncation": null,
4
- "padding": {
5
- "strategy": "BatchLongest",
6
- "direction": "Left",
7
- "pad_to_multiple_of": null,
8
- "pad_id": 2,
9
- "pad_type_id": 0,
10
- "pad_token": "<|im_end|>"
11
- },
12
  "added_tokens": [
13
  {
14
  "id": 0,
 
1
  {
2
  "version": "1.0",
3
  "truncation": null,
4
+ "padding": null,
 
 
 
 
 
 
 
5
  "added_tokens": [
6
  {
7
  "id": 0,
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7a4eaa81d993d68c33c29824d61c91baead1e98508a5a625eb153ada959ab125
3
- size 7185
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:247b792c868f8245ceddd15c2b2486a99317401202045e562ed34e945a36ed82
3
+ size 6865