Fine-tuning SmolLM with Group Relative Policy Optimization (GRPO) by following the Methodologies

Group Relative Policy Optimization (GRPO) is a reinforcement learning technique designed to fine-tune language models by leveraging group-based rewards and policy optimization. It builds on concepts from Proximal Policy Optimization (PPO) but introduces a novel approach to reward calculation and policy updates by considering the relative performance of generated outputs within groups.
- Fine-tuning the SmolLM model using GRPO involves optimizing a surrogate loss derived from rewards based on key factors such as reasoning, accuracy, and formatting. The fine-tuning process follows these steps:
- Installing Required Packages
- Loading and Testing the Base Model
- Defining Helper Functions
- Defining Reward Functions
- Setting Up GRPO Configuration and Trainer
- Preparing the GSM8K Dataset
- Configuring the Trainer and Model
- Training the Model
- Inference with the Fine-Tuned Model
- Uploading the Model to the Hugging Face Hub
SmolLM2 135M Grpo Fine-tuning
Resource | Link |
---|---|
Fine-tuning Script 1 | SmolLM_x_Grpo.ipynb |
Fine-tuning Script 2 | SmolLM_x_Grpo.ipynb |
Fine-tuned Model | SmolLM2_135M_Grpo_Gsm8k |
Fine-tuned Checkpoint | SmolLM2_135M_Grpo_Checkpoint |
Method 1
Step 1: Install the Required Libraries
First, install the necessary packages. We are installing the latest versions of the Hugging Face Transformers, Accelerate, Datasets, and TRL libraries from their GitHub repositories. We also install PEFT for parameter-efficient fine-tuning.
!pip install -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/accelerate.git
!pip install -q datasets huggingface-hub trl
!pip install -q git+https://github.com/huggingface/peft.git
#!pip install flash-attn --no-build-isolation
Step 2: Import Libraries and Define Helper Functions
Import all required libraries and define helper functions for parsing and formatting model outputs. These functions will help extract the answer from the model’s XML-formatted response.
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfiga # Note: This might be a typo. In your code, you later call LoraConfig.
from trl import GRPOConfig, GRPOTrainer
# System prompt that instructs the model to use a specific XML format.
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# XML chain-of-thought format template.
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
# Function to extract the answer part from the XML response.
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
# Function to extract an answer if it is provided with a "####" delimiter.
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
Step 3: Prepare the GSM8K Dataset
We use the GSM8K dataset (a collection of grade school math problems) from Hugging Face Hub. In the get_gsm8k_questions
function, we transform each example into a prompt with a system instruction and the user’s question. (The one-shot example is commented out but can be enabled if needed.)
# Function to load and process the GSM8K dataset.
def get_gsm8k_questions(split="train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # Load the GSM8K dataset.
data = data.map(lambda x: { # Process each example.
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
# Uncomment the following lines to include a one-shot example.
# {'role': 'user', 'content': 'What is the largest single-digit prime number?'},
# {'role': 'assistant', 'content': XML_COT_FORMAT.format(
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
# answer="7"
# )},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data
# Load the processed dataset.
dataset = get_gsm8k_questions()
Step 4: Define Reward Functions
Several reward functions are defined to guide the training process. These functions evaluate different aspects of the model output such as correctness, formatting, and structural adherence to the XML format.
# Reward function to check correctness: compares the extracted answer from the response with the known answer.
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
# Reward function that checks if the response is a digit.
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
# Reward function that checks if the response strictly follows the desired XML format.
def strict_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# Reward function with a softer check for the XML format.
def soft_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in responses]
# Function to count specific XML tokens and award a small reward for each.
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
# Reward function that uses the XML token count.
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
Step 5: Set Up the Model and Tokenizer
We select the SmolLM model (HuggingFaceTB/SmolLM2-135M-Instruct
) from the Hugging Face Hub. The model is loaded with a bfloat16
data type and moved to the GPU. The tokenizer is also loaded and its padding token is set to the end-of-sequence token.
# Choose the model name.
model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
# Alternatively, you can use:
# model_name = "Qwen/Qwen2.5-1.5B-Instruct"
# Set output directories and run name based on the chosen model.
if "SmolLM2" in model_name:
output_dir = "outputs/SmolLM2-135M-GRPO"
run_name = "SmolLM2-135M-GRPO"
else:
output_dir = "outputs/Qwen-1.5B-GRPO"
run_name = "Qwen-1.5B-GRPO-gsm8k"
# Load the model.
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
#attn_implementation="flash_attention_2",
device_map=None
).to("cuda")
# Load the tokenizer and ensure that the pad token is set.
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
Step 6: Configure GRPO and PEFT
Next, we define the training configuration for GRPO as well as the PEFT (Parameter-Efficient Fine-Tuning) configuration using LoRA (Low-Rank Adaptation). Note that in the code below the PEFT configuration is created but not passed to the trainer (it is commented out). You can enable it by uncommenting the corresponding parameter.
# GRPO training configuration.
training_args = GRPOConfig(
output_dir=output_dir,
run_name=run_name,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=16, # Must be divisible by num_generations.
gradient_accumulation_steps=4,
num_generations=16, # Number of generations per prompt.
max_prompt_length=256,
max_completion_length=786,
num_train_epochs=1,
save_steps=100,
max_grad_norm=0.1,
report_to="none",
log_on_each_node=False,
)
# PEFT configuration using LoRA.
peft_config = LoraConfig(
r=16,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
Note: In the import section we used
LoraConfiga
which might be a typo. Ensure you import and useLoraConfig
correctly frompeft
.
Step 7: Initialize the GRPO Trainer
We now instantiate the GRPOTrainer with our model, tokenizer (passed as the processing_class
), reward functions, training configuration, and dataset. The reward functions are applied to each generation to provide fine-grained feedback.
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func
],
args=training_args,
train_dataset=dataset,
# peft_config=peft_config # Uncomment this line to enable LoRA-based parameter-efficient fine-tuning.
)
Step 8: Start the Training Process
Finally, call the train()
method on the trainer to begin fine-tuning the model using GRPO.
trainer.train()
Method 2
1. Installing Required Packages
First, install the latest versions of transformers
and accelerate
(directly from GitHub) along with datasets
and huggingface-hub
:
!pip install -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/accelerate.git
!pip install -q datasets huggingface-hub
2. Loading and Testing the Base Model
We load a pre-trained SmolLM model and perform a simple inference to verify the setup:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM2-360M-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
messages = [
{"role": "system", "content": "Please respond in this specific format ONLY:\n<reasoning>\n input your reasoning behind your answer in between these reasoning tags.\n</reasoning>\n<answer>\nyour answer in between these answer tags.\n</answer>\n"},
{"role": "user", "content": "How to add two numbers in Python?\n"},
]
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=256, temperature=0.2, top_p=0.9, do_sample=True, use_cache=False)
print(tokenizer.decode(outputs[0]))
3. Defining Helper Functions
We define helper functions for processing prompts and responses, including extracting user queries, assistant responses, and XML-formatted answers. These functions will be used both during training and inference.
import re
import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import defaultdict
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup
# Reasoning Instruction
SYSTEM_PROMPT = """
A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <thinking> </thinking> and
<answer> </answer> tags, respectively, i.e., <thinking> reasoning process here </thinking><answer> answer here </answer>.
Response Format rules:
- Always start your response with <thinking> tag and end with </answer>.
- Do not include any text or commentary before the opening <thinking> tag or after the closing </answer> tag.
- Do not include any text or commentary between the closing </thinking> tag and the opening <answer> tag.
For example, your response follow this format:
<thinking>
[Your detailed chain-of-thought goes here]
</thinking>
<answer>
[Your final answer goes here]
</answer>
"""
# Helpers
def get_user_prompt(prompt: str) -> str:
match = re.search(r"<\|im_start\|>user\s*(.*?)\s*<\|im_end\|>", prompt, re.DOTALL)
if match:
return match.group(1).strip()
lines = prompt.splitlines()
result = []
for line in lines:
if not line.strip().lower().startswith("system"):
if line.strip().lower().startswith("user"):
result.append(line.strip()[4:].strip())
else:
result.append(line)
return "\n".join(result).strip()
def get_assistant_response(text: str) -> str:
match = re.search(r"<\|im_start\|>assistant\s*(.*?)\s*<\|im_end\|>", text, re.DOTALL)
if match:
return match.group(1).strip()
lines = text.splitlines()
result = []
capture = False
for line in lines:
stripped = line.strip()
if stripped.lower().startswith("assistant"):
capture = True
continue
if capture:
result.append(line)
return "\n".join(result).strip()
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str:
if "####" not in text:
return text.strip()
return text.split("####", 1)[1].strip()
def count_xml(text: str) -> float:
count = 0.0
if text.count("<thinking>\n") == 1:
count += 0.225
if text.count("\n</thinking>\n") == 1:
count += 0.225
if text.count("\n<answer>\n") == 1:
count += 0.225
count -= len(text.split("\n</answer>")[-1]) * 0.001
if text.count("\n</answer>\n") == 1:
count += 0.225
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def inference(prompt: str, model_path: str) -> str:
device = config.device
model_infer = AutoModelForCausalLM.from_pretrained(model_path).to(device)
tokenizer_infer = AutoTokenizer.from_pretrained(model_path)
inputs = tokenizer_infer(prompt, return_tensors="pt", max_length=config.max_prompt_length, truncation=False)
outputs = model_infer.generate(
inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
max_new_tokens=config.max_completion_length,
do_sample=True,
pad_token_id=tokenizer_infer.eos_token_id,
temperature=config.temperature,
num_return_sequences=1,
use_cache=False
)
full_text = tokenizer_infer.decode(outputs[0])
user_question = get_user_prompt(prompt)
assistant_response = get_assistant_response(full_text)
extracted_answer = extract_xml_answer(assistant_response)
return f"{'='*10} Inference {'='*10}\nQuestion:\n{user_question}\n\nModel Response:\n{assistant_response}\n\nExtracted:\n{extracted_answer}\n{'='*12} End {'='*12}\n"
4. Defining Reward Functions
These functions assign rewards based on various criteria: the length and quality of the reasoning (inside <thinking>
tags), accuracy of the final answer, format adherence, XML tag counts, and whether the answer is an integer. These rewards will guide the GRPO updates.
# Rewards
def reasoning_reward(prompts, completions, answer, **kwargs) -> list:
rewards = []
transition_words = ["first", "next", "then", "because", "wait", "aha", "therefore", "finally", "in summary"]
pattern = r"<\s*thinking\s*>(.*?)<\s*/\s*thinking\s*>"
for comp in completions:
match = re.search(pattern, comp, re.DOTALL | re.IGNORECASE)
if match:
reasoning_text = match.group(1).strip()
words = reasoning_text.split()
reward = 0.0
# base reward if at least 25 words in between <thinking> </thinking> tags
if len(words) >= 25:
reward += 0.25
lower_text = reasoning_text.lower()
# transition words reward (case-insensitive)
transition_count = sum(1 for word in transition_words if word in lower_text)
if transition_count > 0:
reward += 0.5
# bonus reward if there are at least 30 words
if len(words) >= 50:
reward += 0.35
rewards.append(reward)
else:
rewards.append(0.0)
return rewards
def accuracy_reward(prompts, completions, answer, num_generated_samples_to_view=False, q_num=None, **kwargs) -> list:
q = prompts[0]
user_question = get_user_prompt(q)
assistant_responses = [get_assistant_response(r) for r in completions]
extracted_responses = [extract_xml_answer(get_assistant_response(r)) for r in completions]
if num_generated_samples_to_view:
print(f"{'='*15} Sample {q_num} {'='*15}\nQuestion:\n{user_question}\n\nAnswer:\n{answer[0]}\n\nResponse:\n{assistant_responses[0]}\n\nExtracted:\n{extracted_responses[0]}\n{'='*18} End {'='*18}\n")
return [2.0 if r.strip() == a.strip() else 0.0 for r, a in zip(extracted_responses, answer)]
def soft_format_reward(completions, **kwargs) -> list:
pattern = r"<thinking>.*?</thinking>\s*<answer>.*?</answer>"
return [0.5 if re.search(pattern, comp, re.DOTALL) else 0.0 for comp in completions]
def strict_format_reward(completions, **kwargs) -> list:
pattern = r"^<thinking>\n.*?\n</thinking>\n<answer>\n.*?\n</answer>\n$"
return [1.0 if re.fullmatch(pattern, comp) else 0.0 for comp in completions]
def xmlcount_reward(prompts, completions, answer, **kwargs) -> list:
return [count_xml(comp) * 0.5 for comp in completions]
def int_reward(completions, **kwargs) -> list:
return [0.5 if get_assistant_response(comp).strip().isdigit() else 0.0 for comp in completions]
5. Setting Up GRPO Configuration and Trainer
We now define a configuration class (GRPOConfig
) to hold our training parameters and a GRPOTrainer
class that implements the training loop using GRPO. The trainer handles generating completions, calculating rewards, computing the surrogate loss with a KL-penalty, and updating the model.
# GRPO Config
class GRPOConfig:
def __init__(self, **kwargs):
self.output_dir = kwargs.get("output_dir", "outputs")
self.run_name = kwargs.get("run_name", "custom_grpo")
self.learning_rate = kwargs.get("learning_rate", 1e-5)
self.weight_decay = kwargs.get("weight_decay", 0.01)
self.warmup_steps = kwargs.get("warmup_steps", 50)
self.num_generations = kwargs.get("num_generations", 1)
self.max_prompt_length = kwargs.get("max_prompt_length", 256)
self.max_completion_length = kwargs.get("max_completion_length", 256)
self.num_train_epochs = kwargs.get("num_train_epochs", 1)
self.gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps", 1)
self.clip_epsilon = kwargs.get("clip_epsilon", 0.2)
self.beta = kwargs.get("beta", 0.01)
self.logging_steps = kwargs.get("logging_steps", 1)
self.save_steps = kwargs.get("save_steps", 50)
self.max_steps = kwargs.get("max_steps", 1000)
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
self.temperature = kwargs.get("temperature", 0.2)
self.num_generated_samples_to_view = kwargs.get("num_generated_samples_to_view", 10)
self.bf16 = kwargs.get("bf16", True)
self.per_device_train_batch_size = kwargs.get("per_device_train_batch_size", 4)
self.use_flash_attn_2 = kwargs.get("use_flash_attn_2", False)
self.use_vllm = kwargs.get("use_vllm", False)
self.vllm_device = kwargs.get("vllm_device", "cuda:0")
self.vllm_gpu_memory_utilization = kwargs.get("vllm_gpu_memory_utilization", 0.8)
self.vllm_dtype = kwargs.get("vllm_dtype", "float16")
self.vllm_max_model_len = kwargs.get("vllm_max_model_len", 512)
# GRPO Trainer
class GRPOTrainer:
def __init__(self, model, tokenizer, reward_funcs, config, train_dataset):
self.dataloader = DataLoader(train_dataset, batch_size=config.per_device_train_batch_size, shuffle=True, collate_fn=lambda x: x)
self.model = model.to(config.device)
self.tokenizer = tokenizer
self.reward_funcs = reward_funcs
self.config = config
self.train_dataset = train_dataset
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
total_steps = (len(train_dataset) // config.per_device_train_batch_size) * config.num_train_epochs
self.scheduler = get_cosine_schedule_with_warmup(self.optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=total_steps)
self.ref_model = AutoModelForCausalLM.from_pretrained(model.config._name_or_path)
self.ref_model.to(config.device)
self.ref_model.eval()
for param in self.ref_model.parameters():
param.requires_grad = False
self.step = 0
self._metrics = defaultdict(list)
self.scaler = torch.cuda.amp.GradScaler() if config.device.startswith("cuda") else None
def get_per_token_logps(self, model, full_ids, attention_mask, num_logits_to_keep):
outputs = model(input_ids=full_ids, attention_mask=attention_mask)
logits = outputs.logits[:, :-1, :] # Exclude the last logit
logits_slice = logits[:, -num_logits_to_keep:, :]
token_ids = full_ids[:, -num_logits_to_keep:]
log_probs = torch.log_softmax(logits_slice, dim=-1)
token_log_probs = log_probs.gather(dim=-1, index=token_ids.unsqueeze(-1)).squeeze(-1)
return token_log_probs
def compute_loss(self, input_ids, generation_output, advantages, old_logps, attention_mask):
num_logits_to_keep = generation_output.shape[1] - input_ids.shape[1]
full_ids = generation_output
# Compute current log probabilities from the updated model
per_token_logps = self.get_per_token_logps(self.model, full_ids, attention_mask, num_logits_to_keep)
with torch.no_grad():
ref_per_token_logps = self.get_per_token_logps(self.ref_model, full_ids, attention_mask, num_logits_to_keep)
# KL divergence per token (using Schulman et al.'s approximation)
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# Compute mask for valid tokens via EOS detection
completion_ids = full_ids[:, input_ids.shape[1]:]
is_eos = (completion_ids == self.tokenizer.eos_token_id)
batch_size, seq_len = is_eos.size()
device = input_ids.device
eos_idx = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
for i in range(batch_size):
nonzero = torch.nonzero(is_eos[i], as_tuple=False)
if nonzero.numel() > 0:
eos_idx[i] = nonzero[0, 0]
sequence_indices = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
mask = (sequence_indices <= eos_idx.unsqueeze(1)).float()
# Calculate policy ratio using stored old log probabilities
ratio = torch.exp(per_token_logps - old_logps)
clipped_ratio = torch.clamp(ratio, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon)
# Clipped surrogate objective
surrogate_loss = -torch.min(ratio * advantages.unsqueeze(1), clipped_ratio * advantages.unsqueeze(1))
# Add KL penalty term
per_token_loss = surrogate_loss + self.config.beta * per_token_kl
loss = ((per_token_loss * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-8)).mean()
mean_kl = (per_token_kl * mask).sum(dim=1).mean().item()
completion_length = mask.sum(dim=1).mean().item()
return loss, mean_kl, completion_length
def evaluate_rewards(self, prompt, completions, gt_answer):
rewards_dict = {}
for func in self.reward_funcs:
if func.__name__ in ["accuracy_reward", "xmlcount_reward", "reasoning_reward"]:
r = func([prompt] * len(completions), completions, [gt_answer] * len(completions))
else:
r = func(completions)
rewards_dict[func.__name__] = r
combined_rewards = [sum(rewards_dict[func_name][i] for func_name in rewards_dict)
for i in range(len(completions))]
return combined_rewards, rewards_dict
def train(self):
self.model.train()
accumulation_counter = 0
for epoch in range(self.config.num_train_epochs):
for batch in self.dataloader:
if self.step >= self.config.max_steps:
break
example = batch[0]
prompts = example["prompts"]
gt_answer = example["answer"]
prompt_text = self.tokenizer.apply_chat_template(prompts, tokenize=False)
inputs = self.tokenizer(prompt_text, return_tensors="pt", max_length=self.config.max_prompt_length, truncation=False)
input_ids = inputs.input_ids.to(self.config.device)
attention_mask = inputs.attention_mask.to(self.config.device)
with torch.autocast(
device_type=self.config.device,
enabled=(self.scaler is not None),
dtype=(torch.bfloat16 if self.config.bf16 else torch.float16)
):
generation_output = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=self.config.max_completion_length,
do_sample=True,
temperature=self.config.temperature,
num_return_sequences=self.config.num_generations,
pad_token_id=self.tokenizer.eos_token_id,
use_cache=False
)
generation_output = generation_output.to(self.config.device)
completions = [self.tokenizer.decode(seq, skip_special_tokens=True) for seq in generation_output]
completions = [c.replace(prompt_text, "").strip() if prompt_text in c else c for c in completions]
num_gens = len(completions)
view_flag = (self.step < self.config.num_generated_samples_to_view)
acc_rewards = accuracy_reward([prompt_text]*num_gens, completions, [gt_answer]*num_gens,
num_generated_samples_to_view=view_flag, q_num=self.step)
combined_rewards, reward_dict = self.evaluate_rewards(prompt_text, completions, gt_answer)
rewards_tensor = torch.tensor(combined_rewards, device=self.config.device, dtype=torch.float)
reward_avg = rewards_tensor.mean().item()
reward_std = rewards_tensor.std().item() if rewards_tensor.numel() > 1 else 0.0
reasoning_rewards = reward_dict.get("reasoning_reward", [0.0]*len(completions))
reasoning_reward_avg = sum(reasoning_rewards) / len(reasoning_rewards)
if self.config.num_generations > 1:
rewards_grouped = rewards_tensor.view(-1, self.config.num_generations)
mean_rewards = rewards_grouped.mean(dim=1)
std_rewards = rewards_grouped.std(dim=1) + 1e-4
advantages = (rewards_tensor - mean_rewards.repeat_interleave(self.config.num_generations)) / std_rewards.repeat_interleave(self.config.num_generations)
else:
advantages = rewards_tensor
advantages = torch.clamp(advantages, -5.0, 5.0)
num_logits_to_keep = generation_output.shape[1] - input_ids.shape[1]
old_logps = self.get_per_token_logps(self.model, generation_output, attention_mask, num_logits_to_keep).detach()
loss, mean_kl, completion_length = self.compute_loss(input_ids, generation_output, advantages, old_logps, attention_mask)
loss = loss / self.config.gradient_accumulation_steps
if self.scaler is not None:
self.scaler.scale(loss).backward()
else:
loss.backward()
accumulation_counter += 1
if accumulation_counter % self.config.gradient_accumulation_steps == 0:
if self.scaler is not None:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
accumulation_counter = 0
self._metrics["loss"].append(loss.item() * self.config.gradient_accumulation_steps)
self._metrics["completion_length"].append(completion_length)
self._metrics["reward"].append(reward_avg)
self._metrics["reward_std"].append(reward_std)
self._metrics["accuracy_reward"].append(sum(acc_rewards))
self._metrics["reasoning_reward"].append(reasoning_reward_avg)
self._metrics["kl"].append(mean_kl)
# Print without reasoning reward
print(f"Step {self.step} | Loss: {loss.item()*self.config.gradient_accumulation_steps:.4f} | Reward: {reward_avg:.4f} | Reward Std: {reward_std:.4f} | Completion Length: {completion_length:.4f} | KL: {mean_kl:.4f}\n")
self.step += 1
if self.step % self.config.save_steps == 0:
checkpoint_path = os.path.join(self.config.output_dir, f"checkpoint-{self.step}")
os.makedirs(checkpoint_path, exist_ok=True)
self.model.save_pretrained(checkpoint_path)
self.tokenizer.save_pretrained(checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}\n")
test_messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "Which is heavier 1k of steel or 1kg of wool?"}
]
test_prompt = self.tokenizer.apply_chat_template(test_messages, tokenize=False)
inf_result = inference(test_prompt, checkpoint_path)
print(inf_result)
if self.step >= self.config.max_steps:
break
if self.step >= self.config.max_steps:
break
final_model_path = os.path.join(self.config.output_dir, "final_model")
os.makedirs(final_model_path, exist_ok=True)
self.model.save_pretrained(final_model_path)
self.tokenizer.save_pretrained(final_model_path)
print(f"Final model saved to {final_model_path}")
plt.figure(figsize=(14, 10))
plt.subplot(3, 2, 1)
plt.plot(self._metrics["accuracy_reward"], label="Accuracy", color="blue")
plt.title("Accuracy vs Steps")
plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.legend()
plt.subplot(3, 2, 2)
plt.plot(self._metrics["reward"], label="Reward", color="green")
plt.title("Reward vs Steps")
plt.xlabel("Steps")
plt.ylabel("Reward")
plt.legend()
plt.subplot(3, 2, 3)
plt.plot(self._metrics["reward_std"], label="Reward Std", color="orange")
plt.title("Reward Std vs Steps")
plt.xlabel("Steps")
plt.ylabel("Reward Std")
plt.legend()
plt.subplot(3, 2, 4)
plt.plot(self._metrics["kl"], label="KL Penalty", color="red")
plt.title("KL Penalty vs Steps")
plt.xlabel("Steps")
plt.ylabel("KL Penalty")
plt.legend()
plt.subplot(3, 2, 5)
plt.plot(self._metrics["completion_length"], label="Avg Completion Length", color="purple")
plt.title("Avg Completion Length vs Steps")
plt.xlabel("Steps")
plt.ylabel("Completion Length")
plt.legend()
# plt.subplot(3, 2, 6)
# plt.plot(self._metrics["reasoning_reward"], label="Reasoning Reward", color="brown")
# plt.title("Reasoning Reward vs Steps")
# plt.xlabel("Steps")
# plt.ylabel("Reasoning Reward")
# plt.legend()
plt.tight_layout()
plt.show()
6. Preparing the GSM8K Dataset
For our training data, we use the GSM8K dataset (a collection of math questions). We reformat the data so that each example contains a prompt (with the system prompt and user question) and the corresponding answer.
# GSM8K Dataset & Chat Temp
def get_gsm8k_data(split="train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split]
data = data.map(lambda x: {
'prompts': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
})
return data
dataset = get_gsm8k_data()
7. Configuring the Trainer and Model
Set up the GRPO configuration and load the SmolLM model (using a smaller 135M variant) for fine-tuning. Also, specify the reward functions to be used.
# Trainer and Config
config = GRPOConfig(
output_dir="outputs/SmolLM2_135M_Grpo_Gsm8k",
run_name="smollm2_135m_grpo_gsm8k_reasoner",
learning_rate=5e-6,
weight_decay=0.01,
warmup_steps=100,
num_generations=2,
max_prompt_length=256,
max_completion_length=200,
num_train_epochs=1,
gradient_accumulation_steps=1,
clip_epsilon=0.2,
beta=0.01,
logging_steps=1,
save_steps=250,
max_steps=500,
temperature=0.2,
num_generated_samples_to_view=250,
bf16=True,
per_device_train_batch_size=1,
# use_flash_attn_2=True, # Enable Flash Attention 2 (GPU only)
# use_vllm=True, # use vLLM (GPU only)
# vllm_device="cuda:0", # vLLM device config (GPU only)
# vllm_gpu_memory_utilization=0.3 # vLLM GPU memory utilization (GPU only)
)
model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2" if config.use_flash_attn_2 else None,
use_cache=False
).to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer.pad_token = tokenizer.eos_token
reward_functions = [reasoning_reward, accuracy_reward, soft_format_reward, strict_format_reward, int_reward, xmlcount_reward]
trainer = GRPOTrainer(model, tokenizer, reward_functions, config, dataset)
8. Training the Model
Start the training process. The trainer will generate multiple completions, compute rewards, update the model with the GRPO loss (including KL-penalty), and save checkpoints along the way.
# Train
trainer.train()
9. Inference with the Fine-Tuned Model
After training, test the fine-tuned model on a new question.
sample = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "If there are 12 cookies in a dozen and you have 5 dozen, how many cookies do you have?"}
]
final_prompt = tokenizer.apply_chat_template(sample, tokenize=False)
print(inference(final_prompt, os.path.join(config.output_dir, "final_model")))
10. Uploading the Model to the Hugging Face Hub
Finally, log into your Hugging Face account and push the trained model to a repository.
# Login Hf and Push to Repo
from huggingface_hub import notebook_login
notebook_login()
# Import the HfApi class from the huggingface_hub library.
from huggingface_hub import HfApi
api = HfApi()
repo_id = f"prithivMLmods/SmolLM2_135M_Grpo_Checkpoint"
try:
# Attempt to create a new repository on the Hugging Face Model Hub using the specified repo_id.
api.create_repo(repo_id)
print(f"Repo {repo_id} created")
except:
print(f"Repo {repo_id} already exists")
api.upload_folder(
folder_path="outputs/SmolLM2_135M_Grpo_Gsm8k/final_model", # The path to the folder to be uploaded
path_in_repo=".", # The path where the folder will be stored in the repository
repo_id=repo_id, # The ID of the repository where the folder will be uploaded
repo_type="model", # The type of the repository (in this case, a model repository)
revision="main" # Revision name
)
Conclusion
GRPO is a powerful reinforcement learning technique for fine-tuning language models. By leveraging group-based relative rewards and KL-divergence regularization, it enables stable and efficient learning while encouraging high-quality, structured outputs. Its flexibility makes it suitable for a wide range of tasks, from reasoning to instruction following.
By following these steps and using the provided code, you can customize SmolLM to enhance reasoning, format adherence, and overall answer accuracy using the GRPO approach.
Happy fine-tuning! 🤗