Update README.md
Browse files
README.md
CHANGED
@@ -137,7 +137,7 @@ Each rank performs the following operations:
|
|
137 |
1. **Decodes completions**
|
138 |
2. **Computes reward** for (prompt, completion) pairs
|
139 |
3. **Gathers rewards** from other ranks (because it's possible for a given prompt to have its replica across GPUs)
|
140 |
-
4. **Normalizes rewards** by mean/std ⟹ This gives us advantages
|
141 |
5. **Discards completions** for prompts it doesn't own (called alien prompts)
|
142 |
|
143 |
### Concrete Example: Multi-GPU Setup
|
@@ -201,16 +201,16 @@ $$\text{GRPO} = -\mathbb{E}_{(s,a)}\left[\frac{\pi(a|s)}{\pi_{\text{old}}(a|s)}
|
|
201 |
|
202 |
- This actually happens only once at the first iteration when we create the rollout
|
203 |
|
204 |
-
4. **Run forward pass through ref policy** to compute
|
205 |
- This actually happens only once at the first iteration when we create the rollout
|
206 |
- Ref model is the original model without LoRA adapters
|
207 |
|
208 |
-
5. **Run forward pass through current policy** to compute
|
209 |
- Needed only if `num_iterations > 1`; otherwise the same as old policy
|
210 |
|
211 |
-
6. **Compute KL loss** between
|
212 |
|
213 |
-
7. **Compute advantage-weighted logprobs:**
|
214 |
|
215 |
## Workflow Summary
|
216 |
|
|
|
137 |
1. **Decodes completions**
|
138 |
2. **Computes reward** for (prompt, completion) pairs
|
139 |
3. **Gathers rewards** from other ranks (because it's possible for a given prompt to have its replica across GPUs)
|
140 |
+
4. **Normalizes rewards** by mean/std ⟹ This gives us advantages \\(A(s,a)\\)
|
141 |
5. **Discards completions** for prompts it doesn't own (called alien prompts)
|
142 |
|
143 |
### Concrete Example: Multi-GPU Setup
|
|
|
201 |
|
202 |
- This actually happens only once at the first iteration when we create the rollout
|
203 |
|
204 |
+
4. **Run forward pass through ref policy** to compute \\(\pi_{\text{ref}}(a|s)\\)
|
205 |
- This actually happens only once at the first iteration when we create the rollout
|
206 |
- Ref model is the original model without LoRA adapters
|
207 |
|
208 |
+
5. **Run forward pass through current policy** to compute \\(\pi(a|s)\\)
|
209 |
- Needed only if `num_iterations > 1`; otherwise the same as old policy
|
210 |
|
211 |
+
6. **Compute KL loss** between \\(\pi(a|s)\\) and \\(\pi_{\text{ref}}(a|s)\\)
|
212 |
|
213 |
+
7. **Compute advantage-weighted logprobs:** \\(\frac{\pi(a|s)}{\pi_{\text{old}}(a|s)} \times A(s,a)\\)
|
214 |
|
215 |
## Workflow Summary
|
216 |
|