TreeRPO-Qwen2.5-Math-1.5B
Summary:
A 1.5B parameter math reasoning model fine-tuned with TreeRPO, a hierarchical extension of GRPO that assigns rewards to “thought” nodes (not just full completions). Achieves higher GSM8K accuracy with just ~10K supervised + RL examples and no reward model.
🔎 Full write-up (method, math, analysis):
TreeRPO: Hierarchical Credit Assignment for Reasoning in Language Models
Model Details
- Base model:
Qwen/Qwen2.5-Math-1.5B
- Method: TreeRPO (tree-structured GRPO;)
- Reward signal: Deterministic exact-match checker (binary). Interior node rewards = mean descendant leaf rewards.
- Domain: Grade-school and intermediate math word problems (GSM8K style)
Intended Use
Research on hierarchical RL for reasoning; math tutoring (with human oversight); or as a research baseline for deterministic pass/fail domains (potential to extend to code with unit tests).
Not intended for:
Open-ended or unsafe dialog, general factual QA, or high-stakes applications.
Evaluation (GSM8K Test Set, 1,319 problems)
Model | Greedy (%) | Maj@8 (%) | Notes |
---|---|---|---|
Qwen2.5-Math-1.5B-Instruct | 84.8 | 89.5 | Reported settings |
Qwen2.5-Math-1.5B-TreeRPO | 86.4 | 89.6 | Same decoding (temp 0 / (0.7, 0.8)) |
- Greedy: temperature = 0 (deterministic)
- Maj@8: 8 completions (temperature 0.7, top-p 0.8); majority vote on final boxed answer
How to Use
If your Transformers version supports chat templates (≥4.38), use:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "omrisap/TreeRPO-Qwen2.5-Math-1.5B"
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
messages = [
{"role": "system", "content": "You are a helpful math reasoning assistant. Provide step-by-step reasoning and put the final answer in \\boxed{}."},
{"role": "user", "content": "If 3x + 5 = 17, what is x?"}
]
prompt_text = tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tok(prompt_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.0)
print(tok.decode(outputs[0], skip_special_tokens=True))
- Downloads last month
- 20
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support