p1atdev commited on
Commit
72dca67
·
verified ·
1 Parent(s): 9ede090

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +153 -2
README.md CHANGED
@@ -33,7 +33,7 @@ expected output:
33
 
34
  ## Example
35
 
36
- ```
37
  from transformers import pipeline
38
 
39
  formula = "9 + 3 * 5 = ?" # A + B * C か A * B + C の形式のみ対応
@@ -52,4 +52,155 @@ prompt = f"""\
52
 
53
  print(pipe(prompt)[0]["generated_text"][len(prompt):])
54
  # <think>9 + 3 * 5 = 9 + 15 = 24</think><answer>24</answer>
55
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  ## Example
35
 
36
+ ```py
37
  from transformers import pipeline
38
 
39
  formula = "9 + 3 * 5 = ?" # A + B * C か A * B + C の形式のみ対応
 
52
 
53
  print(pipe(prompt)[0]["generated_text"][len(prompt):])
54
  # <think>9 + 3 * 5 = 9 + 15 = 24</think><answer>24</answer>
55
+ ```
56
+
57
+ ## Training information
58
+
59
+ - Device: 1x A100 80G
60
+ - GPU Hour: about 1 hour
61
+ - Base model: [Qwen/Qwen2.5-0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B)
62
+
63
+ Wandb log: https://wandb.ai/p1atdev/grpo-math-01/runs/ytv8wxll
64
+
65
+ ## Training code
66
+
67
+ ```py
68
+ import random
69
+ import re
70
+
71
+ import torch
72
+ from datasets import Dataset
73
+ from trl import GRPOConfig, GRPOTrainer
74
+ from transformers import AutoTokenizer, AutoModelForCausalLM
75
+ import wandb
76
+
77
+ SYSTEM_PROMPT = """命令:
78
+ あなたはアシスタントとして回答します。
79
+ ユーザーの質問に対して、<think></think>ブロック内で思考してから<answer></answer>でファイナルアンサーしてください。
80
+ 具体的には、「<think>ここに思考過程</think><answer>ここに解答</answer>」という形です。
81
+ 「ユーザー」の質問の後に、「アシスタント」が回答します。
82
+ ユーザー:
83
+ """
84
+ MODEL_NAME = "Qwen/Qwen2.5-0.5B"
85
+
86
+ def generate_problem():
87
+ # written by ChatGPT
88
+ # 1~10 の間のランダムな整数を3つ生成
89
+ a = random.randint(1, 10)
90
+ b = random.randint(1, 10)
91
+ c = random.randint(1, 10)
92
+
93
+ # 足し算と掛け算の両方を含むように、2通りのパターンからランダムに選択
94
+ if random.randint(0, 1) == 0:
95
+ # パターン1: 足し算+掛け算 => 例: a + b * c
96
+ expression = f"{a} + {b} * {c}"
97
+ else:
98
+ # パターン2: 掛け算+足し算 => 例: a * b + c
99
+ expression = f"{a} * {b} + {c}"
100
+
101
+ # Python の eval() を用いて答えを計算(演算子の優先順位に従う)
102
+ answer = eval(expression)
103
+
104
+ return f"{expression} = ?", answer
105
+
106
+
107
+ def generate_random_pair(max_count: int):
108
+ for i in range(max_count):
109
+ formula, answer = generate_problem()
110
+ question = f"""{SYSTEM_PROMPT}
111
+ 次の ? に入る数値を計算して回答してください。
112
+ {formula}
113
+
114
+ アシスタント:
115
+ """
116
+ yield {"id": i, "prompt": question, "ground_truth": answer}
117
+
118
+
119
+ # format reward
120
+ FORMAT_PATTERN = re.compile(r"^<think>.*?</think><answer>.*?</answer>$")
121
+
122
+ def format_reward_func(completions: list[str], **kwargs):
123
+ """Reward function that checks if the completion has a specific format."""
124
+ matches = [FORMAT_PATTERN.match(content) for content in completions]
125
+ return [1.0 if match else 0.0 for match in matches]
126
+
127
+
128
+ # answer reward
129
+ ANSWER_PATTERN = re.compile(r"<answer>(\d+)</answer>")
130
+
131
+ def answer_reward_func(completions: list[str], ground_truth: list[str], **kwargs):
132
+ # Regular expression to capture content inside \boxed{}
133
+ matches = [ANSWER_PATTERN.search(completion) for completion in completions]
134
+ contents = [match.group(1) if match else "" for match in matches]
135
+ # Reward 1 if the content is the same as the ground truth, 0 otherwise
136
+ return [1.0 if c == str(gt) else 0.0 for c, gt in zip(contents, ground_truth)]
137
+
138
+
139
+ def main():
140
+ ds = Dataset.from_generator(generate_random_pair, gen_kwargs={"max_count": 100000}) # 100000 is too many, we don't need so much for this task
141
+ model = AutoModelForCausalLM.from_pretrained(
142
+ MODEL_NAME,
143
+ attn_implementation="flash_attention_2",
144
+ torch_dtype=torch.bfloat16,
145
+ device_map="auto",
146
+ )
147
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
148
+ tokenizer.pad_token = tokenizer.eos_token
149
+
150
+ project_name = YOUR_WANDB_PROJECT_NAME
151
+ push_hub_name = YOUR_PUSH_HUB_NAME
152
+
153
+ wandb.init(project=project_name)
154
+ train_args = GRPOConfig(
155
+ output_dir="./grpo-01", #! output path
156
+ use_vllm=False, # True to use vLLM
157
+ overwrite_output_dir=True,
158
+ num_train_epochs=10,
159
+ num_generations=4,
160
+ per_device_train_batch_size=16,
161
+ # per_device_eval_batch_size=4,
162
+ gradient_accumulation_steps=1,
163
+ gradient_checkpointing=True,
164
+ learning_rate=1e-4, # maybe a bit high
165
+ warmup_ratio=0.01,
166
+ weight_decay=0.01,
167
+ optim="adamw_8bit",
168
+ adam_epsilon=1e-8,
169
+ lr_scheduler_type="cosine_with_min_lr",
170
+ lr_scheduler_kwargs={
171
+ "min_lr": 5e-5,
172
+ "num_cycles": 0.5,
173
+ },
174
+ # eval_strategy="steps", # eval did not work well
175
+ # eval_steps=10,
176
+ save_steps=10,
177
+ save_total_limit=2,
178
+ logging_steps=1,
179
+ logging_first_step=True,
180
+ # load_best_model_at_end=True,
181
+ # metric_for_best_model="eval_loss",
182
+ torch_compile=False, # compile does not work
183
+ fp16=False,
184
+ bf16=True,
185
+ report_to=["wandb"],
186
+ hub_model_id=push_hub_name,
187
+ hub_private_repo=True,
188
+ push_to_hub=True,
189
+ save_safetensors=True,
190
+ )
191
+
192
+ trainer = GRPOTrainer(
193
+ model=model,
194
+ processing_class=tokenizer,
195
+ train_dataset=ds,
196
+ # eval_dataset=ds["test"],
197
+ reward_funcs=[format_reward_func, answer_reward_func],
198
+ args=train_args,
199
+ )
200
+
201
+ trainer.train()
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()
206
+ ```