Update README.md
Browse files
README.md
CHANGED
@@ -137,7 +137,7 @@ Each rank performs the following operations:
|
|
137 |
1. **Decodes completions**
|
138 |
2. **Computes reward** for (prompt, completion) pairs
|
139 |
3. **Gathers rewards** from other ranks (because it's possible for a given prompt to have its replica across GPUs)
|
140 |
-
4. **Normalizes rewards** by mean/std ⟹ This gives us advantages
|
141 |
5. **Discards completions** for prompts it doesn't own (called alien prompts)
|
142 |
|
143 |
### Concrete Example: Multi-GPU Setup
|
@@ -196,19 +196,19 @@ $$\text{GRPO} = -\mathbb{E}_{(s,a)}\left[\frac{\pi(a|s)}{\pi_{\text{old}}(a|s)}
|
|
196 |
|
197 |
1. **Concatenate** `prompt_ids + completion_ids`
|
198 |
|
199 |
-
2. **Run forward pass through old policy** to compute
|
200 |
- This actually happens only once at the first iteration when we create the rollout
|
201 |
|
202 |
-
3. **Run forward pass through ref policy** to compute
|
203 |
- This actually happens only once at the first iteration when we create the rollout
|
204 |
- Ref model is the original model without LoRA adapters
|
205 |
|
206 |
-
4. **Run forward pass through current policy** to compute
|
207 |
- Needed only if `num_iterations > 1`; otherwise the same as old policy
|
208 |
|
209 |
-
5. **Compute KL loss** between
|
210 |
|
211 |
-
6. **Compute advantage-weighted logprobs:**
|
212 |
|
213 |
## Workflow Summary
|
214 |
|
|
|
137 |
1. **Decodes completions**
|
138 |
2. **Computes reward** for (prompt, completion) pairs
|
139 |
3. **Gathers rewards** from other ranks (because it's possible for a given prompt to have its replica across GPUs)
|
140 |
+
4. **Normalizes rewards** by mean/std ⟹ This gives us advantages $$A(s,a)$$
|
141 |
5. **Discards completions** for prompts it doesn't own (called alien prompts)
|
142 |
|
143 |
### Concrete Example: Multi-GPU Setup
|
|
|
196 |
|
197 |
1. **Concatenate** `prompt_ids + completion_ids`
|
198 |
|
199 |
+
2. **Run forward pass through old policy** to compute $$\pi_{\text{old}}(a|s)$$
|
200 |
- This actually happens only once at the first iteration when we create the rollout
|
201 |
|
202 |
+
3. **Run forward pass through ref policy** to compute $$\pi_{\text{ref}}(a|s)$$
|
203 |
- This actually happens only once at the first iteration when we create the rollout
|
204 |
- Ref model is the original model without LoRA adapters
|
205 |
|
206 |
+
4. **Run forward pass through current policy** to compute $$\pi(a|s)$$
|
207 |
- Needed only if `num_iterations > 1`; otherwise the same as old policy
|
208 |
|
209 |
+
5. **Compute KL loss** between $$\pi(a|s)$ and $\pi_{\text{ref}}(a|s)$$
|
210 |
|
211 |
+
6. **Compute advantage-weighted logprobs:** $$\frac{\pi(a|s)}{\pi_{\text{old}}(a|s)} \times A(s,a)$$
|
212 |
|
213 |
## Workflow Summary
|
214 |
|