drbh
		
	commited on
		
		
					Commit 
							
							·
						
						89e2950
	
1
								Parent(s):
							
							aa23f77
								
feat: support shared experts layer and tests
Browse files- tests/test_mb_moe_shared_expert.py +139 -0
- tests/test_mb_moe_shared_expert_multi.py +200 -0
- torch-ext/megablocks/layers.py +267 -3
    	
        tests/test_mb_moe_shared_expert.py
    ADDED
    
    | @@ -0,0 +1,139 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import megablocks
         | 
| 3 | 
            +
            from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def test_megablocks_moe_mlp_with_shared_expert_import():
         | 
| 7 | 
            +
                mlp = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 8 | 
            +
                assert hasattr(mlp, 'shared_up_proj_weight')
         | 
| 9 | 
            +
                assert hasattr(mlp, 'shared_down_proj_weight')
         | 
| 10 | 
            +
                assert hasattr(mlp, 'set_shared_expert_weights')
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def test_set_shared_expert_weights():
         | 
| 14 | 
            +
                mlp = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 15 | 
            +
                
         | 
| 16 | 
            +
                hidden_size = 128
         | 
| 17 | 
            +
                shared_expert_hidden_size = 256
         | 
| 18 | 
            +
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 19 | 
            +
                dtype = torch.float32
         | 
| 20 | 
            +
                
         | 
| 21 | 
            +
                up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
         | 
| 22 | 
            +
                down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
         | 
| 23 | 
            +
                up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
         | 
| 24 | 
            +
                down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
         | 
| 25 | 
            +
                
         | 
| 26 | 
            +
                mlp.set_shared_expert_weights(
         | 
| 27 | 
            +
                    up_proj_weight=up_proj_weight,
         | 
| 28 | 
            +
                    down_proj_weight=down_proj_weight,
         | 
| 29 | 
            +
                    up_proj_bias=up_proj_bias,
         | 
| 30 | 
            +
                    down_proj_bias=down_proj_bias,
         | 
| 31 | 
            +
                    weighted_sum=True,
         | 
| 32 | 
            +
                    activation_fn=torch.nn.functional.gelu
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
         | 
| 36 | 
            +
                assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
         | 
| 37 | 
            +
                assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
         | 
| 38 | 
            +
                assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
         | 
| 39 | 
            +
                assert mlp.shared_expert_weighted_sum == True
         | 
| 40 | 
            +
                assert mlp.shared_activation_fn == torch.nn.functional.gelu
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def test_create_shared_expert_weights():
         | 
| 44 | 
            +
                hidden_size = 128
         | 
| 45 | 
            +
                shared_expert_hidden_size = 256
         | 
| 46 | 
            +
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 47 | 
            +
                dtype = torch.float32
         | 
| 48 | 
            +
                
         | 
| 49 | 
            +
                def init_method(tensor):
         | 
| 50 | 
            +
                    torch.nn.init.xavier_uniform_(tensor)
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
         | 
| 53 | 
            +
                    hidden_size=hidden_size,
         | 
| 54 | 
            +
                    shared_expert_hidden_size=shared_expert_hidden_size,
         | 
| 55 | 
            +
                    device=device,
         | 
| 56 | 
            +
                    dtype=dtype,
         | 
| 57 | 
            +
                    init_method=init_method
         | 
| 58 | 
            +
                )
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
         | 
| 61 | 
            +
                assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
         | 
| 62 | 
            +
                assert up_proj_weight.device.type == device.type
         | 
| 63 | 
            +
                assert down_proj_weight.device.type == device.type
         | 
| 64 | 
            +
                assert up_proj_weight.dtype == dtype
         | 
| 65 | 
            +
                assert down_proj_weight.dtype == dtype
         | 
| 66 | 
            +
                assert up_proj_bias is None
         | 
| 67 | 
            +
                assert down_proj_bias is None
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def test_shared_expert_weights_none_by_default():
         | 
| 71 | 
            +
                mlp = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 72 | 
            +
                
         | 
| 73 | 
            +
                assert mlp.shared_up_proj_weight is None
         | 
| 74 | 
            +
                assert mlp.shared_down_proj_weight is None
         | 
| 75 | 
            +
                assert mlp.shared_up_proj_bias is None
         | 
| 76 | 
            +
                assert mlp.shared_down_proj_bias is None
         | 
| 77 | 
            +
                assert mlp.shared_expert_weighted_sum == False
         | 
| 78 | 
            +
                assert mlp.shared_activation_fn is None
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def test_inheritance_from_megablocks_moe_mlp():
         | 
| 82 | 
            +
                mlp = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 83 | 
            +
                
         | 
| 84 | 
            +
                from megablocks.layers import MegaBlocksMoeMLP
         | 
| 85 | 
            +
                assert isinstance(mlp, MegaBlocksMoeMLP)
         | 
