TRL documentation
Reward Functions
Reward Functions
This module contains some useful reward functions, primarily intended for use with the GRPOTrainer and RLOOTrainer.
accuracy_reward
trl.rewards.accuracy_reward
< source >( completions: list solution: list **kwargs )
Parameters
- completions (
list[list[dict[str, str]]]) — List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary containing the key"content"with the value being the text of the completion. - solution — (
list[str]): List of the raw-text solutions to the questions/problems/prompts. - **kwargs — Additional keyword arguments. This function does not use them, but they are required in the function signature to ensure compatibility with trainers like GRPOTrainer.
Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable β use math verification.
- If gold is not parseable β return
Noneto skip the example.
Example:
>>> from trl.rewards import accuracy_reward
>>> solutions = [r"\frac{1}{3}", r"\frac{1}{3}"]
>>> completions = [
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
... ]
>>> accuracy_reward(completions, solutions)
[1.0, 0.0]reasoning_accuracy_reward
trl.rewards.reasoning_accuracy_reward
< source >( completions: list solution: list reasoning_delimiters: list[str] | None = None **kwargs )
Parameters
- completions (
list[list[dict[str, str]]]) — List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary containing the key"content"with the value being the text of the completion. - solution — (
list[str]): List of the raw-text solutions to the questions/problems/prompts. - reasoning_delimiters (
list[str]], optional) — List of strings indicating where the reasoning content ends. The final answer is assumed to be after the last occurrence of any of these delimiters. IfNone, defaults to["</think>"]. - **kwargs — Additional keyword arguments. This function does not use them, but they are required in the function signature to ensure compatibility with trainers like GRPOTrainer.
Reward function that removes the reasoning content and checks if the final answer matches the ground truth.
- If both gold and prediction are parseable β use math verification.
- If gold is not parseable β return
Noneto skip the example.
Example:
>>> from trl.rewards import reasoning_accuracy_reward
>>> reasoning_delimiters = ["</think>"]
>>> solutions = [r"\frac{1}{3}", r"\frac{1}{3}", r"\frac{1}{3}"]
>>> completions = [
... [
... {
... "role": "assistant",
... "content": r"<think> Reasoning content </think> The final answer is \boxed{\frac{1}{3}}",
... }
... ],
... [
... {
... "role": "assistant",
... "content": r"<think> Reasoning content </think> The final answer is \boxed{\frac{1}{2}}",
... }
... ],
... [
... {
... "role": "assistant",
... "content": r"<think> Reasoning content with partial answers \boxed{\frac{1}{3}} but no final answer",
... }
... ],
... ]
>>> reasoning_accuracy_reward(completions, solutions, reasoning_delimiters=reasoning_delimiters)
[1.0, 0.0, 0.0]think_format_reward
trl.rewards.think_format_reward
< source >( completions: list **kwargs ) β list[float]
Parameters
- completions (
list[list[dict[str, str]]]) — List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary containing the key"content"with the value being the text of the completion. - **kwargs — Additional keyword arguments. This function does not use them, but they are required in the function signature to ensure compatibility with trainers like GRPOTrainer.
Returns
list[float]
A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0.
Reward function that checks if the reasoning process is enclosed within "<think>" and "</think>" tags. The
function returns a reward of 1.0 if the format is correct, otherwise 0.0.
get_soft_overlong_punishment
trl.rewards.get_soft_overlong_punishment
< source >( max_completion_len: int soft_punish_cache: int )
Reward function that penalizes overlong completions. It is used to penalize overlong completions, but not to reward shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476)
Example:
from trl.rewards import get_soft_overlong_punishment
soft_overlong_punishment = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 90] # simulating a completion with 90 tokens. 90 is between 80 and 100.
rewards = soft_overlong_punishment(completion_ids)
print(rewards) # [-0.5]