Update README.md
Browse files
README.md
CHANGED
@@ -33,7 +33,7 @@ expected output:
|
|
33 |
|
34 |
## Example
|
35 |
|
36 |
-
```
|
37 |
from transformers import pipeline
|
38 |
|
39 |
formula = "9 + 3 * 5 = ?" # A + B * C か A * B + C の形式のみ対応
|
@@ -52,4 +52,155 @@ prompt = f"""\
|
|
52 |
|
53 |
print(pipe(prompt)[0]["generated_text"][len(prompt):])
|
54 |
# <think>9 + 3 * 5 = 9 + 15 = 24</think><answer>24</answer>
|
55 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
## Example
|
35 |
|
36 |
+
```py
|
37 |
from transformers import pipeline
|
38 |
|
39 |
formula = "9 + 3 * 5 = ?" # A + B * C か A * B + C の形式のみ対応
|
|
|
52 |
|
53 |
print(pipe(prompt)[0]["generated_text"][len(prompt):])
|
54 |
# <think>9 + 3 * 5 = 9 + 15 = 24</think><answer>24</answer>
|
55 |
+
```
|
56 |
+
|
57 |
+
## Training information
|
58 |
+
|
59 |
+
- Device: 1x A100 80G
|
60 |
+
- GPU Hour: about 1 hour
|
61 |
+
- Base model: [Qwen/Qwen2.5-0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B)
|
62 |
+
|
63 |
+
Wandb log: https://wandb.ai/p1atdev/grpo-math-01/runs/ytv8wxll
|
64 |
+
|
65 |
+
## Training code
|
66 |
+
|
67 |
+
```py
|
68 |
+
import random
|
69 |
+
import re
|
70 |
+
|
71 |
+
import torch
|
72 |
+
from datasets import Dataset
|
73 |
+
from trl import GRPOConfig, GRPOTrainer
|
74 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
75 |
+
import wandb
|
76 |
+
|
77 |
+
SYSTEM_PROMPT = """命令:
|
78 |
+
あなたはアシスタントとして回答します。
|
79 |
+
ユーザーの質問に対して、<think></think>ブロック内で思考してから<answer></answer>でファイナルアンサーしてください。
|
80 |
+
具体的には、「<think>ここに思考過程</think><answer>ここに解答</answer>」という形です。
|
81 |
+
「ユーザー」の質問の後に、「アシスタント」が回答します。
|
82 |
+
ユーザー:
|
83 |
+
"""
|
84 |
+
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
|
85 |
+
|
86 |
+
def generate_problem():
|
87 |
+
# written by ChatGPT
|
88 |
+
# 1~10 の間のランダムな整数を3つ生成
|
89 |
+
a = random.randint(1, 10)
|
90 |
+
b = random.randint(1, 10)
|
91 |
+
c = random.randint(1, 10)
|
92 |
+
|
93 |
+
# 足し算と掛け算の両方を含むように、2通りのパターンからランダムに選択
|
94 |
+
if random.randint(0, 1) == 0:
|
95 |
+
# パターン1: 足し算+掛け算 => 例: a + b * c
|
96 |
+
expression = f"{a} + {b} * {c}"
|
97 |
+
else:
|
98 |
+
# パターン2: 掛け算+足し算 => 例: a * b + c
|
99 |
+
expression = f"{a} * {b} + {c}"
|
100 |
+
|
101 |
+
# Python の eval() を用いて答えを計算(演算子の優先順位に従う)
|
102 |
+
answer = eval(expression)
|
103 |
+
|
104 |
+
return f"{expression} = ?", answer
|
105 |
+
|
106 |
+
|
107 |
+
def generate_random_pair(max_count: int):
|
108 |
+
for i in range(max_count):
|
109 |
+
formula, answer = generate_problem()
|
110 |
+
question = f"""{SYSTEM_PROMPT}
|
111 |
+
次の ? に入る数値を計算して回答してください。
|
112 |
+
{formula}
|
113 |
+
|
114 |
+
アシスタント:
|
115 |
+
"""
|
116 |
+
yield {"id": i, "prompt": question, "ground_truth": answer}
|
117 |
+
|
118 |
+
|
119 |
+
# format reward
|
120 |
+
FORMAT_PATTERN = re.compile(r"^<think>.*?</think><answer>.*?</answer>$")
|
121 |
+
|
122 |
+
def format_reward_func(completions: list[str], **kwargs):
|
123 |
+
"""Reward function that checks if the completion has a specific format."""
|
124 |
+
matches = [FORMAT_PATTERN.match(content) for content in completions]
|
125 |
+
return [1.0 if match else 0.0 for match in matches]
|
126 |
+
|
127 |
+
|
128 |
+
# answer reward
|
129 |
+
ANSWER_PATTERN = re.compile(r"<answer>(\d+)</answer>")
|
130 |
+
|
131 |
+
def answer_reward_func(completions: list[str], ground_truth: list[str], **kwargs):
|
132 |
+
# Regular expression to capture content inside \boxed{}
|
133 |
+
matches = [ANSWER_PATTERN.search(completion) for completion in completions]
|
134 |
+
contents = [match.group(1) if match else "" for match in matches]
|
135 |
+
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
136 |
+
return [1.0 if c == str(gt) else 0.0 for c, gt in zip(contents, ground_truth)]
|
137 |
+
|
138 |
+
|
139 |
+
def main():
|
140 |
+
ds = Dataset.from_generator(generate_random_pair, gen_kwargs={"max_count": 100000}) # 100000 is too many, we don't need so much for this task
|
141 |
+
model = AutoModelForCausalLM.from_pretrained(
|
142 |
+
MODEL_NAME,
|
143 |
+
attn_implementation="flash_attention_2",
|
144 |
+
torch_dtype=torch.bfloat16,
|
145 |
+
device_map="auto",
|
146 |
+
)
|
147 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
148 |
+
tokenizer.pad_token = tokenizer.eos_token
|
149 |
+
|
150 |
+
project_name = YOUR_WANDB_PROJECT_NAME
|
151 |
+
push_hub_name = YOUR_PUSH_HUB_NAME
|
152 |
+
|
153 |
+
wandb.init(project=project_name)
|
154 |
+
train_args = GRPOConfig(
|
155 |
+
output_dir="./grpo-01", #! output path
|
156 |
+
use_vllm=False, # True to use vLLM
|
157 |
+
overwrite_output_dir=True,
|
158 |
+
num_train_epochs=10,
|
159 |
+
num_generations=4,
|
160 |
+
per_device_train_batch_size=16,
|
161 |
+
# per_device_eval_batch_size=4,
|
162 |
+
gradient_accumulation_steps=1,
|
163 |
+
gradient_checkpointing=True,
|
164 |
+
learning_rate=1e-4, # maybe a bit high
|
165 |
+
warmup_ratio=0.01,
|
166 |
+
weight_decay=0.01,
|
167 |
+
optim="adamw_8bit",
|
168 |
+
adam_epsilon=1e-8,
|
169 |
+
lr_scheduler_type="cosine_with_min_lr",
|
170 |
+
lr_scheduler_kwargs={
|
171 |
+
"min_lr": 5e-5,
|
172 |
+
"num_cycles": 0.5,
|
173 |
+
},
|
174 |
+
# eval_strategy="steps", # eval did not work well
|
175 |
+
# eval_steps=10,
|
176 |
+
save_steps=10,
|
177 |
+
save_total_limit=2,
|
178 |
+
logging_steps=1,
|
179 |
+
logging_first_step=True,
|
180 |
+
# load_best_model_at_end=True,
|
181 |
+
# metric_for_best_model="eval_loss",
|
182 |
+
torch_compile=False, # compile does not work
|
183 |
+
fp16=False,
|
184 |
+
bf16=True,
|
185 |
+
report_to=["wandb"],
|
186 |
+
hub_model_id=push_hub_name,
|
187 |
+
hub_private_repo=True,
|
188 |
+
push_to_hub=True,
|
189 |
+
save_safetensors=True,
|
190 |
+
)
|
191 |
+
|
192 |
+
trainer = GRPOTrainer(
|
193 |
+
model=model,
|
194 |
+
processing_class=tokenizer,
|
195 |
+
train_dataset=ds,
|
196 |
+
# eval_dataset=ds["test"],
|
197 |
+
reward_funcs=[format_reward_func, answer_reward_func],
|
198 |
+
args=train_args,
|
199 |
+
)
|
200 |
+
|
201 |
+
trainer.train()
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
main()
|
206 |
+
```
|