| 86 | 
            +
                assert hasattr(mlp, 'forward')
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def test_shared_expert_weights_custom_init():
         | 
| 90 | 
            +
                hidden_size = 64
         | 
| 91 | 
            +
                shared_expert_hidden_size = 128
         | 
| 92 | 
            +
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 93 | 
            +
                dtype = torch.float16
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                def custom_init(tensor):
         | 
| 96 | 
            +
                    torch.nn.init.constant_(tensor, 0.5)
         | 
| 97 | 
            +
                
         | 
| 98 | 
            +
                def custom_output_init(tensor):
         | 
| 99 | 
            +
                    torch.nn.init.constant_(tensor, 0.1)
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
         | 
| 102 | 
            +
                    hidden_size=hidden_size,
         | 
| 103 | 
            +
                    shared_expert_hidden_size=shared_expert_hidden_size,
         | 
| 104 | 
            +
                    device=device,
         | 
| 105 | 
            +
                    dtype=dtype,
         | 
| 106 | 
            +
                    init_method=custom_init,
         | 
| 107 | 
            +
                    output_layer_init_method=custom_output_init
         | 
| 108 | 
            +
                )
         | 
| 109 | 
            +
                
         | 
| 110 | 
            +
                assert torch.all(up_proj_weight == 0.5)
         | 
| 111 | 
            +
                assert torch.all(down_proj_weight == 0.1)
         | 
| 112 | 
            +
                assert up_proj_weight.dtype == dtype
         | 
| 113 | 
            +
                assert down_proj_weight.dtype == dtype
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def test_shared_expert_weights_dimensions():
         | 
| 117 | 
            +
                mlp = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 118 | 
            +
                
         | 
| 119 | 
            +
                batch_size = 4
         | 
| 120 | 
            +
                seq_len = 16
         | 
| 121 | 
            +
                hidden_size = 128
         | 
| 122 | 
            +
                shared_expert_hidden_size = 256
         | 
| 123 | 
            +
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 124 | 
            +
                
         | 
| 125 | 
            +
                up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
         | 
| 126 | 
            +
                down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
         | 
| 127 | 
            +
                
         | 
| 128 | 
            +
                mlp.set_shared_expert_weights(
         | 
| 129 | 
            +
                    up_proj_weight=up_proj_weight,
         | 
| 130 | 
            +
                    down_proj_weight=down_proj_weight
         | 
| 131 | 
            +
                )
         | 
| 132 | 
            +
                
         | 
| 133 | 
            +
                x = torch.randn(seq_len, batch_size, hidden_size, device=device)
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
         | 
| 136 | 
            +
                expected_down_output_shape = (seq_len, batch_size, hidden_size)
         | 
| 137 | 
            +
                
         | 
| 138 | 
            +
                assert up_proj_weight.shape[1] == x.shape[-1]
         | 
| 139 | 
            +
                assert down_proj_weight.shape[0] == x.shape[-1]
         | 
    	
        tests/test_mb_moe_shared_expert_multi.py
    ADDED
    
    | @@ -0,0 +1,200 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.distributed as dist
         | 
| 3 | 
            +
            import torch.multiprocessing as mp
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import pytest
         | 
| 6 | 
            +
            from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def run_distributed_shared_expert_test(rank, world_size):
         | 
| 10 | 
            +
                os.environ["MASTER_ADDR"] = "localhost"
         | 
| 11 | 
            +
                os.environ["MASTER_PORT"] = "12356"
         | 
| 12 | 
            +
                os.environ["RANK"] = str(rank)
         | 
| 13 | 
            +
                os.environ["WORLD_SIZE"] = str(world_size)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                dist.init_process_group(
         | 
| 16 | 
            +
                    backend="gloo",
         | 
| 17 | 
            +
                    rank=rank,
         | 
| 18 | 
            +
                    world_size=world_size,
         | 
| 19 | 
            +
                )
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                model = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                hidden_size = 128
         | 
| 24 | 
            +
                shared_expert_hidden_size = 192
         | 
| 25 | 
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def simple_init(tensor):
         | 
| 28 | 
            +
                    torch.nn.init.xavier_uniform_(tensor)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
         | 
| 31 | 
            +
                    hidden_size=hidden_size,
         | 
| 32 | 
            +
                    shared_expert_hidden_size=shared_expert_hidden_size,
         | 
| 33 | 
            +
                    device=torch.device(device),
         | 
| 34 | 
            +
                    dtype=torch.float32,
         | 
| 35 | 
            +
                    init_method=simple_init
         | 
| 36 | 
            +
                )
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                model.set_shared_expert_weights(
         | 
| 39 | 
            +
                    up_proj_weight=shared_up_proj_weight,
         | 
| 40 | 
            +
                    down_proj_weight=shared_down_proj_weight,
         | 
| 41 | 
            +
                    up_proj_bias=shared_up_proj_bias,
         | 
| 42 | 
            +
                    down_proj_bias=shared_down_proj_bias,
         | 
| 43 | 
            +
                    weighted_sum=True,
         | 
| 44 | 
            +
                    activation_fn=torch.nn.functional.gelu
         | 
| 45 | 
            +
                )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
         | 
