Update README.md
Browse files
README.md
CHANGED
@@ -196,19 +196,23 @@ $$\text{GRPO} = -\mathbb{E}_{(s,a)}\left[\frac{\pi(a|s)}{\pi_{\text{old}}(a|s)}
|
|
196 |
|
197 |
1. **Concatenate** `prompt_ids + completion_ids`
|
198 |
|
199 |
-
2. **Run forward pass through old policy** to compute
|
|
|
|
|
|
|
|
|
200 |
- This actually happens only once at the first iteration when we create the rollout
|
201 |
|
202 |
-
|
203 |
- This actually happens only once at the first iteration when we create the rollout
|
204 |
- Ref model is the original model without LoRA adapters
|
205 |
|
206 |
-
|
207 |
- Needed only if `num_iterations > 1`; otherwise the same as old policy
|
208 |
|
209 |
-
|
210 |
|
211 |
-
|
212 |
|
213 |
## Workflow Summary
|
214 |
|
@@ -305,11 +309,6 @@ Mark wants 12 total pieces of fruit. He already has 3 apples and 4 bananas, whic
|
|
305 |
</answer>
|
306 |
```
|
307 |
|
308 |
-
## Performance
|
309 |
-
|
310 |
-
- **Dataset**: GSM8K test split (1,319 examples)
|
311 |
-
- **Evaluation Metric**: Exact match accuracy on final numerical answers
|
312 |
-
- **Performance**: [Insert actual accuracy from your evaluation]
|
313 |
|
314 |
## Technical Details
|
315 |
|
|
|
196 |
|
197 |
1. **Concatenate** `prompt_ids + completion_ids`
|
198 |
|
199 |
+
2. **Run forward pass through old policy** to compute
|
200 |
+
|
201 |
+
$$\pi_{\text{old}}(a|s)$$
|
202 |
+
|
203 |
+
|
204 |
- This actually happens only once at the first iteration when we create the rollout
|
205 |
|
206 |
+
4. **Run forward pass through ref policy** to compute $$\pi_{\text{ref}}(a|s)$$
|
207 |
- This actually happens only once at the first iteration when we create the rollout
|
208 |
- Ref model is the original model without LoRA adapters
|
209 |
|
210 |
+
5. **Run forward pass through current policy** to compute $$\pi(a|s)$$
|
211 |
- Needed only if `num_iterations > 1`; otherwise the same as old policy
|
212 |
|
213 |
+
6. **Compute KL loss** between $$\pi(a|s)$ and $\pi_{\text{ref}}(a|s)$$
|
214 |
|
215 |
+
7. **Compute advantage-weighted logprobs:** $$\frac{\pi(a|s)}{\pi_{\text{old}}(a|s)} \times A(s,a)$$
|
216 |
|
217 |
## Workflow Summary
|
218 |
|
|
|
309 |
</answer>
|
310 |
```
|
311 |
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
## Technical Details
|
314 |
|