drbh
commited on
Commit
·
aa23f77
1
Parent(s):
eba2c2c
fix: extract expert device mesh for group from unused prehook
Browse files
torch-ext/megablocks/layers.py
CHANGED
|
@@ -680,6 +680,17 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +702,9 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
-
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
| 683 |
+
def get_device_mesh(model):
|
| 684 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 685 |
+
try:
|
| 686 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 687 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 688 |
+
# Extract the device_mesh from the closure
|
| 689 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 690 |
+
except:
|
| 691 |
+
return None
|
| 692 |
+
|
| 693 |
+
|
| 694 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 695 |
|
| 696 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 702 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 703 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 704 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 705 |
+
|
| 706 |
+
device_mesh = get_device_mesh(self)
|
| 707 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 708 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 709 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 710 |
|