| 48 | 
            +
                assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
         | 
| 49 | 
            +
                assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}"
         | 
| 50 | 
            +
                
         | 
| 51 | 
            +
                print(f"Rank {rank}: Shared expert setup test passed!")
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                dist.destroy_process_group()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def run_distributed_shared_expert_weighted_sum_test(rank, world_size):
         | 
| 57 | 
            +
                os.environ["MASTER_ADDR"] = "localhost"
         | 
| 58 | 
            +
                os.environ["MASTER_PORT"] = "12357"
         | 
| 59 | 
            +
                os.environ["RANK"] = str(rank)
         | 
| 60 | 
            +
                os.environ["WORLD_SIZE"] = str(world_size)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                dist.init_process_group(
         | 
| 63 | 
            +
                    backend="gloo",
         | 
| 64 | 
            +
                    rank=rank,
         | 
| 65 | 
            +
                    world_size=world_size,
         | 
| 66 | 
            +
                )
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                model = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                hidden_size = 64
         | 
| 71 | 
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def simple_init(tensor):
         | 
| 74 | 
            +
                    torch.nn.init.xavier_uniform_(tensor)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
         | 
| 77 | 
            +
                    hidden_size=hidden_size,
         | 
| 78 | 
            +
                    shared_expert_hidden_size=96,
         | 
| 79 | 
            +
                    device=torch.device(device),
         | 
| 80 | 
            +
                    dtype=torch.float32,
         | 
| 81 | 
            +
                    init_method=simple_init
         | 
| 82 | 
            +
                )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                model.set_shared_expert_weights(
         | 
| 85 | 
            +
                    up_proj_weight=shared_up_proj_weight,
         | 
| 86 | 
            +
                    down_proj_weight=shared_down_proj_weight,
         | 
| 87 | 
            +
                    weighted_sum=False,
         | 
| 88 | 
            +
                    activation_fn=torch.nn.functional.relu
         | 
| 89 | 
            +
                )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
         | 
| 92 | 
            +
                assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
         | 
| 93 | 
            +
                assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}"
         | 
| 94 | 
            +
                assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}"
         | 
| 95 | 
            +
                
         | 
| 96 | 
            +
                print(f"Rank {rank}: Weighted sum setup test passed!")
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                dist.destroy_process_group()
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
         | 
| 102 | 
            +
            def test_shared_expert_distributed_functionality(world_size):
         | 
| 103 | 
            +
                if world_size == 1:
         | 
| 104 | 
            +
                    # Single process test
         | 
| 105 | 
            +
                    model = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 106 | 
            +
                    
         | 
| 107 | 
            +
                    hidden_size = 128
         | 
| 108 | 
            +
                    shared_expert_hidden_size = 192
         | 
| 109 | 
            +
                    device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    def simple_init(tensor):
         | 
| 112 | 
            +
                        torch.nn.init.xavier_uniform_(tensor)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
         | 
| 115 | 
            +
                        hidden_size=hidden_size,
         | 
| 116 | 
            +
                        shared_expert_hidden_size=shared_expert_hidden_size,
         | 
| 117 | 
            +
                        device=torch.device(device),
         | 
| 118 | 
            +
                        dtype=torch.float32,
         | 
| 119 | 
            +
                        init_method=simple_init
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    model.set_shared_expert_weights(
         | 
| 123 | 
            +
                        up_proj_weight=shared_up_proj_weight,
         | 
| 124 | 
            +
                        down_proj_weight=shared_down_proj_weight,
         | 
| 125 | 
            +
                        up_proj_bias=shared_up_proj_bias,
         | 
| 126 | 
            +
                        down_proj_bias=shared_down_proj_bias,
         | 
| 127 | 
            +
                        weighted_sum=True,
         | 
| 128 | 
            +
                        activation_fn=torch.nn.functional.gelu
         | 
| 129 | 
            +
                    )
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
         | 
| 132 | 
            +
                    assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
         | 
| 133 | 
            +
                    assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly"
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                    print("Single process shared expert setup test passed!")
         | 
| 136 | 
            +
                else:
         | 
| 137 | 
            +
                    # Multi-process test
         | 
| 138 | 
            +
                    mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True)
         | 
| 139 | 
            +
                    print("Multi-process shared expert test completed successfully!")
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
         | 
| 143 | 
            +
            def test_shared_expert_distributed_weighted_sum(world_size):
         | 
| 144 | 
            +
                if world_size == 1:
         | 
| 145 | 
            +
                    # Single process test
         | 
| 146 | 
            +
                    model = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    hidden_size = 64
         | 
