Llama 3.1-8B Mathematical Reasoning (GRPO)
Full-Post-Training Instruction
Please visit our notebook for a full walkthrough on this project.
Model Description
This model is a fine-tuned version of meta-llama/Meta-Llama-3.1-8B-Instruct trained using Group Relative Policy Optimization (GRPO) to enable autonomous mathematical reasoning capabilities. Unlike traditional supervised fine-tuning approaches, this model learned to reason through reinforcement learning without requiring pre-labeled reasoning examples or teacher model distillation.
Key Features
- Structured Reasoning: Generates step-by-step mathematical reasoning enclosed in
<thinking>
tags - Formatted Answers: Provides final numerical answers in
<answer>
tags - Autonomous Learning: Trained using reward functions rather than human-annotated reasoning examples
- Mathematical Focus: Optimized for grade school math word problems requiring multi-step reasoning
Training Details
Training Pipeline
GRPO Trainer Workflow
While setting up a GRPO config is straightforward, understanding what each hyperparameter means and how the underlying training logic works is essential. This guide walks through the GRPO training loop, highlighting the key code paths in trl/trainer/grpo_trainer.py.
GRPO Trainer in Context
GRPO is built on top of the TRL Trainer class, but it customizes key parts of the training pipeline by overriding several critical methods:
compute_loss
prepare_input
These functions are the heart of GRPO's training logic.
Training Workflow
train_step Process
In each train_step
:
Trainer.train_step()
in the parent class first callsprepare_input()
in the GRPO trainer- Then it calls
compute_loss()
with the output ofprepare_input()
Input Preparation Logic
High-Level Behavior
- Called once per
self.num_iterations
to generate fresh completions - Buffering is used to reuse completions across multiple gradient updates
Example Configuration:
num_iterations = 4
: generate completions every 4 global steps. Each global step can consist of k grad accumulation stepsgradient_accumulation_steps = 5
: buffer stores 5 sets of completions- If
num_iterations = 1
, completions are not reused, and buffering may be unnecessary
Core Logic
@profiling_decorator
def _prepare_inputs(self, inputs: dict) -> dict:
mode = "eval" if self.control.should_evaluate else "train"
if mode == "train":
if self.state.global_step % self.num_iterations == 0:
inputs = self._generate_and_score_completions(inputs)
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
else:
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
self._step += 1
else:
inputs = self._generate_and_score_completions(inputs)
return inputs
Prompt Construction
Conversations are stored in the "prompt" field. The apply_chat_template()
function is used to convert this into a generation-ready string:
prompt = tokenizer.apply_chat_template(
example["prompt"],
tools=tools,
tokenize=False,
add_generation_prompt=True
)
Inference and Completion Generation
Non-vLLM Path
For the non-vLLM implementation path:
- Each rank tokenizes its prompts and truncates them by
max_prompt_length
- Each rank sends its text prompts to the on-device LLM
Example Output: The last 5 tokens for 6 generations made on GPU0. With batched inference, we append EOS (128009) for early completed answers:
prompt_completion_ids[:,-5:]
# Output tensor:
tensor([[128009, 128009, 128009, 128009, 128009],
[128009, 128009, 128009, 128009, 128009],
[128009, 128009, 128009, 128009, 128009],
[ 279, 1176, 220, 605, 14741],
[ 220, 430, 568, 11021, 220],
[128009, 128009, 128009, 128009, 128009]], device='cuda:0')
Note: You can check the code for vLLM path which is more efficient.
Completion Masking
Completion masks identify valid (non-EOS) completions. Each rank computes completion masks based on the first occurrence of EOS.
Example Mask: The following is the mask for the completions shown above:
completion_mask[:,-5:]
# Output tensor:
tensor([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 0, 0, 0, 0]], device='cuda:0', dtype=torch.int32)
Reward Computation and Normalization
Each rank performs the following operations:
- Decodes completions
- Computes reward for (prompt, completion) pairs
- Gathers rewards from other ranks (because it's possible for a given prompt to have its replica across GPUs)
- Normalizes rewards by mean/std ⟹ This gives us advantages $A(s,a)$
- Discards completions for prompts it doesn't own (called alien prompts)
Concrete Example: Multi-GPU Setup
Setup:
- 3 GPUs (Ranks)
- Batch size per GPU: 8
- Number of generations per prompt: 6
Total Prompts Processed per Iteration:
- Total batch size = 3 GPUs × 8 prompts = 24 prompts
- Since each prompt needs 6 completions, we can only afford to have: 24 / 6 = 4 unique prompts
- Each prompt is replicated 6 times across all ranks
Prompt Distribution Example on GPU 0 (Rank 0):
- 8 total prompts:
- 6 replicas of Prompt #1
- 2 replicas of Prompt #2
Note: Other ranks may hold the remaining replicas of Prompt #2 and additional prompts.
Reward Normalization Process:
For Prompt #1:
- All 6 completions are on GPU 0
- ✅ Rank 0 can compute mean and std of rewards locally
For Prompt #2:
- Rank 0 has only 2 out of 6 completions
- ❌ It cannot compute accurate reward statistics alone
- ✅ Needs to gather the remaining 4 rewards from other ranks
Important: That's why all-gather is needed so each rank has access to the required replicas.
Logprob Computation
- Rollouts (prompts + completions) are stored per grad accumulation step (
step % G
) during the first iteration - These buffered inputs are reused for the remaining iterations within that
num_iterations
window - Buffering occurs only during the first
num_iterations
step (i.e., atglobal_step % num_iterations == 0
)
Key Points:
- Logprobs for old policy and ref policy are computed during the first iteration only (when completions are freshly generated)
- Logprobs for current policy are computed in every gradient accumulation step, including when using buffered inputs
Note: Old and current policy logprobs will be identical when
num_iterations = 1
.
Loss Computation
GRPO Loss Formula
compute_loss() Breakdown
Steps:
Concatenate
prompt_ids + completion_ids
Run forward pass through old policy to compute $\pi_{\text{old}}(a|s)$
- This actually happens only once at the first iteration when we create the rollout
Run forward pass through ref policy to compute $\pi_{\text{ref}}(a|s)$
- This actually happens only once at the first iteration when we create the rollout
- Ref model is the original model without LoRA adapters
Run forward pass through current policy to compute $\pi(a|s)$
- Needed only if
num_iterations > 1
; otherwise the same as old policy
- Needed only if
Compute KL loss between $\pi(a|s)$ and $\pi_{\text{ref}}(a|s)$
Compute advantage-weighted logprobs: $\frac{\pi(a|s)}{\pi_{\text{old}}(a|s)} \times A(s,a)$
Workflow Summary
Component | What It Does | Why It Matters |
---|---|---|
GRPO Trainer Class | Extends the TRL Trainer, overrides prepare_input and compute_loss |
Customizes the loss and input preparation for rollout reuse and reward learning |
train_step | Calls prepare_input then compute_loss |
Controls how and when rollouts and gradients are processed |
prepare_input() | Generates and buffers rollouts once every num_iterations steps |
Allows reuse of expensive rollouts across multiple updates |
Prompt Construction | Uses apply_chat_template to create generation-ready input |
Ensures the model sees correctly formatted conversational prompts |
Inference | Uses on-device or vLLM backend to generate completions | Provides actions for which logprobs and rewards are computed |
Completion Masking | Identifies valid (non-EOS) completions | Ensures reward/logprob computation only applies to meaningful tokens |
Reward Normalization | Aggregates and normalizes rewards across GPUs for each prompt | Yields correct advantage estimates across distributed setup |
Logprob Computation | Computes logprobs of old, ref, and current policy | Needed for KL and advantage-weighted loss; reused if num_iterations > 1 |
compute_loss() | Combines KL divergence and advantage-weighted logprob ratio | Drives the optimization update direction for the policy |
- Up to here, we are only walking through the key code paths in
trl/trainer/grpo_trainer.py
to help understand the GRPO training workflow.*
Training Method
- Algorithm: Group Relative Policy Optimization (GRPO)
- Base Model: meta-llama/Meta-Llama-3.1-8B-Instruct
- Fine-tuning: LoRA with NF4 quantization
- Dataset: GSM8K (Grade School Math 8K)
Training Configuration
- LoRA rank: 8
- LoRA alpha: 32
- Learning rate: 2e-4
- Batch size: 32 (effective)
- Generations per prompt: 2-4
- Max completion length: 128-256 tokens
- Temperature: 0.7
Reward Functions
The model was trained using multiple custom reward functions:
- Format compliance: Ensures proper
<thinking>
and<answer>
tag structure - Answer correctness: Validates numerical accuracy against ground truth
- Integer validation: Confirms answers are properly formatted integers
- Output cleanliness: Penalizes extraneous text after answer tags
Usage
Required Format
The model expects a specific system prompt to maintain the structured output format:
SYSTEM_PROMPT = """
Respond in the following format and make sure the entire response is wrapped in <reasoning> and <answer> tags with no other text outside of these tags:
<thinking>
your reasoning goes here and do not use newlines so that the entire reasoning becomes a single paragraph
</thinking>
<answer>
your answer goes here
</answer>
"""
Example Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("your-username/llama3-grpo-math-reasoning")
tokenizer = AutoTokenizer.from_pretrained("your-username/llama3-grpo-math-reasoning")
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": "Mark is looking to buy a total of 12 pieces of fruit at the store. He has already chosen 3 apples. He has also selected a bunch of bananas containing 4 bananas. How many oranges does he need to pick out to have 12 total pieces of fruit?"}
]
input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
Expected Output:
<thinking>
Mark wants 12 total pieces of fruit. He already has 3 apples and 4 bananas, which gives him 3 + 4 = 7 pieces of fruit. To reach 12 total pieces, he needs 12 - 7 = 5 more pieces of fruit, which would be oranges.
</thinking>
<answer>
5
</answer>
Performance
- Dataset: GSM8K test split (1,319 examples)
- Evaluation Metric: Exact match accuracy on final numerical answers
- Performance: [Insert actual accuracy from your evaluation]
Technical Details
Model Architecture
- Base: Llama 3.1-8B-Instruct
- Parameters: ~8B total, ~21M trainable (LoRA)
- Quantization: NF4 with double quantization
- Precision: bfloat16
GRPO Specifics
- Group size: 2-4 completions per prompt
- KL coefficient (β): 0.02
- Advantage normalization: Group-based mean/std normalization
- Policy updates: Advantage-weighted importance sampling
Limitations
- Domain: Primarily trained on grade school mathematics; may not generalize to advanced mathematical concepts
- Format dependency: Requires specific system prompt for optimal performance
- Language: English only
- Reasoning scope: Optimized for multi-step arithmetic rather than complex mathematical proofs
Ethical Considerations
- Educational use: Designed to assist with mathematical learning, not replace mathematical education
- Accuracy: While trained for correctness, users should verify mathematical solutions for critical applications
- Bias: Inherits potential biases from the base Llama model and GSM8K dataset
Citation
If you use this model, please cite:
@misc{llama3-grpo-math-reasoning,
title={Llama 3.1-8B Mathematical Reasoning via Group Relative Policy Optimization},
author={[Your Name]},
year={2025},
howpublished={\\url{https://huggingface.co/your-username/llama3-grpo-math-reasoning}}
}
Acknowledgments
- Meta AI for the Llama 3.1 base model
- OpenAI for the GSM8K dataset
- Hugging Face TRL library for GRPO implementation
- The reinforcement learning and mathematical reasoning research communities
License
This model inherits the license from meta-llama/Meta-Llama-3.1-8B-Instruct. Please refer to the original model's license for usage terms.
- Downloads last month
- 24
Model tree for lordChipotle/Llama3GRPOReasoning
Base model
meta-llama/Llama-3.1-8B