| 149 | 
            +
                    device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    def simple_init(tensor):
         | 
| 152 | 
            +
                        torch.nn.init.xavier_uniform_(tensor)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
         | 
| 155 | 
            +
                        hidden_size=hidden_size,
         | 
| 156 | 
            +
                        shared_expert_hidden_size=96,
         | 
| 157 | 
            +
                        device=torch.device(device),
         | 
| 158 | 
            +
                        dtype=torch.float32,
         | 
| 159 | 
            +
                        init_method=simple_init
         | 
| 160 | 
            +
                    )
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    model.set_shared_expert_weights(
         | 
| 163 | 
            +
                        up_proj_weight=shared_up_proj_weight,
         | 
| 164 | 
            +
                        down_proj_weight=shared_down_proj_weight,
         | 
| 165 | 
            +
                        weighted_sum=False,
         | 
| 166 | 
            +
                        activation_fn=torch.nn.functional.relu
         | 
| 167 | 
            +
                    )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
         | 
| 170 | 
            +
                    assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
         | 
| 171 | 
            +
                    assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly"
         | 
| 172 | 
            +
                    assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly"
         | 
| 173 | 
            +
                    
         | 
| 174 | 
            +
                    print("Single process weighted sum setup test passed!")
         | 
| 175 | 
            +
                else:
         | 
| 176 | 
            +
                    # Multi-process test
         | 
| 177 | 
            +
                    mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True)
         | 
| 178 | 
            +
                    print("Multi-process shared expert weighted sum test completed successfully!")
         | 
| 179 | 
            +
             | 
| 180 | 
            +
             | 
| 181 | 
            +
            def test_shared_expert_single_process():
         | 
| 182 | 
            +
                model = MegaBlocksMoeMLPWithSharedExpert()
         | 
| 183 | 
            +
                
         | 
| 184 | 
            +
                assert model.shared_up_proj_weight is None
         | 
| 185 | 
            +
                assert model.shared_down_proj_weight is None
         | 
| 186 | 
            +
                assert hasattr(model, 'set_shared_expert_weights')
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                print("Single process shared expert basic test passed!")
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            if __name__ == "__main__":
         | 
| 192 | 
            +
                test_shared_expert_single_process()
         | 
| 193 | 
            +
                print("Single process test passed!")
         | 
| 194 | 
            +
                
         | 
| 195 | 
            +
                os.environ['WORLD_SIZE'] = '2'
         | 
| 196 | 
            +
                test_shared_expert_distributed_functionality()
         | 
| 197 | 
            +
                print("Distributed functionality test passed!")
         | 
| 198 | 
            +
                
         | 
| 199 | 
            +
                test_shared_expert_distributed_weighted_sum()
         | 
| 200 | 
            +
                print("Distributed weighted sum test passed!")
         | 
    	
        torch-ext/megablocks/layers.py
    CHANGED
    
    | @@ -152,6 +152,66 @@ def mlp_forward( | |
| 152 | 
             
                return torch.bmm(x, w2) + w2_bias[..., None, :]
         | 
| 153 |  | 
| 154 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 155 | 
             
            # Global variable to store load balancing loss
         | 
| 156 | 
             
            _LOAD_BALANCING_LOSS = []
         | 
| 157 |  | 
| @@ -680,6 +740,125 @@ def moe_forward( | |
| 680 | 
             
                return x, expert_weights, router_scores
         | 
| 681 |  | 
| 682 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 683 | 
             
            def get_device_mesh(model):
         | 
| 684 | 
             
                # Extract device_mesh from child's unused pre_hook closure
         | 
| 685 | 
             
                try:
         | 
| @@ -687,7 +866,7 @@ def get_device_mesh(model): | |
| 687 | 
             
                    hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
         | 
| 688 | 
             
                    # Extract the device_mesh from the closure
         | 
| 689 | 
             
                    return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
         | 
| 690 | 
            -
                except:
         | 
| 691 | 
             
                    return None
         | 
| 692 |  | 
| 693 |  | 
| @@ -703,8 +882,11 @@ class MegaBlocksMoeMLP(torch.nn.Module): | |
| 703 | 
             
                    moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
         | 
| 704 | 
             
                    uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
         | 
| 705 |  | 
| 706 | 
            -
                     | 
| 707 | 
            -
                     | 
|  | |
|  | |
|  | |
| 708 | 
             
                    has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
         | 
| 709 | 
             
                    forward_fn = parallel_forward_once if has_parallel else forward_once
         | 
| 710 |  | 
| @@ -734,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module): | |
| 734 | 
             
                        hidden_size=self.experts.hidden_size,
         | 
| 735 | 
             
                        mlp_impl=mlp_impl,
         | 
| 736 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 737 | 
             
                    return output, expert_weights_out
         | 
|  | |
| 152 | 
             
                return torch.bmm(x, w2) + w2_bias[..., None, :]
         | 
| 153 |  | 
| 154 |  | 
| 155 | 
            +
            # Shared expert MLP forward pass
         | 
| 156 | 
            +
            def shared_mlp_forward(
         | 
| 157 | 
            +
                x: torch.Tensor,
         | 
| 158 | 
            +
                up_proj_weight: torch.Tensor,
         | 
| 159 | 
            +
                down_proj_weight: torch.Tensor,
         | 
| 160 | 
            +
                up_proj_bias: Optional[torch.Tensor] = None,
         | 
| 161 | 
            +
                down_proj_bias: Optional[torch.Tensor] = None,
         | 
| 162 | 
            +
                activation_fn: Optional[Any] = None,
         | 
| 163 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 164 | 
            +
            ) -> torch.Tensor:
         | 
| 165 | 
            +
                # Default activation function
         | 
| 166 | 
            +
                if activation_fn is None:
         | 
| 167 | 
            +
                    activation_fn = torch.nn.functional.gelu
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                # Scale weights
         | 
| 170 | 
            +
                up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
         | 
| 171 | 
            +
                down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
         | 
| 172 | 
            +
                if up_proj_bias is not None:
         | 
| 173 | 
            +
                    up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
         | 
| 174 | 
            +
                if down_proj_bias is not None:
         | 
| 175 | 
            +
                    down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                # Resolve dtensors
         | 
| 178 | 
            +
                up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
         | 
| 179 | 
            +
                down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
         | 
| 180 | 
            +
                if up_proj_bias is not None:
         | 
| 181 | 
            +
                    up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
         | 
| 182 | 
            +
                if down_proj_bias is not None:
         | 
| 183 | 
            +
                    down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                # Up projection
         | 
| 186 | 
            +
                x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                # Activation
         | 
| 189 | 
            +
                x = activation_fn(x)
         | 
| 190 | 
            +
                
         | 
| 191 | 
            +
                # Down projection
         | 
| 192 | 
            +
                x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
         | 
| 193 | 
            +
                
         | 
| 194 | 
            +
                return x
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            # Combine outputs from shared expert and regular experts
         | 
| 198 | 
            +
            def combine_expert_shared_outputs(
         | 
| 199 | 
            +
                shared_expert_out: torch.Tensor,
         | 
| 200 | 
            +
                expert_out: torch.Tensor,
         | 
| 201 | 
            +
                shared_expert_weighted_sum: bool = False,
         | 
| 202 | 
            +
                moe_top_k: int = 1,
         | 
| 203 | 
            +
            ) -> torch.Tensor:
         | 
| 204 | 
            +
                if shared_expert_weighted_sum:
         | 
| 205 | 
            +
                    # Weighted sum based on number of experts used
         | 
| 206 | 
            +
                    total_experts = moe_top_k + 1
         | 
| 207 | 
            +
                    shared_weight = 1.0 / total_experts
         | 
| 208 | 
            +
                    expert_weight = moe_top_k / total_experts
         | 
| 209 | 
            +
                    return shared_expert_out * shared_weight + expert_out * expert_weight
         | 
| 210 | 
            +
                else:
         | 
| 211 | 
            +
                    # Simple addition
         | 
| 212 | 
            +
                    return shared_expert_out + expert_out
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
             
            # Global variable to store load balancing loss
         | 
| 216 | 
             
            _LOAD_BALANCING_LOSS = []
         | 
| 217 |  | 
|  | |
| 740 | 
             
                return x, expert_weights, router_scores
         | 
| 741 |  | 
| 742 |  | 
| 743 | 
            +
            def moe_forward_with_shared_expert(
         | 
| 744 | 
            +
                x: torch.Tensor,
         | 
| 745 | 
            +
                router_weight: torch.Tensor,
         | 
| 746 | 
            +
                moe_top_k: int,
         | 
| 747 | 
            +
                moe_num_experts: int,
         | 
| 748 | 
            +
                moe_jitter_eps: float = None,
         | 
| 749 | 
            +
                moe_normalize_expert_weights: int = None,
         | 
| 750 | 
            +
                uniform_expert_assignment: bool = False,
         | 
| 751 | 
            +
                training: bool = False,
         | 
| 752 | 
            +
                w1: torch.Tensor = None,
         | 
| 753 | 
            +
                w2: torch.Tensor = None,
         | 
| 754 | 
            +
                w1_bias: torch.Tensor = None,
         | 
| 755 | 
            +
                w2_bias: torch.Tensor = None,
         | 
| 756 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 757 | 
            +
                alpha: float = 1.702,
         | 
| 758 | 
            +
                sort_end_bit: int = 0,
         | 
| 759 | 
            +
                expert_parallel_group: torch.distributed.ProcessGroup = None,
         | 
| 760 | 
            +
                moe_capacity_factor: float = 1.0,
         | 
| 761 | 
            +
                moe_expert_model_parallelism: bool = False,
         | 
| 762 | 
            +
                forward_fn: Any = None,
         | 
| 763 | 
            +
                hidden_size: int = None,
         | 
| 764 | 
            +
                mlp_impl: str = "grouped",
         | 
| 765 | 
            +
                # Shared expert parameters
         | 
| 766 | 
            +
                shared_up_proj_weight: Optional[torch.Tensor] = None,
         | 
| 767 | 
            +
                shared_down_proj_weight: Optional[torch.Tensor] = None,
         | 
| 768 | 
            +
                shared_up_proj_bias: Optional[torch.Tensor] = None,
         | 
| 769 | 
            +
                shared_down_proj_bias: Optional[torch.Tensor] = None,
         | 
| 770 | 
            +
                shared_expert_weighted_sum: bool = False,
         | 
| 771 | 
            +
                shared_activation_fn: Optional[Any] = None,
         | 
| 772 | 
            +
            ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                # First, compute regular MoE forward pass
         | 
| 775 | 
            +
                expert_out, expert_weights, router_scores = moe_forward(
         | 
| 776 | 
            +
                    x=x,
         | 
| 777 | 
            +
                    router_weight=router_weight,
         | 
| 778 | 
            +
                    moe_top_k=moe_top_k,
         | 
| 779 | 
            +
                    moe_num_experts=moe_num_experts,
         | 
| 780 | 
            +
                    moe_jitter_eps=moe_jitter_eps,
         | 
| 781 | 
            +
                    moe_normalize_expert_weights=moe_normalize_expert_weights,
         | 
| 782 | 
            +
                    uniform_expert_assignment=uniform_expert_assignment,
         | 
| 783 | 
            +
                    training=training,
         | 
| 784 | 
            +
                    w1=w1,
         | 
| 785 | 
            +
                    w2=w2,
         | 
| 786 | 
            +
                    w1_bias=w1_bias,
         | 
| 787 | 
            +
                    w2_bias=w2_bias,
         | 
| 788 | 
            +
                    gradient_scale=gradient_scale,
         | 
| 789 | 
            +
                    alpha=alpha,
         | 
| 790 | 
            +
                    sort_end_bit=sort_end_bit,
         | 
| 791 | 
            +
                    expert_parallel_group=expert_parallel_group,
         | 
| 792 | 
            +
                    moe_capacity_factor=moe_capacity_factor,
         | 
| 793 | 
            +
                    moe_expert_model_parallelism=moe_expert_model_parallelism,
         | 
| 794 | 
            +
                    forward_fn=forward_fn,
         | 
| 795 | 
            +
                    hidden_size=hidden_size,
         | 
| 796 | 
            +
                    mlp_impl=mlp_impl,
         | 
| 797 | 
            +
                )
         | 
| 798 | 
            +
                
         | 
| 799 | 
            +
                # If shared expert weights provided, compute shared expert output
         | 
| 800 | 
            +
                if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
         | 
| 801 | 
            +
                    shared_expert_out = shared_mlp_forward(
         | 
| 802 | 
            +
                        x=x,
         | 
| 803 | 
            +
                        up_proj_weight=shared_up_proj_weight,
         | 
| 804 | 
            +
                        down_proj_weight=shared_down_proj_weight,
         | 
| 805 | 
            +
                        up_proj_bias=shared_up_proj_bias,
         | 
| 806 | 
            +
                        down_proj_bias=shared_down_proj_bias,
         | 
| 807 | 
            +
                        activation_fn=shared_activation_fn,
         | 
| 808 | 
            +
                        gradient_scale=gradient_scale,
         | 
| 809 | 
            +
                    )
         | 
| 810 | 
            +
                    
         | 
| 811 | 
            +
                    # Combine expert outputs
         | 
| 812 | 
            +
                    combined_out = combine_expert_shared_outputs(
         | 
| 813 | 
            +
                        shared_expert_out=shared_expert_out,
         | 
| 814 | 
            +
                        expert_out=expert_out,
         | 
| 815 | 
            +
                        shared_expert_weighted_sum=shared_expert_weighted_sum,
         | 
| 816 | 
            +
                        moe_top_k=moe_top_k,
         | 
| 817 | 
            +
                    )
         | 
| 818 | 
            +
                    
         | 
| 819 | 
            +
                    return combined_out, expert_weights, router_scores
         | 
| 820 | 
            +
                
         | 
| 821 | 
            +
                # Return regular MoE output if no shared expert
         | 
| 822 | 
            +
                return expert_out, expert_weights, router_scores
         | 
| 823 | 
            +
             | 
| 824 | 
            +
             | 
| 825 | 
            +
            def create_shared_expert_weights(
         | 
| 826 | 
            +
                hidden_size: int,
         | 
| 827 | 
            +
                shared_expert_hidden_size: int,
         | 
| 828 | 
            +
                device: torch.device,
         | 
| 829 | 
            +
                dtype: torch.dtype,
         | 
| 830 | 
            +
                init_method: Any,
         | 
| 831 | 
            +
                output_layer_init_method: Any = None,
         | 
| 832 | 
            +
            ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
         | 
| 833 | 
            +
             | 
| 834 | 
            +
                if output_layer_init_method is None:
         | 
| 835 | 
            +
                    output_layer_init_method = init_method
         | 
| 836 | 
            +
                    
         | 
| 837 | 
            +
                # Create weight tensors
         | 
| 838 | 
            +
                up_proj_weight = torch.empty(
         | 
| 839 | 
            +
                    shared_expert_hidden_size,
         | 
| 840 | 
            +
                    hidden_size,
         | 
| 841 | 
            +
                    device=device,
         | 
| 842 | 
            +
                    dtype=dtype,
         | 
| 843 | 
            +
                )
         | 
| 844 | 
            +
                down_proj_weight = torch.empty(
         | 
| 845 | 
            +
                    hidden_size,
         | 
| 846 | 
            +
                    shared_expert_hidden_size,
         | 
| 847 | 
            +
                    device=device,
         | 
| 848 | 
            +
                    dtype=dtype,
         | 
| 849 | 
            +
                )
         | 
| 850 | 
            +
                
         | 
| 851 | 
            +
                # Initialize weights
         | 
| 852 | 
            +
                init_method(up_proj_weight)
         | 
| 853 | 
            +
                output_layer_init_method(down_proj_weight)
         | 
| 854 | 
            +
                
         | 
| 855 | 
            +
                # No bias by default
         | 
| 856 | 
            +
                return up_proj_weight, down_proj_weight, None, None
         | 
| 857 | 
            +
             | 
| 858 | 
            +
            # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
         | 
| 859 | 
            +
            # This exists because device_mesh is trapped in hook closures with no model attribute
         | 
| 860 | 
            +
            # Fragile - breaks if hook structure changes or Python internals change
         | 
| 861 | 
            +
            # TODO: Replace with a more robust solution when available
         | 
| 862 | 
             
            def get_device_mesh(model):
         | 
| 863 | 
             
                # Extract device_mesh from child's unused pre_hook closure
         | 
| 864 | 
             
                try:
         | 
|  | |
| 866 | 
             
                    hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
         | 
| 867 | 
             
                    # Extract the device_mesh from the closure
         | 
| 868 | 
             
                    return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
         | 
| 869 | 
            +
                except Exception:
         | 
| 870 | 
             
                    return None
         | 
| 871 |  | 
| 872 |  | 
|  | |
| 882 | 
             
                    moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
         | 
| 883 | 
             
                    uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
         | 
| 884 |  | 
| 885 | 
            +
                    expert_parallel_group = getattr(self, "expert_parallel_group", None)
         | 
| 886 | 
            +
                    if expert_parallel_group is None:
         | 
| 887 | 
            +
                        device_mesh = get_device_mesh(self)
         | 
| 888 | 
            +
                        expert_parallel_group = device_mesh.get_group() if device_mesh else None
         | 
| 889 | 
            +
             | 
| 890 | 
             
                    has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
         | 
| 891 | 
             
                    forward_fn = parallel_forward_once if has_parallel else forward_once
         | 
| 892 |  | 
|  | |
| 916 | 
             
                        hidden_size=self.experts.hidden_size,
         | 
| 917 | 
             
                        mlp_impl=mlp_impl,
         | 
| 918 | 
             
                    )
         | 
| 919 | 
            +
                    return output, expert_weights_out
         | 
| 920 | 
            +
             | 
| 921 | 
            +
             | 
| 922 | 
            +
            class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
         | 
| 923 | 
            +
                
         | 
| 924 | 
            +
                def __init__(self):
         | 
| 925 | 
            +
                    super().__init__()
         | 
| 926 | 
            +
                    # Shared expert weights will be set by the user
         | 
| 927 | 
            +
                    self.shared_up_proj_weight = None
         | 
| 928 | 
            +
                    self.shared_down_proj_weight = None
         | 
| 929 | 
            +
                    self.shared_up_proj_bias = None
         | 
| 930 | 
            +
                    self.shared_down_proj_bias = None
         | 
| 931 | 
            +
                    self.shared_expert_weighted_sum = False
         | 
| 932 | 
            +
                    self.shared_activation_fn = None
         | 
| 933 | 
            +
                    
         | 
| 934 | 
            +
                def set_shared_expert_weights(
         | 
| 935 | 
            +
                    self,
         | 
| 936 | 
            +
                    up_proj_weight: torch.Tensor,
         | 
| 937 | 
            +
                    down_proj_weight: torch.Tensor,
         | 
| 938 | 
            +
                    up_proj_bias: Optional[torch.Tensor] = None,
         | 
| 939 | 
            +
                    down_proj_bias: Optional[torch.Tensor] = None,
         | 
| 940 | 
            +
                    weighted_sum: bool = False,
         | 
| 941 | 
            +
                    activation_fn: Optional[Any] = None,
         | 
| 942 | 
            +
                ):
         | 
| 943 | 
            +
                    self.shared_up_proj_weight = up_proj_weight
         | 
| 944 | 
            +
                    self.shared_down_proj_weight = down_proj_weight
         | 
| 945 | 
            +
                    self.shared_up_proj_bias = up_proj_bias
         | 
| 946 | 
            +
                    self.shared_down_proj_bias = down_proj_bias
         | 
| 947 | 
            +
                    self.shared_expert_weighted_sum = weighted_sum
         | 
| 948 | 
            +
                    self.shared_activation_fn = activation_fn
         | 
| 949 | 
            +
                
         | 
| 950 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 951 | 
            +
                    moe_top_k = getattr(self.router, "top_k", 4)
         | 
| 952 | 
            +
                    moe_num_experts = getattr(self.experts, "num_experts", 128)
         | 
| 953 | 
            +
                    gradient_scale = getattr(self.experts, "gradient_scale", None)
         | 
| 954 | 
            +
                    alpha = getattr(self.experts, "alpha", 1.0)
         | 
| 955 | 
            +
                    moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
         | 
| 956 | 
            +
                    moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
         | 
| 957 | 
            +
                    moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
         | 
| 958 | 
            +
                    uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
         | 
| 959 | 
            +
             | 
| 960 | 
            +
                    expert_parallel_group = getattr(self, "expert_parallel_group", None)
         | 
| 961 | 
            +
                    if expert_parallel_group is None:
         | 
| 962 | 
            +
                        device_mesh = get_device_mesh(self)
         | 
| 963 | 
            +
                        expert_parallel_group = device_mesh.get_group() if device_mesh else None
         | 
| 964 | 
            +
             | 
| 965 | 
            +
                    has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
         | 
| 966 | 
            +
                    forward_fn = parallel_forward_once if has_parallel else forward_once
         | 
| 967 | 
            +
                    
         | 
| 968 | 
            +
                    sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
         | 
| 969 | 
            +
                    mlp_impl = getattr(self, "mlp_impl", "grouped")
         | 
| 970 | 
            +
                    
         | 
| 971 | 
            +
                    output, expert_weights_out, *_ = moe_forward_with_shared_expert(
         | 
| 972 | 
            +
                        x=x,
         | 
| 973 | 
            +
                        router_weight=self.router.weight,
         | 
| 974 | 
            +
                        moe_top_k=moe_top_k,
         | 
| 975 | 
            +
                        moe_num_experts=moe_num_experts,
         | 
| 976 | 
            +
                        moe_jitter_eps=moe_jitter_eps,
         | 
| 977 | 
            +
                        moe_normalize_expert_weights=moe_normalize_expert_weights,
         | 
| 978 | 
            +
                        uniform_expert_assignment=uniform_expert_assignment,
         | 
| 979 | 
            +
                        training=self.training,
         | 
| 980 | 
            +
                        w1=self.experts.gate_up_proj,
         | 
| 981 | 
            +
                        w2=self.experts.down_proj,
         | 
| 982 | 
            +
                        w1_bias=self.experts.gate_up_proj_bias,
         | 
| 983 | 
            +
                        w2_bias=self.experts.down_proj_bias,
         | 
| 984 | 
            +
                        gradient_scale=gradient_scale,
         | 
| 985 | 
            +
                        alpha=alpha,
         | 
| 986 | 
            +
                        sort_end_bit=sort_end_bit,
         | 
| 987 | 
            +
                        expert_parallel_group=expert_parallel_group,
         | 
| 988 | 
            +
                        moe_capacity_factor=moe_capacity_factor,
         | 
| 989 | 
            +
                        moe_expert_model_parallelism=has_parallel,
         | 
| 990 | 
            +
                        forward_fn=forward_fn,
         | 
| 991 | 
            +
                        hidden_size=self.experts.hidden_size,
         | 
| 992 | 
            +
                        mlp_impl=mlp_impl,
         | 
| 993 | 
            +
                        # Shared expert parameters
         | 
| 994 | 
            +
                        shared_up_proj_weight=self.shared_up_proj_weight,
         | 
| 995 | 
            +
                        shared_down_proj_weight=self.shared_down_proj_weight,
         | 
| 996 | 
            +
                        shared_up_proj_bias=self.shared_up_proj_bias,
         | 
| 997 | 
            +
                        shared_down_proj_bias=self.shared_down_proj_bias,
         | 
| 998 | 
            +
                        shared_expert_weighted_sum=self.shared_expert_weighted_sum,
         | 
| 999 | 
            +
                        shared_activation_fn=self.shared_activation_fn,
         | 
| 1000 | 
            +
                    )
         | 
| 1001 | 
             
                    return output, expert_weights_out
         | 
