drbh
		
	commited on
		
		
					Commit 
							
							·
						
						e0fb143
	
1
								Parent(s):
							
							5c51af4
								
fix: re upload build
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- build +0 -1
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py +202 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py +10 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py +33 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py +54 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py +101 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py +26 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py +42 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py +337 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py +52 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py +244 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py +103 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py +587 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py +507 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py +94 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py +116 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +32 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_3bdb4b8_dirty.abi3.so +3 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +9 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py +6 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py +2 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py +557 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py +23 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py +35 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py +33 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py +31 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +1001 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py +35 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +63 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py +37 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py +59 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py +52 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py +38 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py +27 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py +78 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py +415 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py +55 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py +98 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +66 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py +149 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py +10 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py +36 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py +14 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py +72 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py +38 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py +85 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py +39 -0
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py +9 -0
    	
        build
    DELETED
    
    | @@ -1 +0,0 @@ | |
| 1 | 
            -
            /nix/store/clckh64l8yhprqcbs4vkm27lfac37j6w-torch-ext-bundle
         | 
|  | |
|  | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/__init__.py
    ADDED
    
    | @@ -0,0 +1,202 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from ._ops import ops
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from .grouped_gemm import backend as gg_backend
         | 
| 9 | 
            +
            from .grouped_gemm import ops as gg_ops
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            from ._layers.arguments import Arguments
         | 
| 13 | 
            +
            from ._layers.dmoe import ParallelDroplessMLP, dMoE
         | 
| 14 | 
            +
            from ._layers.glu import SparseGLU
         | 
| 15 | 
            +
            from ._layers.mlp import MLP, SparseMLP
         | 
| 16 | 
            +
            from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from . import layers
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # This section contains the direct kernel exports (not inlcuded in the original code)
         | 
| 21 | 
            +
            def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
         | 
| 22 | 
            +
                """
         | 
| 23 | 
            +
                Compute exclusive cumulative sum along the specified dimension.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                Args:
         | 
| 26 | 
            +
                    x: Input tensor
         | 
| 27 | 
            +
                    dim: Dimension along which to compute cumsum
         | 
| 28 | 
            +
                    out: Output tensor (modified in-place)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Returns:
         | 
| 31 | 
            +
                    The output tensor
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                result = ops.exclusive_cumsum(x, dim)
         | 
| 34 | 
            +
                out.copy_(result)
         | 
| 35 | 
            +
                return out
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                Compute inclusive cumulative sum along the specified dimension.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                Args:
         | 
| 43 | 
            +
                    x: Input tensor
         | 
| 44 | 
            +
                    dim: Dimension along which to compute cumsum
         | 
| 45 | 
            +
                    out: Output tensor (modified in-place)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Returns:
         | 
| 48 | 
            +
                    The output tensor
         | 
| 49 | 
            +
                """
         | 
| 50 | 
            +
                result = ops.inclusive_cumsum(x, dim)
         | 
| 51 | 
            +
                out.copy_(result)
         | 
| 52 | 
            +
                return out
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                Compute histogram of input tensor values.
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                Args:
         | 
| 60 | 
            +
                    x: Input tensor
         | 
| 61 | 
            +
                    num_bins: Number of histogram bins
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                Returns:
         | 
| 64 | 
            +
                    Histogram tensor with counts for each bin
         | 
| 65 | 
            +
                """
         | 
| 66 | 
            +
                return ops.histogram(x, num_bins)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def indices(
         | 
| 70 | 
            +
                padded_bins: torch.Tensor,
         | 
| 71 | 
            +
                block_size: int,
         | 
| 72 | 
            +
                output_block_rows: int,
         | 
| 73 | 
            +
                output_block_columns: int,
         | 
| 74 | 
            +
            ) -> torch.Tensor:
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                Construct indices from padded bins for sparse operations.
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                Args:
         | 
| 79 | 
            +
                    padded_bins: Tensor containing bin boundaries
         | 
| 80 | 
            +
                    block_size: Size of each block
         | 
| 81 | 
            +
                    output_block_rows: Number of rows in output blocks
         | 
| 82 | 
            +
                    output_block_columns: Number of columns in output blocks
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                Returns:
         | 
| 85 | 
            +
                    Tensor containing constructed indices
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def replicate_forward(
         | 
| 91 | 
            +
                x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
         | 
| 92 | 
            +
            ) -> torch.Tensor:
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                Forward pass of replicate operation - replicate values according to bin sizes.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                Args:
         | 
| 97 | 
            +
                    x: Input tensor with values to replicate
         | 
| 98 | 
            +
                    bins: Tensor containing bin sizes
         | 
| 99 | 
            +
                    out: Output tensor (modified in-place)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                Returns:
         | 
| 102 | 
            +
                    The output tensor
         | 
| 103 | 
            +
                """
         | 
| 104 | 
            +
                return ops.replicate_forward(x, bins, out)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            def replicate_backward(
         | 
| 108 | 
            +
                grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
         | 
| 109 | 
            +
            ) -> torch.Tensor:
         | 
| 110 | 
            +
                """
         | 
| 111 | 
            +
                Backward pass of replicate operation - reduce gradients back to bins.
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                Args:
         | 
| 114 | 
            +
                    grad: Gradient tensor to reduce
         | 
| 115 | 
            +
                    bins: Tensor containing bin sizes
         | 
| 116 | 
            +
                    out: Output tensor (modified in-place)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                Returns:
         | 
| 119 | 
            +
                    The output tensor
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                return ops.replicate_backward(grad, bins, out)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            def sort(
         | 
| 125 | 
            +
                x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
         | 
| 126 | 
            +
            ) -> torch.Tensor:
         | 
| 127 | 
            +
                """
         | 
| 128 | 
            +
                Radix sort with index tracking.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                Args:
         | 
| 131 | 
            +
                    x: Input tensor to sort
         | 
| 132 | 
            +
                    end_bit: Number of bits to consider in sorting
         | 
| 133 | 
            +
                    x_out: Output tensor for sorted values
         | 
| 134 | 
            +
                    iota_out: Output tensor for sorted indices
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                Returns:
         | 
| 137 | 
            +
                    The sorted values tensor
         | 
| 138 | 
            +
                """
         | 
| 139 | 
            +
                return ops.sort(x, end_bit, x_out, iota_out)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            # Convenience functions for common use cases
         | 
| 143 | 
            +
            def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
         | 
| 144 | 
            +
                """
         | 
| 145 | 
            +
                Compute cumulative sum with automatic output allocation.
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                Args:
         | 
| 148 | 
            +
                    x: Input tensor
         | 
| 149 | 
            +
                    dim: Dimension along which to compute cumsum (default: last dimension)
         | 
| 150 | 
            +
                    exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                Returns:
         | 
| 153 | 
            +
                    New tensor containing the cumulative sum
         | 
| 154 | 
            +
                """
         | 
| 155 | 
            +
                out = torch.empty_like(x)
         | 
| 156 | 
            +
                if exclusive:
         | 
| 157 | 
            +
                    return exclusive_cumsum(x, dim, out)
         | 
| 158 | 
            +
                else:
         | 
| 159 | 
            +
                    return inclusive_cumsum(x, dim, out)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
         | 
| 163 | 
            +
                """
         | 
| 164 | 
            +
                Sort tensor and return both sorted values and indices.
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                Args:
         | 
| 167 | 
            +
                    x: Input tensor to sort
         | 
| 168 | 
            +
                    end_bit: Number of bits to consider in sorting
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                Returns:
         | 
| 171 | 
            +
                    Tuple of (sorted_values, sorted_indices)
         | 
| 172 | 
            +
                """
         | 
| 173 | 
            +
                x_out = torch.empty_like(x)
         | 
| 174 | 
            +
                iota_out = torch.empty_like(x)
         | 
| 175 | 
            +
                sort(x, end_bit, x_out, iota_out)
         | 
| 176 | 
            +
                return x_out, iota_out
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            # Export public API
         | 
| 180 | 
            +
            __all__ = [
         | 
| 181 | 
            +
                "MyReplacementLayer",
         | 
| 182 | 
            +
                # Direct kernel exports
         | 
| 183 | 
            +
                "exclusive_cumsum",
         | 
| 184 | 
            +
                "inclusive_cumsum",
         | 
| 185 | 
            +
                "histogram",
         | 
| 186 | 
            +
                "indices",
         | 
| 187 | 
            +
                "replicate_forward",
         | 
| 188 | 
            +
                "replicate_backward",
         | 
| 189 | 
            +
                "sort",
         | 
| 190 | 
            +
                "cumsum",
         | 
| 191 | 
            +
                "argsort",
         | 
| 192 | 
            +
                # Original exports
         | 
| 193 | 
            +
                "Arguments",
         | 
| 194 | 
            +
                "ParallelDroplessMLP",
         | 
| 195 | 
            +
                "dMoE",
         | 
| 196 | 
            +
                "SparseGLU",
         | 
| 197 | 
            +
                "MLP",
         | 
| 198 | 
            +
                "SparseMLP",
         | 
| 199 | 
            +
                "MoE",
         | 
| 200 | 
            +
                "ParallelMLP",
         | 
| 201 | 
            +
                "get_load_balancing_loss",
         | 
| 202 | 
            +
            ]
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/__init__.py
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # from megablocks.layers.dmoe import dMoE
         | 
| 5 | 
            +
            from .moe import MoE
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            __all__ = [
         | 
| 8 | 
            +
                'MoE',
         | 
| 9 | 
            +
                # 'dMoE',
         | 
| 10 | 
            +
            ]
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/activation_fn.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Any, Callable, Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from ..stk import Matrix
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def act_fn(
         | 
| 11 | 
            +
                x: Matrix,
         | 
| 12 | 
            +
                function: Callable,
         | 
| 13 | 
            +
                return_grad_fn: bool = False,
         | 
| 14 | 
            +
                **kwargs,
         | 
| 15 | 
            +
            ) -> Union[tuple[Matrix, Any] | Matrix]:
         | 
| 16 | 
            +
                assert isinstance(x, Matrix)
         | 
| 17 | 
            +
                with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
         | 
| 18 | 
            +
                    if return_grad_fn:
         | 
| 19 | 
            +
                        x.data.requires_grad = True
         | 
| 20 | 
            +
                    out = function(x.data, **kwargs)
         | 
| 21 | 
            +
                    y = Matrix(
         | 
| 22 | 
            +
                        x.size(),
         | 
| 23 | 
            +
                        out,
         | 
| 24 | 
            +
                        x.row_indices,
         | 
| 25 | 
            +
                        x.column_indices,
         | 
| 26 | 
            +
                        x.offsets,
         | 
| 27 | 
            +
                        x.column_indices_t,
         | 
| 28 | 
            +
                        x.offsets_t,
         | 
| 29 | 
            +
                        x.block_offsets_t,
         | 
| 30 | 
            +
                    )
         | 
| 31 | 
            +
                    if return_grad_fn:
         | 
| 32 | 
            +
                        return y, out.backward
         | 
| 33 | 
            +
                    return y
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/all_to_all.py
    ADDED
    
    | @@ -0,0 +1,54 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.distributed as dist
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class AllToAllOp(torch.autograd.Function):
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                @staticmethod
         | 
| 11 | 
            +
                def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
         | 
| 12 | 
            +
                    out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                    ctx.input_shape = x.shape
         | 
| 15 | 
            +
                    ctx.output_split_sizes = output_split_sizes
         | 
| 16 | 
            +
                    ctx.input_split_sizes = input_split_sizes
         | 
| 17 | 
            +
                    ctx.group = group
         | 
| 18 | 
            +
                    handle = dist.all_to_all_single(
         | 
| 19 | 
            +
                        out,
         | 
| 20 | 
            +
                        x,
         | 
| 21 | 
            +
                        output_split_sizes=output_split_sizes,
         | 
| 22 | 
            +
                        input_split_sizes=input_split_sizes,
         | 
| 23 | 
            +
                        group=group,
         | 
| 24 | 
            +
                        async_op=async_op,
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
                    return out, handle
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                @staticmethod
         | 
| 29 | 
            +
                def backward(ctx, grad, _):
         | 
| 30 | 
            +
                    if ctx.needs_input_grad[0]:
         | 
| 31 | 
            +
                        out = torch.empty(
         | 
| 32 | 
            +
                            ctx.input_shape,
         | 
| 33 | 
            +
                            device=grad.device,
         | 
| 34 | 
            +
                            dtype=grad.dtype,
         | 
| 35 | 
            +
                        )
         | 
| 36 | 
            +
                        dist.all_to_all_single(
         | 
| 37 | 
            +
                            out,
         | 
| 38 | 
            +
                            grad,
         | 
| 39 | 
            +
                            output_split_sizes=ctx.input_split_sizes,
         | 
| 40 | 
            +
                            input_split_sizes=ctx.output_split_sizes,
         | 
| 41 | 
            +
                            group=ctx.group,
         | 
| 42 | 
            +
                        )
         | 
| 43 | 
            +
                        return out, None, None, None, None
         | 
| 44 | 
            +
                    return None, None, None, None, None
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
         | 
| 48 | 
            +
                return AllToAllOp.apply(
         | 
| 49 | 
            +
                    x,
         | 
| 50 | 
            +
                    output_split_sizes,
         | 
| 51 | 
            +
                    input_split_sizes,
         | 
| 52 | 
            +
                    group,
         | 
| 53 | 
            +
                    async_op,
         | 
| 54 | 
            +
                )
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/arguments.py
    ADDED
    
    | @@ -0,0 +1,101 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import dataclasses
         | 
| 5 | 
            +
            from functools import partial
         | 
| 6 | 
            +
            from typing import Any, Callable, Optional, Union
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.distributed as dist
         | 
| 10 | 
            +
            import torch.nn.functional as F
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # import megablocks.grouped_gemm_util as grouped_gemm
         | 
| 13 | 
            +
            from .. import grouped_gemm_util as grouped_gemm
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Type annotation for in-place Tensor initialization function.
         | 
| 16 | 
            +
            InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            _ALLOWED_BITWIDTHS = (-1, 4, 8)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            @dataclasses.dataclass
         | 
| 24 | 
            +
            class Arguments:
         | 
| 25 | 
            +
                # Model arguments.
         | 
| 26 | 
            +
                hidden_size: int = 1024
         | 
| 27 | 
            +
                ffn_hidden_size: int = 4096
         | 
| 28 | 
            +
                num_layers: int = 1
         | 
| 29 | 
            +
                bias: bool = True
         | 
| 30 | 
            +
                return_bias: bool = True
         | 
| 31 | 
            +
                activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                # MoE arguments.
         | 
| 34 | 
            +
                moe_num_experts: int = 1
         | 
| 35 | 
            +
                moe_top_k: int = 1
         | 
| 36 | 
            +
                moe_capacity_factor: int = 1
         | 
| 37 | 
            +
                moe_normalize_expert_weights: Optional[Union[int, float]] = None
         | 
| 38 | 
            +
                moe_loss_weight: float = 0.1
         | 
| 39 | 
            +
                moe_jitter_eps: Optional[float] = None
         | 
| 40 | 
            +
                moe_lbl_in_fp32: bool = False
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                # Parallelism arguments.
         | 
| 43 | 
            +
                moe_expert_model_parallelism: bool = False
         | 
| 44 | 
            +
                expert_parallel_group: Optional[dist.ProcessGroup] = None
         | 
| 45 | 
            +
                pipeline_model_parallel_size: int = 1
         | 
| 46 | 
            +
                num_layers_per_virtual_pipeline_stage: Optional[int] = None
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # Compute arguments.
         | 
| 49 | 
            +
                memory_optimized_mlp: bool = False
         | 
| 50 | 
            +
                mlp_type: str = 'mlp'
         | 
| 51 | 
            +
                mlp_impl: str = 'sparse'
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                # Initialization arguments.
         | 
| 54 | 
            +
                fp16: bool = True
         | 
| 55 | 
            +
                bf16: bool = False
         | 
| 56 | 
            +
                device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
         | 
| 57 | 
            +
                init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
         | 
| 58 | 
            +
                output_layer_init_method: InitFn = init_method
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                # Benchmarking arguments.
         | 
| 61 | 
            +
                uniform_expert_assignment: bool = False
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                # shared expert arguments
         | 
| 64 | 
            +
                shared_expert: bool = False  # enable using shared expert
         | 
| 65 | 
            +
                fc_cls: Any = torch.nn.Linear  # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
         | 
| 66 | 
            +
                fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,)  # kwargs for custom fc layers
         | 
| 67 | 
            +
                remat_act_fn: bool = True  # enable act fn to be rematerialized instead of stored
         | 
| 68 | 
            +
                shared_expert_hidden_size: Optional[
         | 
| 69 | 
            +
                    int] = None  # hidden size of the shared expert IF we want to set it to something different from hidden_size
         | 
| 70 | 
            +
                shared_expert_weighted_sum: bool = False  # enable using weighted sum for shared expert output (wieghted by number of experts used)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                # Router Z-loss arguments
         | 
| 73 | 
            +
                moe_zloss_weight: float = 0  # 1e-3 is a reasonable value
         | 
| 74 | 
            +
                moe_zloss_in_fp32: bool = False
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def __post_init__(self):
         | 
| 77 | 
            +
                    # Sparse MLP is not supported with triton >=3.2.0
         | 
| 78 | 
            +
                    # TODO: Remove this once sparse is supported with triton >=3.2.0
         | 
| 79 | 
            +
                    if self.__getattribute__('mlp_impl') == 'sparse':
         | 
| 80 | 
            +
                        try:
         | 
| 81 | 
            +
                            import triton
         | 
| 82 | 
            +
                            if triton.__version__ >= '3.2.0':
         | 
| 83 | 
            +
                                raise ValueError(
         | 
| 84 | 
            +
                                    'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
         | 
| 85 | 
            +
                                )
         | 
| 86 | 
            +
                        except ImportError:
         | 
| 87 | 
            +
                            raise ImportError('Triton is required for sparse MLP implementation')
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if self.__getattribute__('mlp_impl') == 'grouped':
         | 
| 90 | 
            +
                        grouped_gemm.assert_grouped_gemm_is_available()
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    if self.shared_expert_hidden_size is None:
         | 
| 93 | 
            +
                        self.shared_expert_hidden_size = self.ffn_hidden_size
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def from_megatron(megatron_args: Any):
         | 
| 97 | 
            +
                args = Arguments()
         | 
| 98 | 
            +
                for field in dataclasses.fields(args):
         | 
| 99 | 
            +
                    if hasattr(megatron_args, field.name):
         | 
| 100 | 
            +
                        setattr(args, field.name, getattr(megatron_args, field.name))
         | 
| 101 | 
            +
                return args
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/common.py
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .arguments import Arguments
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def dtype(args: Arguments):
         | 
| 10 | 
            +
                if args.fp16:
         | 
| 11 | 
            +
                    return torch.float16
         | 
| 12 | 
            +
                elif args.bf16:
         | 
| 13 | 
            +
                    return torch.bfloat16
         | 
| 14 | 
            +
                return None
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def cast_if_autocast_enabled(tensor):
         | 
| 18 | 
            +
                if torch.is_autocast_enabled():
         | 
| 19 | 
            +
                    if tensor.device.type == 'cuda':
         | 
| 20 | 
            +
                        dtype = torch.get_autocast_gpu_dtype()
         | 
| 21 | 
            +
                    elif tensor.device.type == 'cpu':
         | 
| 22 | 
            +
                        dtype = torch.get_autocast_cpu_dtype()
         | 
| 23 | 
            +
                    else:
         | 
| 24 | 
            +
                        raise NotImplementedError()
         | 
| 25 | 
            +
                    return tensor.to(dtype=dtype)
         | 
| 26 | 
            +
                return tensor
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmlp_registry.py
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from . import glu, mlp
         | 
| 7 | 
            +
            from .arguments import Arguments
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            _REGISTRY = {
         | 
| 12 | 
            +
                'mlp': {
         | 
| 13 | 
            +
                    'grouped': mlp.GroupedMLP,
         | 
| 14 | 
            +
                    'sparse': mlp.SparseMLP,
         | 
| 15 | 
            +
                },
         | 
| 16 | 
            +
                'glu': {
         | 
| 17 | 
            +
                    'grouped': glu.GroupedGLU,
         | 
| 18 | 
            +
                    'sparse': glu.SparseGLU,
         | 
| 19 | 
            +
                },
         | 
| 20 | 
            +
            }
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def get(args: Arguments) -> MlpType:
         | 
| 24 | 
            +
                """Returns an MLP for use in a dMoE instance.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Uses the provided arguments to instantiate the appropriate
         | 
| 27 | 
            +
                MLP instance. This only contains MLPs for use in dMoEs
         | 
| 28 | 
            +
                (ie. only for the dropless versions of MoEs).
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Args:
         | 
| 31 | 
            +
                    args: propagated Arguments dataclass.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Returns:
         | 
| 34 | 
            +
                    An instantiated MLP constructed using the input args.
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                if args.mlp_type not in _REGISTRY:
         | 
| 37 | 
            +
                    raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                if args.mlp_impl not in _REGISTRY[args.mlp_type]:
         | 
| 40 | 
            +
                    raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                return _REGISTRY[args.mlp_type][args.mlp_impl](args)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/dmoe.py
    ADDED
    
    | @@ -0,0 +1,337 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # try:
         | 
| 8 | 
            +
            #     import stk.ops
         | 
| 9 | 
            +
            # except ImportError:
         | 
| 10 | 
            +
            #     import warnings
         | 
| 11 | 
            +
            #     warnings.warn(
         | 
| 12 | 
            +
            #         'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
         | 
| 13 | 
            +
            #     )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # import megablocks.ops as ops
         | 
| 16 | 
            +
            # # from megablocks.ops import ops
         | 
| 17 | 
            +
            # from megablocks.layers import common, dmlp_registry, moe, mpu
         | 
| 18 | 
            +
            # from megablocks.layers.arguments import Arguments
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from .. import stk
         | 
| 21 | 
            +
            from .. import ops
         | 
| 22 | 
            +
            from . import common, dmlp_registry, moe, mpu
         | 
| 23 | 
            +
            from .arguments import Arguments
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def promote_scalar(x):
         | 
| 26 | 
            +
                return x.view(1) if not len(x.size()) else x
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class ParallelDroplessMLP(moe.ParallelMLP):
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __init__(self, args: Arguments):
         | 
| 32 | 
            +
                    super(ParallelDroplessMLP, self).__init__(args)
         | 
| 33 | 
            +
                    self.hidden_size = args.hidden_size
         | 
| 34 | 
            +
                    self.ffn_hidden_size = mpu.features_per_rank(args)
         | 
| 35 | 
            +
                    self.blocking = 128
         | 
| 36 | 
            +
                    self.mlp = dmlp_registry.get(args)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    # Calculate the number of bits needed to represent the column indices
         | 
| 39 | 
            +
                    # in the intermediate sparse matrix.
         | 
| 40 | 
            +
                    max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
         | 
| 41 | 
            +
                    self.transpose_sort_end_bit = max(
         | 
| 42 | 
            +
                        int(np.ceil(np.log2(max_column_index))),
         | 
| 43 | 
            +
                        1,
         | 
| 44 | 
            +
                    )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def sparse_transpose(self, size, row_indices, column_indices, offsets):
         | 
| 47 | 
            +
                    block_columns = size[1] // self.blocking
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    # Sort row indices by column indices to get the transposed matrix's
         | 
| 50 | 
            +
                    # column indices.
         | 
| 51 | 
            +
                    #
         | 
| 52 | 
            +
                    # NOTE: Our sort operation uses the same width indices as the input values.
         | 
| 53 | 
            +
                    # To avoid overflow when we have large activation matrices we cast to
         | 
| 54 | 
            +
                    # 32-bit before sorting.
         | 
| 55 | 
            +
                    _, gather_indices = ops.sort(
         | 
| 56 | 
            +
                        column_indices.int(),
         | 
| 57 | 
            +
                        self.transpose_sort_end_bit,
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # There are a constant number of blocks in every row of the sparse matrix.
         | 
| 61 | 
            +
                    # A blocks offset is:
         | 
| 62 | 
            +
                    #
         | 
| 63 | 
            +
                    # row_index * blocks_per_row + column_index % blocks_per_row
         | 
| 64 | 
            +
                    #
         | 
| 65 | 
            +
                    # Once we have the block offsets ordered for transposition we can divide
         | 
| 66 | 
            +
                    # by blocks_per_row to get the transposed column indices.
         | 
| 67 | 
            +
                    column_indices_t = row_indices.gather(0, gather_indices.long())
         | 
| 68 | 
            +
                    block_offsets_t = gather_indices.int()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
         | 
| 71 | 
            +
                    nnz_per_column = ops.histogram(column_indices, block_columns)
         | 
| 72 | 
            +
                    nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
         | 
| 73 | 
            +
                    if nnz_per_column.dim() == 0:
         | 
| 74 | 
            +
                        # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
         | 
| 75 | 
            +
                        nnz_per_column = nnz_per_column.unsqueeze(0)
         | 
| 76 | 
            +
                    offsets_t = torch.cat([zero, nnz_per_column])
         | 
| 77 | 
            +
                    return column_indices_t, offsets_t, block_offsets_t
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def topology(self, x, padded_bins):
         | 
| 80 | 
            +
                    padded_tokens, _ = x.size()
         | 
| 81 | 
            +
                    assert padded_tokens % self.blocking == 0
         | 
| 82 | 
            +
                    if self.ffn_hidden_size % self.blocking != 0:
         | 
| 83 | 
            +
                        raise ValueError(
         | 
| 84 | 
            +
                            f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
         | 
| 85 | 
            +
                            f'the block size {self.blocking}. Please update your configuration.',
         | 
| 86 | 
            +
                        )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Offsets for the sparse matrix. All rows have the
         | 
| 89 | 
            +
                    # same number of nonzero blocks dictated by the
         | 
| 90 | 
            +
                    # dimensionality of a single expert.
         | 
| 91 | 
            +
                    block_rows = padded_tokens // self.blocking
         | 
| 92 | 
            +
                    blocks_per_row = self.ffn_hidden_size // self.blocking
         | 
| 93 | 
            +
                    offsets = torch.arange(
         | 
| 94 | 
            +
                        0,
         | 
| 95 | 
            +
                        block_rows * blocks_per_row + 1,
         | 
| 96 | 
            +
                        blocks_per_row,
         | 
| 97 | 
            +
                        dtype=torch.int32,
         | 
| 98 | 
            +
                        device=x.device,
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    # Indices for the sparse matrix. The indices for
         | 
| 102 | 
            +
                    # the intermediate matrix are dynamic depending
         | 
| 103 | 
            +
                    # on the mapping of tokens to experts.
         | 
| 104 | 
            +
                    column_indices = ops.topology(
         | 
| 105 | 
            +
                        padded_bins,
         | 
| 106 | 
            +
                        self.blocking,
         | 
| 107 | 
            +
                        block_rows,
         | 
| 108 | 
            +
                        blocks_per_row,
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # TODO(tgale): This is unused. Remove the need for this in stk.
         | 
| 112 | 
            +
                    # For now, use meta init to save the device memory.
         | 
| 113 | 
            +
                    data = torch.empty(
         | 
| 114 | 
            +
                        column_indices.numel(),
         | 
| 115 | 
            +
                        self.blocking,
         | 
| 116 | 
            +
                        self.blocking,
         | 
| 117 | 
            +
                        dtype=common.dtype(self.args),
         | 
| 118 | 
            +
                        device='meta',
         | 
| 119 | 
            +
                    )
         | 
| 120 | 
            +
                    shape = (
         | 
| 121 | 
            +
                        padded_tokens,
         | 
| 122 | 
            +
                        self.ffn_hidden_size * mpu.experts_per_rank(self.args),
         | 
| 123 | 
            +
                    )
         | 
| 124 | 
            +
                    row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
         | 
| 125 | 
            +
                    column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
         | 
| 126 | 
            +
                        shape,
         | 
| 127 | 
            +
                        row_indices,
         | 
| 128 | 
            +
                        column_indices,
         | 
| 129 | 
            +
                        offsets,
         | 
| 130 | 
            +
                    )
         | 
| 131 | 
            +
                    return stk.Matrix(
         | 
| 132 | 
            +
                        shape,
         | 
| 133 | 
            +
                        data,
         | 
| 134 | 
            +
                        row_indices,
         | 
| 135 | 
            +
                        column_indices,
         | 
| 136 | 
            +
                        offsets,
         | 
| 137 | 
            +
                        column_indices_t,
         | 
| 138 | 
            +
                        offsets_t,
         | 
| 139 | 
            +
                        block_offsets_t,
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def indices_and_padded_bins(self, top_experts):
         | 
| 143 | 
            +
                    # Sort the expert ids to produce the scatter/gather
         | 
| 144 | 
            +
                    # indices for the permutation.
         | 
| 145 | 
            +
                    top_experts = top_experts.int()
         | 
| 146 | 
            +
                    bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    # Histogram the expert ids to identify the number of
         | 
| 149 | 
            +
                    # tokens routed to each expert.
         | 
| 150 | 
            +
                    tokens_per_expert = ops.histogram(top_experts, self.num_experts)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    # Round the token counts up to the block size used in
         | 
| 153 | 
            +
                    # the matrix muliplications. Caculate the starting
         | 
| 154 | 
            +
                    # position of each bin.
         | 
| 155 | 
            +
                    padded_tokens_per_expert = ops.round_up(
         | 
| 156 | 
            +
                        tokens_per_expert,
         | 
| 157 | 
            +
                        self.blocking,
         | 
| 158 | 
            +
                    )
         | 
| 159 | 
            +
                    padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         | 
| 160 | 
            +
                    padded_bins = promote_scalar(padded_bins)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    # Calculate the bin bounds for the sorted tokens.
         | 
| 163 | 
            +
                    bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 164 | 
            +
                    bins = promote_scalar(bins)
         | 
| 165 | 
            +
                    return indices, bin_ids, bins, padded_bins, tokens_per_expert
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                def sparse_forward_once(self, x, expert_weights, top_experts):
         | 
| 168 | 
            +
                    # x: [sl, bs, hs]
         | 
| 169 | 
            +
                    # expert_weights: [sl * bs, top-k]
         | 
| 170 | 
            +
                    # top_experts: [sl * bs, top-k]
         | 
| 171 | 
            +
                    expert_weights = expert_weights.flatten()
         | 
| 172 | 
            +
                    top_experts = top_experts.flatten()
         | 
| 173 | 
            +
                    with torch.no_grad():
         | 
| 174 | 
            +
                        indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # Route the tokens for MoE computation.
         | 
| 177 | 
            +
                    x = x.view(-1, x.shape[-1])
         | 
| 178 | 
            +
                    x = ops.padded_gather(
         | 
| 179 | 
            +
                        x,
         | 
| 180 | 
            +
                        indices,
         | 
| 181 | 
            +
                        bin_ids,
         | 
| 182 | 
            +
                        bins,
         | 
| 183 | 
            +
                        padded_bins,
         | 
| 184 | 
            +
                        self.top_k,
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # Create the sparse matrix topology.
         | 
| 188 | 
            +
                    with torch.no_grad():
         | 
| 189 | 
            +
                        topo = self.topology(x, padded_bins)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    # Perform the expert computation.
         | 
| 192 | 
            +
                    x = self.mlp(x, topo)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    # Un-route the data for the MoE output.
         | 
| 195 | 
            +
                    x = ops.padded_scatter(
         | 
| 196 | 
            +
                        x,
         | 
| 197 | 
            +
                        indices,
         | 
| 198 | 
            +
                        bin_ids,
         | 
| 199 | 
            +
                        expert_weights,
         | 
| 200 | 
            +
                        bins,
         | 
| 201 | 
            +
                        padded_bins,
         | 
| 202 | 
            +
                        self.top_k,
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
                    return x, tokens_per_expert
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                # For use in the base-class parallel_forward_once.
         | 
| 207 | 
            +
                def sparse_permute_and_compute(
         | 
| 208 | 
            +
                    self,
         | 
| 209 | 
            +
                    x,
         | 
| 210 | 
            +
                    tokens_per_expert,
         | 
| 211 | 
            +
                    indices,
         | 
| 212 | 
            +
                    bin_ids,
         | 
| 213 | 
            +
                    expert_weights,
         | 
| 214 | 
            +
                    bins,
         | 
| 215 | 
            +
                    expert_capactiy,  # unused
         | 
| 216 | 
            +
                    top_k,
         | 
| 217 | 
            +
                ):
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    # Round the token counts up to the block size used in the matrix
         | 
| 220 | 
            +
                    # multiplication. Calculate the starting position of each bin.
         | 
| 221 | 
            +
                    padded_tokens_per_expert = ops.round_up(
         | 
| 222 | 
            +
                        tokens_per_expert,
         | 
| 223 | 
            +
                        self.blocking,
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
                    padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         | 
| 226 | 
            +
                    padded_bins = promote_scalar(padded_bins)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    # Route the tokens for MoE computation.
         | 
| 229 | 
            +
                    x = x.view(-1, x.shape[-1])
         | 
| 230 | 
            +
                    x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    # Create the sparse matrix topology.
         | 
| 233 | 
            +
                    with torch.no_grad():
         | 
| 234 | 
            +
                        topo = self.topology(x, padded_bins)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    # Perform the expert computation.
         | 
| 237 | 
            +
                    x = self.mlp(x, topo)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    # Un-route the data for the MoE output.
         | 
| 240 | 
            +
                    return ops.padded_scatter(
         | 
| 241 | 
            +
                        x,
         | 
| 242 | 
            +
                        indices,
         | 
| 243 | 
            +
                        bin_ids,
         | 
| 244 | 
            +
                        expert_weights,
         | 
| 245 | 
            +
                        bins,
         | 
| 246 | 
            +
                        padded_bins,
         | 
| 247 | 
            +
                        top_k,
         | 
| 248 | 
            +
                    )
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                def grouped_forward_once(self, x, expert_weights, top_experts):
         | 
| 251 | 
            +
                    # x: [sl, bs, hs]
         | 
| 252 | 
            +
                    # expert_weights: [sl * bs, top-k]
         | 
| 253 | 
            +
                    # top_experts: [sl * bs, top-k]
         | 
| 254 | 
            +
                    expert_weights = expert_weights.flatten()
         | 
| 255 | 
            +
                    top_experts = top_experts.flatten()
         | 
| 256 | 
            +
                    with torch.no_grad():
         | 
| 257 | 
            +
                        indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    out = self.grouped_permute_and_compute(
         | 
| 260 | 
            +
                        x,
         | 
| 261 | 
            +
                        tokens_per_expert,
         | 
| 262 | 
            +
                        indices,
         | 
| 263 | 
            +
                        bin_ids,
         | 
| 264 | 
            +
                        expert_weights,
         | 
| 265 | 
            +
                        bins,
         | 
| 266 | 
            +
                        -1,  # unused
         | 
| 267 | 
            +
                        self.args.moe_top_k,
         | 
| 268 | 
            +
                    )
         | 
| 269 | 
            +
                    return out, tokens_per_expert
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                def grouped_permute_and_compute(
         | 
| 272 | 
            +
                    self,
         | 
| 273 | 
            +
                    x,
         | 
| 274 | 
            +
                    tokens_per_expert,
         | 
| 275 | 
            +
                    indices,
         | 
| 276 | 
            +
                    bin_ids,
         | 
| 277 | 
            +
                    expert_weights,
         | 
| 278 | 
            +
                    bins,
         | 
| 279 | 
            +
                    expert_capactiy,  # unused
         | 
| 280 | 
            +
                    top_k,
         | 
| 281 | 
            +
                ):
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    # Route the tokens for MoE computation.
         | 
| 284 | 
            +
                    x = x.view(-1, x.shape[-1])
         | 
| 285 | 
            +
                    x = ops.gather(x, indices, bin_ids, bins, top_k)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    # Perform the expert computation.
         | 
| 288 | 
            +
                    x = self.mlp(x, tokens_per_expert)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    # Un-route the data for the MoE output.
         | 
| 291 | 
            +
                    return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                def forward_once(self, x, expert_weights, top_experts):
         | 
| 294 | 
            +
                    if self.args.mlp_impl == 'sparse':
         | 
| 295 | 
            +
                        return self.sparse_forward_once(x, expert_weights, top_experts)
         | 
| 296 | 
            +
                    else:
         | 
| 297 | 
            +
                        return self.grouped_forward_once(x, expert_weights, top_experts)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                def permute_and_compute(
         | 
| 300 | 
            +
                    self,
         | 
| 301 | 
            +
                    x,
         | 
| 302 | 
            +
                    tokens_per_expert,
         | 
| 303 | 
            +
                    indices,
         | 
| 304 | 
            +
                    bin_ids,
         | 
| 305 | 
            +
                    expert_weights,
         | 
| 306 | 
            +
                    bins,
         | 
| 307 | 
            +
                    expert_capactiy,
         | 
| 308 | 
            +
                    top_k,
         | 
| 309 | 
            +
                ):
         | 
| 310 | 
            +
                    if self.args.mlp_impl == 'sparse':
         | 
| 311 | 
            +
                        return self.sparse_permute_and_compute(
         | 
| 312 | 
            +
                            x,
         | 
| 313 | 
            +
                            tokens_per_expert,
         | 
| 314 | 
            +
                            indices,
         | 
| 315 | 
            +
                            bin_ids,
         | 
| 316 | 
            +
                            expert_weights,
         | 
| 317 | 
            +
                            bins,
         | 
| 318 | 
            +
                            expert_capactiy,
         | 
| 319 | 
            +
                            top_k,
         | 
| 320 | 
            +
                        )
         | 
| 321 | 
            +
                    else:
         | 
| 322 | 
            +
                        return self.grouped_permute_and_compute(
         | 
| 323 | 
            +
                            x,
         | 
| 324 | 
            +
                            tokens_per_expert,
         | 
| 325 | 
            +
                            indices,
         | 
| 326 | 
            +
                            bin_ids,
         | 
| 327 | 
            +
                            expert_weights,
         | 
| 328 | 
            +
                            bins,
         | 
| 329 | 
            +
                            expert_capactiy,
         | 
| 330 | 
            +
                            top_k,
         | 
| 331 | 
            +
                        )
         | 
| 332 | 
            +
             | 
| 333 | 
            +
             | 
| 334 | 
            +
            class dMoE(moe.MoE):
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                def _init_experts_mlp(self, args: Arguments):
         | 
| 337 | 
            +
                    return ParallelDroplessMLP(args)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/gelu.py
    ADDED
    
    | @@ -0,0 +1,52 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # try:
         | 
| 5 | 
            +
            #     import stk
         | 
| 6 | 
            +
            # except ImportError:
         | 
| 7 | 
            +
            #     import warnings
         | 
| 8 | 
            +
            #     warnings.warn(
         | 
| 9 | 
            +
            #         'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
         | 
| 10 | 
            +
            #     )
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .. import stk
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.nn.functional as F
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            @torch.jit.script
         | 
| 19 | 
            +
            def _gelu_backward_inplace(g, x):
         | 
| 20 | 
            +
                tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
         | 
| 21 | 
            +
                ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
         | 
| 22 | 
            +
                return g.mul_(ff)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
         | 
| 26 | 
            +
                # NOTE: The two sparse matrices must have the same topology.
         | 
| 27 | 
            +
                if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
         | 
| 28 | 
            +
                    return stk.Matrix(
         | 
| 29 | 
            +
                        x.size(),
         | 
| 30 | 
            +
                        _gelu_backward_inplace(grad.data, x.data),
         | 
| 31 | 
            +
                        x.row_indices,
         | 
| 32 | 
            +
                        x.column_indices,
         | 
| 33 | 
            +
                        x.offsets,
         | 
| 34 | 
            +
                        x.column_indices_t,
         | 
| 35 | 
            +
                        x.offsets_t,
         | 
| 36 | 
            +
                        x.block_offsets_t,
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
                return _gelu_backward_inplace(grad, x)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def gelu(x: stk.Matrix):
         | 
| 42 | 
            +
                assert isinstance(x, stk.Matrix)
         | 
| 43 | 
            +
                return stk.Matrix(
         | 
| 44 | 
            +
                    x.size(),
         | 
| 45 | 
            +
                    F.gelu(x.data, approximate='tanh'),
         | 
| 46 | 
            +
                    x.row_indices,
         | 
| 47 | 
            +
                    x.column_indices,
         | 
| 48 | 
            +
                    x.offsets,
         | 
| 49 | 
            +
                    x.column_indices_t,
         | 
| 50 | 
            +
                    x.offsets_t,
         | 
| 51 | 
            +
                    x.block_offsets_t,
         | 
| 52 | 
            +
                )
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/glu.py
    ADDED
    
    | @@ -0,0 +1,244 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # import stk.ops
         | 
| 5 | 
            +
            # try:
         | 
| 6 | 
            +
            #     import stk.ops
         | 
| 7 | 
            +
            # except ImportError:
         | 
| 8 | 
            +
            #     import warnings
         | 
| 9 | 
            +
            #     warnings.warn(
         | 
| 10 | 
            +
            #         'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
         | 
| 11 | 
            +
            #     )
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .. import stk
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # from megablocks import grouped_gemm_util as gg
         | 
| 18 | 
            +
            # from megablocks.layers import common, mpu
         | 
| 19 | 
            +
            # from megablocks.layers.activation_fn import act_fn
         | 
| 20 | 
            +
            # from megablocks.layers.arguments import Arguments
         | 
| 21 | 
            +
            # from megablocks.layers.mlp import (
         | 
| 22 | 
            +
            #     SharedMLP,
         | 
| 23 | 
            +
            #     SparseMLP,
         | 
| 24 | 
            +
            #     create_dmoe_expert_weights,
         | 
| 25 | 
            +
            #     resolve_dtensor,
         | 
| 26 | 
            +
            # )
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from .. import grouped_gemm_util as gg
         | 
| 29 | 
            +
            from . import common, mpu
         | 
| 30 | 
            +
            from .activation_fn import act_fn
         | 
| 31 | 
            +
            from .arguments import Arguments
         | 
| 32 | 
            +
            from .mlp import (
         | 
| 33 | 
            +
                SharedMLP,
         | 
| 34 | 
            +
                SparseMLP,
         | 
| 35 | 
            +
                create_dmoe_expert_weights,
         | 
| 36 | 
            +
                resolve_dtensor,
         | 
| 37 | 
            +
            )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class SparseGLU(SparseMLP):
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def __init__(self, args: Arguments):
         | 
| 43 | 
            +
                    super().__init__(args)
         | 
| 44 | 
            +
                    self.v1 = torch.nn.Parameter(
         | 
| 45 | 
            +
                        torch.empty(
         | 
| 46 | 
            +
                            self._num_rows_per_rank,
         | 
| 47 | 
            +
                            args.hidden_size,
         | 
| 48 | 
            +
                            device=args.device,
         | 
| 49 | 
            +
                            dtype=common.dtype(args),
         | 
| 50 | 
            +
                        ),
         | 
| 51 | 
            +
                    )
         | 
| 52 | 
            +
                    with torch.no_grad():
         | 
| 53 | 
            +
                        self.v1.copy_(
         | 
| 54 | 
            +
                            create_dmoe_expert_weights(
         | 
| 55 | 
            +
                                args,
         | 
| 56 | 
            +
                                args.moe_num_experts,
         | 
| 57 | 
            +
                                args.ffn_hidden_size,
         | 
| 58 | 
            +
                                args.hidden_size,
         | 
| 59 | 
            +
                                args.init_method,
         | 
| 60 | 
            +
                            ),
         | 
| 61 | 
            +
                        )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    mpu.set_expert_model_parallel_attributes(
         | 
| 64 | 
            +
                        self.v1,
         | 
| 65 | 
            +
                        self._should_set_parallelism_attribute,
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def forward(self, x, topo):
         | 
| 69 | 
            +
                    if self.args.memory_optimized_mlp:
         | 
| 70 | 
            +
                        raise NotImplementedError(
         | 
| 71 | 
            +
                            'Memory optimized implementation not yet supported with GLU with sparse kernels.',
         | 
| 72 | 
            +
                        )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
         | 
| 75 | 
            +
                    w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    # Compute the GLU.
         | 
| 78 | 
            +
                    x1 = stk.ops.sdd(x, w1.t(), topo)
         | 
| 79 | 
            +
                    x2 = stk.ops.sdd(x, v1.t(), topo)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    activation_fn_out = act_fn(x1, self.args.activation_fn)
         | 
| 82 | 
            +
                    x1 = stk.ops.mul(activation_fn_out, x2)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    return stk.ops.dsd(x1, w2)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            class MemoryOptimizedGroupedGLU(torch.autograd.Function):
         | 
| 88 | 
            +
                """GroupedMLP with manually scheduled memory reuse."""
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                @staticmethod
         | 
| 91 | 
            +
                @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
         | 
| 92 | 
            +
                def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
         | 
| 93 | 
            +
                    # Cast inputs using ctx dtype from AMP
         | 
| 94 | 
            +
                    if ctx._fwd_used_autocast:
         | 
| 95 | 
            +
                        x = x.to(ctx._dtype)
         | 
| 96 | 
            +
                        w1 = w1.to(ctx._dtype)
         | 
| 97 | 
            +
                        v1 = v1.to(ctx._dtype)
         | 
| 98 | 
            +
                        w2 = w2.to(ctx._dtype)
         | 
| 99 | 
            +
                    # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
         | 
| 100 | 
            +
                    if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
         | 
| 101 | 
            +
                        raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # Layer 0: x @ w1.t().
         | 
| 104 | 
            +
                    assert gg.backend is not None
         | 
| 105 | 
            +
                    sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
         | 
| 106 | 
            +
                    v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    # GeLU.
         | 
| 109 | 
            +
                    activation_fn_out = activation_fn(sdd_out) * v1_out
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # Layer 1: x @ w2.
         | 
| 112 | 
            +
                    dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # NOTE: Save the input to the layer and the activation_fn input for
         | 
| 115 | 
            +
                    # gradient computation. We'll re-compute the activation_fn forward
         | 
| 116 | 
            +
                    # pass in the backward pass to avoid materializing another
         | 
| 117 | 
            +
                    # intermediate.
         | 
| 118 | 
            +
                    ctx.x_shape = x.shape
         | 
| 119 | 
            +
                    ctx.sdd_out_shape = sdd_out.shape
         | 
| 120 | 
            +
                    ctx.dtype = x.dtype
         | 
| 121 | 
            +
                    ctx.activation_fn = activation_fn
         | 
| 122 | 
            +
                    ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
         | 
| 123 | 
            +
                    return dsd_out
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                @staticmethod
         | 
| 126 | 
            +
                @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
         | 
| 127 | 
            +
                def backward(ctx, ddsd_out):
         | 
| 128 | 
            +
                    if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
         | 
| 129 | 
            +
                        raise ValueError('Expected all MLP inputs to need grad.')
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    # Unpack saved tensors
         | 
| 132 | 
            +
                    # dtype = ctx.dtype
         | 
| 133 | 
            +
                    saved_tensors = ctx.saved_tensors
         | 
| 134 | 
            +
                    w1, v1, w2 = saved_tensors[:3]
         | 
| 135 | 
            +
                    batch_sizes = saved_tensors[3]
         | 
| 136 | 
            +
                    x = saved_tensors[4]
         | 
| 137 | 
            +
                    sdd_out, v1_out = saved_tensors[5:7]
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # Rematerialize activation_fn output.
         | 
| 140 | 
            +
                    activation_fn = ctx.activation_fn
         | 
| 141 | 
            +
                    with torch.set_grad_enabled(True):
         | 
| 142 | 
            +
                        sdd_out.requires_grad = True
         | 
| 143 | 
            +
                        v1_out.requires_grad = True
         | 
| 144 | 
            +
                        activation_fn_out = activation_fn(sdd_out) * v1_out
         | 
| 145 | 
            +
                        activation_grad_fn = activation_fn_out.backward
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # Compute dw2 with recomputed activation_fn output.
         | 
| 148 | 
            +
                    assert gg.backend is not None
         | 
| 149 | 
            +
                    dw2 = gg.backend.gmm(
         | 
| 150 | 
            +
                        activation_fn_out,
         | 
| 151 | 
            +
                        ddsd_out,
         | 
| 152 | 
            +
                        batch_sizes,
         | 
| 153 | 
            +
                        trans_a=True,
         | 
| 154 | 
            +
                    )
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # Compute dactivation_fn_out.
         | 
| 157 | 
            +
                    #
         | 
| 158 | 
            +
                    # NOTE: We reuse the activation_fn_out allocation.
         | 
| 159 | 
            +
                    dactivation_fn_out = activation_fn_out
         | 
| 160 | 
            +
                    gg.backend.gmm(
         | 
| 161 | 
            +
                        ddsd_out,
         | 
| 162 | 
            +
                        w2,
         | 
| 163 | 
            +
                        batch_sizes,
         | 
| 164 | 
            +
                        trans_b=True,
         | 
| 165 | 
            +
                        c=dactivation_fn_out,
         | 
| 166 | 
            +
                    )
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # Compute dsdd_out.
         | 
| 169 | 
            +
                    #
         | 
| 170 | 
            +
                    # NOTE: This reuses the dactivation_fn_out allocation.
         | 
| 171 | 
            +
                    assert activation_grad_fn is not None
         | 
| 172 | 
            +
                    activation_grad_fn(dactivation_fn_out)
         | 
| 173 | 
            +
                    dsdd_out = sdd_out.grad
         | 
| 174 | 
            +
                    dv1_out = v1_out.grad
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # Compute dw1.
         | 
| 177 | 
            +
                    dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # Compute dv1.
         | 
| 180 | 
            +
                    dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # Compute dx.
         | 
| 183 | 
            +
                    #
         | 
| 184 | 
            +
                    # NOTE: This reuses the ddsd_out allocation.
         | 
| 185 | 
            +
                    dx = ddsd_out
         | 
| 186 | 
            +
                    gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
         | 
| 187 | 
            +
                    dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
         | 
| 188 | 
            +
                    return dx, dw1, dv1, dw2, None, None
         | 
| 189 | 
            +
             | 
| 190 | 
            +
             | 
| 191 | 
            +
            memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
            +
            class GroupedGLU(SparseGLU):
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                def forward(self, x, tokens_per_expert):
         | 
| 197 | 
            +
                    batch_sizes = tokens_per_expert.cpu().to(torch.long)
         | 
| 198 | 
            +
                    w1, v1, w2 = (
         | 
| 199 | 
            +
                        self.scale_grad(self.w1),
         | 
| 200 | 
            +
                        self.scale_grad(self.v1),
         | 
| 201 | 
            +
                        self.scale_grad(self.w2),
         | 
| 202 | 
            +
                    )
         | 
| 203 | 
            +
                    w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    # Re-shape the weights for the grouped GEMMs.
         | 
| 206 | 
            +
                    ne = mpu.experts_per_rank(self.args)
         | 
| 207 | 
            +
                    w1 = w1.view(ne, -1, self.args.hidden_size)
         | 
| 208 | 
            +
                    v1 = v1.view(ne, -1, self.args.hidden_size)
         | 
| 209 | 
            +
                    w2 = w2.view(ne, -1, self.args.hidden_size)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    if self.args.memory_optimized_mlp:
         | 
| 212 | 
            +
                        return memory_optimized_grouped_glu(
         | 
| 213 | 
            +
                            x,
         | 
| 214 | 
            +
                            w1,
         | 
| 215 | 
            +
                            v1,
         | 
| 216 | 
            +
                            w2,
         | 
| 217 | 
            +
                            batch_sizes,
         | 
| 218 | 
            +
                            self.args.activation_fn,
         | 
| 219 | 
            +
                        )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    # Compute the MLP.
         | 
| 222 | 
            +
                    assert gg.ops is not None
         | 
| 223 | 
            +
                    x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
         | 
| 224 | 
            +
                    x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
         | 
| 225 | 
            +
                    x1 = self.args.activation_fn(x1) * x2
         | 
| 226 | 
            +
                    return gg.ops.gmm(x1, w2, batch_sizes)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            class SharedGLU(SharedMLP):
         | 
| 230 | 
            +
                """GPU for shared expert.
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
         | 
| 233 | 
            +
                """
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                def __init__(self, args: Arguments):
         | 
| 236 | 
            +
                    super().__init__(args)
         | 
| 237 | 
            +
                    self.gate_proj = args.fc_cls(
         | 
| 238 | 
            +
                        args.hidden_size,
         | 
| 239 | 
            +
                        self.args.shared_expert_hidden_size,
         | 
| 240 | 
            +
                        **self.fc_kwargs,
         | 
| 241 | 
            +
                    )
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 244 | 
            +
                    return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/memory_test.py
    ADDED
    
    | @@ -0,0 +1,103 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import gc
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.distributed as dist
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # from megablocks.layers import arguments, dmoe
         | 
| 10 | 
            +
            from . import arguments, dmoe
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def get_tensors():
         | 
| 16 | 
            +
                ptrs = set()
         | 
| 17 | 
            +
                out = []
         | 
| 18 | 
            +
                for obj in gc.get_objects():
         | 
| 19 | 
            +
                    if torch.is_tensor(obj):
         | 
| 20 | 
            +
                        if not obj.is_contiguous() or obj.data_ptr() in ptrs:
         | 
| 21 | 
            +
                            continue
         | 
| 22 | 
            +
                        out.append(obj)
         | 
| 23 | 
            +
                        ptrs.add(obj.data_ptr())
         | 
| 24 | 
            +
                return out
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def test_memory(
         | 
| 28 | 
            +
                group,
         | 
| 29 | 
            +
                batch_size,
         | 
| 30 | 
            +
                sequence_length,
         | 
| 31 | 
            +
                hidden_size,
         | 
| 32 | 
            +
                ffn_hidden_size,
         | 
| 33 | 
            +
                num_experts,
         | 
| 34 | 
            +
                top_k,
         | 
| 35 | 
            +
            ):
         | 
| 36 | 
            +
                args = arguments.Arguments(
         | 
| 37 | 
            +
                    hidden_size=hidden_size,
         | 
| 38 | 
            +
                    ffn_hidden_size=ffn_hidden_size,
         | 
| 39 | 
            +
                    moe_num_experts=num_experts,
         | 
| 40 | 
            +
                    moe_top_k=top_k,
         | 
| 41 | 
            +
                    moe_expert_model_parallelism=True,
         | 
| 42 | 
            +
                    expert_parallel_group=group,
         | 
| 43 | 
            +
                    fp16=False,
         | 
| 44 | 
            +
                    bf16=True,
         | 
| 45 | 
            +
                    device=torch.cuda.current_device(),
         | 
| 46 | 
            +
                )
         | 
| 47 | 
            +
                layer = dmoe.dMoE(args).cuda()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                x = torch.randn((batch_size, sequence_length, hidden_size),
         | 
| 50 | 
            +
                                device=torch.cuda.current_device(),
         | 
| 51 | 
            +
                                dtype=torch.bfloat16).requires_grad_(True)
         | 
| 52 | 
            +
                torch.cuda.empty_cache()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # Run forward + backward.
         | 
| 55 | 
            +
                # with torch.autograd.detect_anomaly():
         | 
| 56 | 
            +
                out, _ = layer(x)
         | 
| 57 | 
            +
                out.mean().backward()
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # Report peak memory.
         | 
| 60 | 
            +
                mem = torch.cuda.max_memory_allocated()
         | 
| 61 | 
            +
                print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
         | 
| 62 | 
            +
                print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                # Calculate weight and gradient memory usage.
         | 
| 65 | 
            +
                weight_memory = 2 * (
         | 
| 66 | 
            +
                    layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
         | 
| 67 | 
            +
                )
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def grad_numel(x):
         | 
| 70 | 
            +
                    if x.grad is not None:
         | 
| 71 | 
            +
                        return x.grad.numel()
         | 
| 72 | 
            +
                    return 0
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                grad_memory = 2 * (
         | 
| 75 | 
            +
                    grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
         | 
| 76 | 
            +
                )
         | 
| 77 | 
            +
                weight_memory += grad_memory
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
         | 
| 80 | 
            +
                print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # Manually calculate GPU memory usage from the garbage
         | 
| 83 | 
            +
                # collector.
         | 
| 84 | 
            +
                gc.collect()
         | 
| 85 | 
            +
                total = 0
         | 
| 86 | 
            +
                tensors = get_tensors()
         | 
| 87 | 
            +
                tensors = sorted(tensors, key=lambda x: -x.numel())
         | 
| 88 | 
            +
                for i, t in enumerate(tensors):
         | 
| 89 | 
            +
                    total += t.numel()
         | 
| 90 | 
            +
                    print(f'{i}: {t.shape}, {t.numel() * 2}')
         | 
| 91 | 
            +
                del tensors
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            if __name__ == '__main__':
         | 
| 97 | 
            +
                assert dist.is_available()
         | 
| 98 | 
            +
                group = dist.init_process_group(backend='nccl')
         | 
| 99 | 
            +
                local_rank = dist.get_rank(group)
         | 
| 100 | 
            +
                torch.cuda.set_device(local_rank)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                for args in _TESTS:
         | 
| 103 | 
            +
                    test_memory(group, *args)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mlp.py
    ADDED
    
    | @@ -0,0 +1,587 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # try:
         | 
| 7 | 
            +
            #     import stk
         | 
| 8 | 
            +
            #     import stk.backend.triton_kernels
         | 
| 9 | 
            +
            #     import stk.ops
         | 
| 10 | 
            +
            # except ImportError:
         | 
| 11 | 
            +
            #     import warnings
         | 
| 12 | 
            +
            #     warnings.warn(
         | 
| 13 | 
            +
            #         'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
         | 
| 14 | 
            +
            #     )
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .. import stk
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            from packaging import version
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            # from megablocks import grouped_gemm_util as gg
         | 
| 22 | 
            +
            # from megablocks.layers import common, gelu, mpu
         | 
| 23 | 
            +
            # from megablocks.layers.activation_fn import act_fn
         | 
| 24 | 
            +
            # from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from .. import grouped_gemm_util as gg
         | 
| 27 | 
            +
            from . import common, gelu, mpu
         | 
| 28 | 
            +
            from .activation_fn import act_fn
         | 
| 29 | 
            +
            from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ScaleGradient(torch.autograd.Function):
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                @staticmethod
         | 
| 34 | 
            +
                @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
         | 
| 35 | 
            +
                def forward(ctx: Any, x: torch.Tensor, scale: float):
         | 
| 36 | 
            +
                    ctx.scale = scale
         | 
| 37 | 
            +
                    return x
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                @staticmethod
         | 
| 40 | 
            +
                @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
         | 
| 41 | 
            +
                def backward(ctx: torch.Tensor, grad: torch.Tensor):
         | 
| 42 | 
            +
                    return grad * ctx.scale, None
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            scale_gradient = ScaleGradient.apply
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def resolve_dtensor(weight: torch.Tensor):
         | 
| 49 | 
            +
                if version.parse(torch.__version__) >= version.parse('2.0.0'):
         | 
| 50 | 
            +
                    from torch.distributed._tensor import DTensor
         | 
| 51 | 
            +
                    if isinstance(weight, DTensor):
         | 
| 52 | 
            +
                        return weight.to_local()
         | 
| 53 | 
            +
                return weight
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def create_moe_expert_weights(
         | 
| 57 | 
            +
                args: Arguments,
         | 
| 58 | 
            +
                num_experts: int,
         | 
| 59 | 
            +
                ffn_hidden_size: int,
         | 
| 60 | 
            +
                hidden_size: int,
         | 
| 61 | 
            +
                init_method: InitFn,
         | 
| 62 | 
            +
            ):
         | 
| 63 | 
            +
                # Create the entire weight matrix such that the sampled weights will
         | 
| 64 | 
            +
                # not vary between data parallelism and expert model parallelism for
         | 
| 65 | 
            +
                # the same random seed.
         | 
| 66 | 
            +
                master_weights = torch.empty(
         | 
| 67 | 
            +
                    num_experts,
         | 
| 68 | 
            +
                    ffn_hidden_size,
         | 
| 69 | 
            +
                    hidden_size,
         | 
| 70 | 
            +
                    device=args.device,
         | 
| 71 | 
            +
                    dtype=common.dtype(args),
         | 
| 72 | 
            +
                )
         | 
| 73 | 
            +
                init_method(master_weights)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                if not args.moe_expert_model_parallelism:
         | 
| 76 | 
            +
                    return master_weights
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                # Calculate the amount of sharding in each dimension.
         | 
| 79 | 
            +
                expert_sharding_degree = mpu.expert_sharding_degree(args)
         | 
| 80 | 
            +
                hidden_sharding_degree = mpu.hidden_sharding_degree(args)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # Calculate the experts per rank.
         | 
| 83 | 
            +
                #
         | 
| 84 | 
            +
                # NOTE: We assign ranks to be expert parallel before going
         | 
| 85 | 
            +
                # tensor parallel.
         | 
| 86 | 
            +
                rank = mpu.get_expert_parallel_rank(args)
         | 
| 87 | 
            +
                expert_rank = rank % expert_sharding_degree
         | 
| 88 | 
            +
                num_experts_per_rank = num_experts // expert_sharding_degree
         | 
| 89 | 
            +
                start_expert = expert_rank * num_experts_per_rank
         | 
| 90 | 
            +
                end_expert = (expert_rank + 1) * num_experts_per_rank
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                # Calculate the rows per rank.
         | 
| 93 | 
            +
                row_rank = rank // expert_sharding_degree
         | 
| 94 | 
            +
                num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
         | 
| 95 | 
            +
                start_row = row_rank * num_rows_per_rank
         | 
| 96 | 
            +
                end_row = (row_rank + 1) * num_rows_per_rank
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                # Slice the weight matrix to get the chunk for this rank.
         | 
| 99 | 
            +
                with torch.no_grad():
         | 
| 100 | 
            +
                    weights = master_weights[start_expert:end_expert, start_row:end_row]
         | 
| 101 | 
            +
                return weights
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            class MLP(torch.nn.Module):
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def __init__(self, args: Arguments):
         | 
| 107 | 
            +
                    super().__init__()
         | 
| 108 | 
            +
                    self.args = args
         | 
| 109 | 
            +
                    # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
         | 
| 110 | 
            +
                    experts_per_rank = mpu.experts_per_rank(args)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.w1 = torch.nn.Parameter(
         | 
| 113 | 
            +
                        torch.empty(
         | 
| 114 | 
            +
                            experts_per_rank,
         | 
| 115 | 
            +
                            args.hidden_size,
         | 
| 116 | 
            +
                            mpu.features_per_rank(args),
         | 
| 117 | 
            +
                            device=args.device,
         | 
| 118 | 
            +
                            dtype=common.dtype(args),
         | 
| 119 | 
            +
                        ),
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    self.w2 = torch.nn.Parameter(
         | 
| 122 | 
            +
                        torch.empty(
         | 
| 123 | 
            +
                            experts_per_rank,
         | 
| 124 | 
            +
                            mpu.features_per_rank(args),
         | 
| 125 | 
            +
                            args.hidden_size,
         | 
| 126 | 
            +
                            device=args.device,
         | 
| 127 | 
            +
                            dtype=common.dtype(args),
         | 
| 128 | 
            +
                        ),
         | 
| 129 | 
            +
                    )
         | 
| 130 | 
            +
                    mpu.set_expert_model_parallel_attributes(
         | 
| 131 | 
            +
                        self.w1,
         | 
| 132 | 
            +
                        args.moe_expert_model_parallelism,
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
                    mpu.set_expert_model_parallel_attributes(
         | 
| 135 | 
            +
                        self.w2,
         | 
| 136 | 
            +
                        args.moe_expert_model_parallelism,
         | 
| 137 | 
            +
                    )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # Initialize the parameters for the MLP.
         | 
| 140 | 
            +
                    #
         | 
| 141 | 
            +
                    # NOTE: It is important that we create the weight tensors prior
         | 
| 142 | 
            +
                    # to creating the master weights and slicing our the piece for
         | 
| 143 | 
            +
                    # this rank. If the master weights are created first the PyTorch
         | 
| 144 | 
            +
                    # caching allocator appears to use the same memory block for these
         | 
| 145 | 
            +
                    # and the slice which causes large increases in our peak memory
         | 
| 146 | 
            +
                    # usage.
         | 
| 147 | 
            +
                    with torch.no_grad():
         | 
| 148 | 
            +
                        w1 = create_moe_expert_weights(
         | 
| 149 | 
            +
                            args,
         | 
| 150 | 
            +
                            args.moe_num_experts,
         | 
| 151 | 
            +
                            args.ffn_hidden_size,
         | 
| 152 | 
            +
                            args.hidden_size,
         | 
| 153 | 
            +
                            args.init_method,
         | 
| 154 | 
            +
                        )
         | 
| 155 | 
            +
                        self.w1.copy_(w1.transpose(1, 2).contiguous())
         | 
| 156 | 
            +
                        self.w2.copy_(
         | 
| 157 | 
            +
                            create_moe_expert_weights(
         | 
| 158 | 
            +
                                args,
         | 
| 159 | 
            +
                                args.moe_num_experts,
         | 
| 160 | 
            +
                                args.ffn_hidden_size,
         | 
| 161 | 
            +
                                args.hidden_size,
         | 
| 162 | 
            +
                                args.output_layer_init_method,
         | 
| 163 | 
            +
                            ),
         | 
| 164 | 
            +
                        )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    self.gradient_scale = None
         | 
| 167 | 
            +
                    if self.args.moe_expert_model_parallelism:
         | 
| 168 | 
            +
                        self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def scale_grad(self, w):
         | 
| 171 | 
            +
                    if self.gradient_scale is None:
         | 
| 172 | 
            +
                        return w
         | 
| 173 | 
            +
                    return scale_gradient(w, self.gradient_scale)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def forward(self, x):
         | 
| 176 | 
            +
                    w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
         | 
| 177 | 
            +
                    w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
         | 
| 178 | 
            +
                    x = torch.bmm(x, w1)
         | 
| 179 | 
            +
                    x = self.args.activation_fn(x)
         | 
| 180 | 
            +
                    return torch.bmm(x, w2)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            def create_dmoe_expert_weights(
         | 
| 184 | 
            +
                args: Arguments,
         | 
| 185 | 
            +
                num_experts: int,
         | 
| 186 | 
            +
                rows: int,
         | 
| 187 | 
            +
                columns: int,
         | 
| 188 | 
            +
                init_method: InitFn,
         | 
| 189 | 
            +
            ):
         | 
| 190 | 
            +
                weights = create_moe_expert_weights(
         | 
| 191 | 
            +
                    args,
         | 
| 192 | 
            +
                    num_experts,
         | 
| 193 | 
            +
                    rows,
         | 
| 194 | 
            +
                    columns,
         | 
| 195 | 
            +
                    init_method,
         | 
| 196 | 
            +
                )
         | 
| 197 | 
            +
                return weights.view([-1, columns])
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class MemoryOptimizedMLP(torch.autograd.Function):
         | 
| 201 | 
            +
                """Sparse MLP with manually scheduled memory reuse."""
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                @staticmethod
         | 
| 204 | 
            +
                @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
         | 
| 205 | 
            +
                def forward(ctx, x, w1, w2, topo, activation_fn):
         | 
| 206 | 
            +
                    # Cast inputs using ctx dtype from AMP
         | 
| 207 | 
            +
                    if ctx._fwd_used_autocast:
         | 
| 208 | 
            +
                        x = x.to(ctx._dtype)
         | 
| 209 | 
            +
                        w1 = w1.to(ctx._dtype)
         | 
| 210 | 
            +
                        w2 = w2.to(ctx._dtype)
         | 
| 211 | 
            +
                    # x: [m, k], w1: [n, k], w2: [n, k]
         | 
| 212 | 
            +
                    if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
         | 
| 213 | 
            +
                        raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    topo_tensors = (
         | 
| 216 | 
            +
                        topo.row_indices,
         | 
| 217 | 
            +
                        topo.column_indices,
         | 
| 218 | 
            +
                        topo.offsets,
         | 
| 219 | 
            +
                        topo.column_indices_t,
         | 
| 220 | 
            +
                        topo.offsets_t,
         | 
| 221 | 
            +
                        topo.block_offsets_t,
         | 
| 222 | 
            +
                    )
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # Layer 0: x @ w1.t().
         | 
| 225 | 
            +
                    sdd_out = stk.ops.sdd(x, w1.t(), topo)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # GeLU.
         | 
| 228 | 
            +
                    activation_fn_out = act_fn(sdd_out, activation_fn)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # Layer 1: x @ w2.
         | 
| 231 | 
            +
                    dsd_out = stk.ops.dsd(activation_fn_out, w2)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # NOTE: Save the input to the layer and the activation_fn input for
         | 
| 234 | 
            +
                    # gradient computation. We'll re-compute the activation_fn forward
         | 
| 235 | 
            +
                    # pass in the backward pass to avoid materializing another
         | 
| 236 | 
            +
                    # intermediate.
         | 
| 237 | 
            +
                    ctx.shape = topo.shape
         | 
| 238 | 
            +
                    ctx.x_shape = x.shape
         | 
| 239 | 
            +
                    ctx.sdd_out_shape = sdd_out.data.shape
         | 
| 240 | 
            +
                    ctx.dtype = x.dtype
         | 
| 241 | 
            +
                    ctx.activation_fn = activation_fn
         | 
| 242 | 
            +
                    ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
         | 
| 243 | 
            +
                    return dsd_out
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                @staticmethod
         | 
| 246 | 
            +
                @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
         | 
| 247 | 
            +
                def backward(ctx, ddsd_out):
         | 
| 248 | 
            +
                    if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
         | 
| 249 | 
            +
                        raise ValueError('Expected all MLP inputs to need grad.')
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    # unpack saved tensors
         | 
| 252 | 
            +
                    # dtype = ctx.dtype
         | 
| 253 | 
            +
                    saved_tensors = ctx.saved_tensors
         | 
| 254 | 
            +
                    w1, w2 = saved_tensors[:2]
         | 
| 255 | 
            +
                    topo_tensors = saved_tensors[2:8]
         | 
| 256 | 
            +
                    x = saved_tensors[8]
         | 
| 257 | 
            +
                    sdd_out_data = saved_tensors[9]
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    # rematerialize activation function output
         | 
| 260 | 
            +
                    activation_fn = ctx.activation_fn
         | 
| 261 | 
            +
                    sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
         | 
| 262 | 
            +
                    activation_fn_out, activation_grad_fn = act_fn(
         | 
| 263 | 
            +
                        sdd_out,
         | 
| 264 | 
            +
                        activation_fn,
         | 
| 265 | 
            +
                        return_grad_fn=True,
         | 
| 266 | 
            +
                    )
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    # Compute dw2 with recomputed activation_fn output.
         | 
| 269 | 
            +
                    dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    # Compute dactivation_fn_out.
         | 
| 272 | 
            +
                    #
         | 
| 273 | 
            +
                    # NOTE: We reuse the activation_fn_out allocation.
         | 
| 274 | 
            +
                    dactivation_fn_out = activation_fn_out
         | 
| 275 | 
            +
                    stk.backend.triton_kernels.sdd(
         | 
| 276 | 
            +
                        ddsd_out,
         | 
| 277 | 
            +
                        w2.t(),
         | 
| 278 | 
            +
                        dactivation_fn_out.shape,
         | 
| 279 | 
            +
                        dactivation_fn_out.data,
         | 
| 280 | 
            +
                        dactivation_fn_out.offsets,
         | 
| 281 | 
            +
                        dactivation_fn_out.row_indices,
         | 
| 282 | 
            +
                        dactivation_fn_out.column_indices,
         | 
| 283 | 
            +
                    )
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    # Compute dsdd_out.
         | 
| 286 | 
            +
                    #
         | 
| 287 | 
            +
                    # NOTE: This reuses the dactivation_fn_out allocation.
         | 
| 288 | 
            +
                    if activation_fn is DEFAULT_ACTIVATION_FN:
         | 
| 289 | 
            +
                        dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
         | 
| 290 | 
            +
                    else:
         | 
| 291 | 
            +
                        assert activation_grad_fn is not None
         | 
| 292 | 
            +
                        activation_grad_fn(dactivation_fn_out.data)
         | 
| 293 | 
            +
                        dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    # Compute dw1.
         | 
| 296 | 
            +
                    dw1 = stk.ops.dsd(dsdd_out.t(), x)
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # Compute dx.
         | 
| 299 | 
            +
                    #
         | 
| 300 | 
            +
                    # NOTE: This reuses the ddsd_out allocation.
         | 
| 301 | 
            +
                    stk.backend.triton_kernels.dsd(
         | 
| 302 | 
            +
                        dsdd_out.shape,
         | 
| 303 | 
            +
                        dsdd_out.data,
         | 
| 304 | 
            +
                        dsdd_out.offsets,
         | 
| 305 | 
            +
                        dsdd_out.row_indices,
         | 
| 306 | 
            +
                        dsdd_out.column_indices,
         | 
| 307 | 
            +
                        dsdd_out.offsets_t,
         | 
| 308 | 
            +
                        dsdd_out.column_indices_t,
         | 
| 309 | 
            +
                        dsdd_out.block_offsets_t,
         | 
| 310 | 
            +
                        False,
         | 
| 311 | 
            +
                        w1,
         | 
| 312 | 
            +
                        ddsd_out,
         | 
| 313 | 
            +
                    )
         | 
| 314 | 
            +
                    dx = ddsd_out
         | 
| 315 | 
            +
                    return dx, dw1, dw2, None, None
         | 
| 316 | 
            +
             | 
| 317 | 
            +
             | 
| 318 | 
            +
            memory_optimized_mlp = MemoryOptimizedMLP.apply
         | 
| 319 | 
            +
             | 
| 320 | 
            +
             | 
| 321 | 
            +
            class SparseMLP(torch.nn.Module):
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                def __init__(self, args: Arguments):
         | 
| 324 | 
            +
                    super().__init__()
         | 
| 325 | 
            +
                    self.args = args
         | 
| 326 | 
            +
                    self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    self.w1 = torch.nn.Parameter(
         | 
| 329 | 
            +
                        torch.empty(
         | 
| 330 | 
            +
                            self._num_rows_per_rank,
         | 
| 331 | 
            +
                            args.hidden_size,
         | 
| 332 | 
            +
                            device=args.device,
         | 
| 333 | 
            +
                            dtype=common.dtype(args),
         | 
| 334 | 
            +
                        ),
         | 
| 335 | 
            +
                    )
         | 
| 336 | 
            +
                    self.w2 = torch.nn.Parameter(
         | 
| 337 | 
            +
                        torch.empty(
         | 
| 338 | 
            +
                            self._num_rows_per_rank,
         | 
| 339 | 
            +
                            args.hidden_size,
         | 
| 340 | 
            +
                            device=args.device,
         | 
| 341 | 
            +
                            dtype=common.dtype(args),
         | 
| 342 | 
            +
                        ),
         | 
| 343 | 
            +
                    )
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    # Initialize the parameters for the MLP.
         | 
| 346 | 
            +
                    #
         | 
| 347 | 
            +
                    # NOTE: It is important that we create the weight tensors prior
         | 
| 348 | 
            +
                    # to creating the master weights and slicing our the piece for
         | 
| 349 | 
            +
                    # this rank. If the master weights are created first the PyTorch
         | 
| 350 | 
            +
                    # caching allocator appears to use the same memory block for these
         | 
| 351 | 
            +
                    # and the slice which causes large increases in our peak memory
         | 
| 352 | 
            +
                    # usage.
         | 
| 353 | 
            +
                    with torch.no_grad():
         | 
| 354 | 
            +
                        self.w1.copy_(
         | 
| 355 | 
            +
                            create_dmoe_expert_weights(
         | 
| 356 | 
            +
                                args,
         | 
| 357 | 
            +
                                args.moe_num_experts,
         | 
| 358 | 
            +
                                args.ffn_hidden_size,
         | 
| 359 | 
            +
                                args.hidden_size,
         | 
| 360 | 
            +
                                args.init_method,
         | 
| 361 | 
            +
                            ),
         | 
| 362 | 
            +
                        )
         | 
| 363 | 
            +
                        self.w2.copy_(
         | 
| 364 | 
            +
                            create_dmoe_expert_weights(
         | 
| 365 | 
            +
                                args,
         | 
| 366 | 
            +
                                args.moe_num_experts,
         | 
| 367 | 
            +
                                args.ffn_hidden_size,
         | 
| 368 | 
            +
                                args.hidden_size,
         | 
| 369 | 
            +
                                args.output_layer_init_method,
         | 
| 370 | 
            +
                            ),
         | 
| 371 | 
            +
                        )
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
         | 
| 374 | 
            +
                    mpu.set_expert_model_parallel_attributes(
         | 
| 375 | 
            +
                        self.w1,
         | 
| 376 | 
            +
                        self._should_set_parallelism_attribute,
         | 
| 377 | 
            +
                    )
         | 
| 378 | 
            +
                    mpu.set_expert_model_parallel_attributes(
         | 
| 379 | 
            +
                        self.w2,
         | 
| 380 | 
            +
                        self._should_set_parallelism_attribute,
         | 
| 381 | 
            +
                    )
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    self.gradient_scale = None
         | 
| 384 | 
            +
                    if self.args.moe_expert_model_parallelism:
         | 
| 385 | 
            +
                        self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                def scale_grad(self, w):
         | 
| 388 | 
            +
                    if self.gradient_scale is None:
         | 
| 389 | 
            +
                        return w
         | 
| 390 | 
            +
                    return scale_gradient(w, self.gradient_scale)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                def forward(self, x, topo):
         | 
| 393 | 
            +
                    w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
         | 
| 394 | 
            +
                    w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
         | 
| 395 | 
            +
                    if self.args.memory_optimized_mlp:
         | 
| 396 | 
            +
                        return memory_optimized_mlp(
         | 
| 397 | 
            +
                            x,
         | 
| 398 | 
            +
                            w1,
         | 
| 399 | 
            +
                            w2,
         | 
| 400 | 
            +
                            topo,
         | 
| 401 | 
            +
                            self.args.activation_fn,
         | 
| 402 | 
            +
                        )
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                    # Compute the MLP.
         | 
| 405 | 
            +
                    x = stk.ops.sdd(x, w1.t(), topo)
         | 
| 406 | 
            +
                    activation_fn_out = act_fn(x, self.args.activation_fn)
         | 
| 407 | 
            +
                    return stk.ops.dsd(activation_fn_out, w2)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
             | 
| 410 | 
            +
            class MemoryOptimizedGroupedMLP(torch.autograd.Function):
         | 
| 411 | 
            +
                """GroupedMLP with manually scheduled memory reuse."""
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                @staticmethod
         | 
| 414 | 
            +
                @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
         | 
| 415 | 
            +
                def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
         | 
| 416 | 
            +
                    # Cast inputs using ctx dtype from AMP
         | 
| 417 | 
            +
                    if ctx._fwd_used_autocast:
         | 
| 418 | 
            +
                        x = x.to(ctx._dtype)
         | 
| 419 | 
            +
                        w1 = w1.to(ctx._dtype)
         | 
| 420 | 
            +
                        w2 = w2.to(ctx._dtype)
         | 
| 421 | 
            +
                    # x: [m, k], w1: [n, k], w2: [n, k]
         | 
| 422 | 
            +
                    if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
         | 
| 423 | 
            +
                        raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    # Layer 0: x @ w1.t().
         | 
| 426 | 
            +
                    assert gg.backend is not None
         | 
| 427 | 
            +
                    sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    # activation_fn
         | 
| 430 | 
            +
                    activation_fn_out = activation_fn(sdd_out)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    # Layer 1: x @ w2.
         | 
| 433 | 
            +
                    dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    # NOTE: Save the input to the layer and the activation_fn input for
         | 
| 436 | 
            +
                    # gradient computation. We'll re-compute the activation_fn forward
         | 
| 437 | 
            +
                    # pass in the backward pass to avoid materializing another
         | 
| 438 | 
            +
                    # intermediate.
         | 
| 439 | 
            +
                    ctx.x_shape = x.shape
         | 
| 440 | 
            +
                    ctx.sdd_out_shape = sdd_out.shape
         | 
| 441 | 
            +
                    ctx.dtype = x.dtype
         | 
| 442 | 
            +
                    ctx.activation_fn = activation_fn
         | 
| 443 | 
            +
                    ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
         | 
| 444 | 
            +
                    return dsd_out
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                @staticmethod
         | 
| 447 | 
            +
                @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
         | 
| 448 | 
            +
                def backward(ctx: Any, ddsd_out: torch.Tensor):
         | 
| 449 | 
            +
                    if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
         | 
| 450 | 
            +
                        raise ValueError('Expected all MLP inputs to need grad.')
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                    # Unpack saved tensors
         | 
| 453 | 
            +
                    # dtype = ctx.dtype
         | 
| 454 | 
            +
                    saved_tensors = ctx.saved_tensors
         | 
| 455 | 
            +
                    w1, w2 = saved_tensors[:2]
         | 
| 456 | 
            +
                    batch_sizes = saved_tensors[2]
         | 
| 457 | 
            +
                    x = saved_tensors[3]
         | 
| 458 | 
            +
                    sdd_out = saved_tensors[4]
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    # Rematerialize activation_fn output.
         | 
| 461 | 
            +
                    activation_fn = ctx.activation_fn
         | 
| 462 | 
            +
                    with torch.set_grad_enabled(True):
         | 
| 463 | 
            +
                        sdd_out.requires_grad = True
         | 
| 464 | 
            +
                        activation_fn_out = activation_fn(sdd_out)
         | 
| 465 | 
            +
                        activation_grad_fn = activation_fn_out.backward
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    # Compute dw2 with recomputed activation_fn output.
         | 
| 468 | 
            +
                    assert gg.backend is not None
         | 
| 469 | 
            +
                    dw2 = gg.backend.gmm(
         | 
| 470 | 
            +
                        activation_fn_out,
         | 
| 471 | 
            +
                        ddsd_out,
         | 
| 472 | 
            +
                        batch_sizes,
         | 
| 473 | 
            +
                        trans_a=True,
         | 
| 474 | 
            +
                    )
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    # Compute dactivation_fn_out.
         | 
| 477 | 
            +
                    #
         | 
| 478 | 
            +
                    # NOTE: We reuse the activation_fn_out allocation.
         | 
| 479 | 
            +
                    dactivation_fn_out = activation_fn_out
         | 
| 480 | 
            +
                    gg.backend.gmm(
         | 
| 481 | 
            +
                        ddsd_out,
         | 
| 482 | 
            +
                        w2,
         | 
| 483 | 
            +
                        batch_sizes,
         | 
| 484 | 
            +
                        trans_b=True,
         | 
| 485 | 
            +
                        c=dactivation_fn_out,
         | 
| 486 | 
            +
                    )
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    # Compute dsdd_out.
         | 
| 489 | 
            +
                    #
         | 
| 490 | 
            +
                    # NOTE: This reuses the dactivation_fn_out allocation.
         | 
| 491 | 
            +
                    if activation_fn is DEFAULT_ACTIVATION_FN:
         | 
| 492 | 
            +
                        dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
         | 
| 493 | 
            +
                    else:
         | 
| 494 | 
            +
                        assert activation_grad_fn is not None
         | 
| 495 | 
            +
                        activation_grad_fn(dactivation_fn_out)
         | 
| 496 | 
            +
                        dsdd_out = sdd_out.grad
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    # Compute dw1.
         | 
| 499 | 
            +
                    dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    # Compute dx.
         | 
| 502 | 
            +
                    #
         | 
| 503 | 
            +
                    # NOTE: This reuses the ddsd_out allocation.
         | 
| 504 | 
            +
                    gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
         | 
| 505 | 
            +
                    dx = ddsd_out
         | 
| 506 | 
            +
                    return dx, dw1, dw2, None, None
         | 
| 507 | 
            +
             | 
| 508 | 
            +
             | 
| 509 | 
            +
            memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
         | 
| 510 | 
            +
             | 
| 511 | 
            +
             | 
| 512 | 
            +
            class GroupedMLP(SparseMLP):
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                def forward(self, x, tokens_per_expert):
         | 
| 515 | 
            +
                    batch_sizes = tokens_per_expert.cpu().to(torch.long)
         | 
| 516 | 
            +
                    w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    # Re-shape the weights for the grouped GEMMs.
         | 
| 519 | 
            +
                    ne = mpu.experts_per_rank(self.args)
         | 
| 520 | 
            +
                    w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
         | 
| 521 | 
            +
                    w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    if self.args.memory_optimized_mlp:
         | 
| 524 | 
            +
                        return memory_optimized_grouped_mlp(
         | 
| 525 | 
            +
                            x,
         | 
| 526 | 
            +
                            w1,
         | 
| 527 | 
            +
                            w2,
         | 
| 528 | 
            +
                            batch_sizes,
         | 
| 529 | 
            +
                            self.args.activation_fn,
         | 
| 530 | 
            +
                        )
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    # Compute the MLP.
         | 
| 533 | 
            +
                    assert gg.ops is not None
         | 
| 534 | 
            +
                    x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
         | 
| 535 | 
            +
                    x = self.args.activation_fn(x)
         | 
| 536 | 
            +
                    return gg.ops.gmm(x, w2, batch_sizes)
         | 
| 537 | 
            +
             | 
| 538 | 
            +
             | 
| 539 | 
            +
            class SharedMLP(torch.nn.Module):
         | 
| 540 | 
            +
                """MLP for shared expert.
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
         | 
| 543 | 
            +
                """
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                def __init__(self, args: Arguments):
         | 
| 546 | 
            +
                    super().__init__()
         | 
| 547 | 
            +
                    self.args = args
         | 
| 548 | 
            +
                    self.fc_kwargs: dict[str, Any] = {
         | 
| 549 | 
            +
                        'bias': args.bias,
         | 
| 550 | 
            +
                        'device': args.device,
         | 
| 551 | 
            +
                    }
         | 
| 552 | 
            +
                    self.fc_kwargs.update(args.fc_kwargs)
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                    self.up_proj = args.fc_cls(
         | 
| 555 | 
            +
                        args.hidden_size,
         | 
| 556 | 
            +
                        args.shared_expert_hidden_size,
         | 
| 557 | 
            +
                        **self.fc_kwargs,
         | 
| 558 | 
            +
                    )
         | 
| 559 | 
            +
                    self.act = args.activation_fn
         | 
| 560 | 
            +
                    self.down_proj = args.fc_cls(
         | 
| 561 | 
            +
                        args.shared_expert_hidden_size,
         | 
| 562 | 
            +
                        args.hidden_size,
         | 
| 563 | 
            +
                        **self.fc_kwargs,
         | 
| 564 | 
            +
                    )
         | 
| 565 | 
            +
                    self.down_proj._is_residual = True  # a flag for llm-foundry init
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                def add_experts_sharedexpert(
         | 
| 568 | 
            +
                    self,
         | 
| 569 | 
            +
                    shared_expert_out: torch.Tensor,
         | 
| 570 | 
            +
                    expert_out: torch.Tensor,
         | 
| 571 | 
            +
                ) -> torch.Tensor:
         | 
| 572 | 
            +
                    # Helper function to add expert output to shared expert output
         | 
| 573 | 
            +
                    # with optional weighted sum.
         | 
| 574 | 
            +
                    if self.args.shared_expert_weighted_sum:
         | 
| 575 | 
            +
                        # enable using weighted sum for shared expert output
         | 
| 576 | 
            +
                        # wieghted by number of experts used
         | 
| 577 | 
            +
                        t_experts = self.args.moe_top_k + 1
         | 
| 578 | 
            +
                        sh_mlp_out = shared_expert_out / t_experts
         | 
| 579 | 
            +
                        return sh_mlp_out.add(
         | 
| 580 | 
            +
                            expert_out,
         | 
| 581 | 
            +
                            alpha=(self.args.moe_top_k / t_experts),
         | 
| 582 | 
            +
                        )
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    return shared_expert_out + expert_out
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 587 | 
            +
                    return self.down_proj(self.act(self.up_proj(x)))
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/moe.py
    ADDED
    
    | @@ -0,0 +1,507 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            from typing import Optional, Tuple
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.distributed as dist
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # import megablocks.ops as ops
         | 
| 10 | 
            +
            # from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
         | 
| 11 | 
            +
            # from megablocks.layers.all_to_all import all_to_all
         | 
| 12 | 
            +
            # from megablocks.layers.arguments import Arguments
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from ..ops import (
         | 
| 15 | 
            +
                sort,
         | 
| 16 | 
            +
                histogram,
         | 
| 17 | 
            +
                inclusive_cumsum,
         | 
| 18 | 
            +
                exclusive_cumsum,
         | 
| 19 | 
            +
                binned_gather,
         | 
| 20 | 
            +
                binned_scatter,
         | 
| 21 | 
            +
                gather,
         | 
| 22 | 
            +
                scatter,
         | 
| 23 | 
            +
                repeat,
         | 
| 24 | 
            +
                replicate,
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            from . import common, mlp, mpu, router, sharedexpert_registry
         | 
| 28 | 
            +
            from .arguments import Arguments
         | 
| 29 | 
            +
            from .all_to_all import all_to_all
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            _LOAD_BALANCING_LOSS = []
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def save_load_balancing_loss(loss):
         | 
| 35 | 
            +
                global _LOAD_BALANCING_LOSS
         | 
| 36 | 
            +
                _LOAD_BALANCING_LOSS.append(loss)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            def get_load_balancing_loss():
         | 
| 40 | 
            +
                global _LOAD_BALANCING_LOSS
         | 
| 41 | 
            +
                return _LOAD_BALANCING_LOSS
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def clear_load_balancing_loss():
         | 
| 45 | 
            +
                global _LOAD_BALANCING_LOSS
         | 
| 46 | 
            +
                _LOAD_BALANCING_LOSS.clear()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def batched_load_balancing_loss(args: Arguments):
         | 
| 50 | 
            +
                if args.moe_loss_weight == 0:
         | 
| 51 | 
            +
                    return 0.0
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                # tokens_per_expert[i].shape = (num_experts)
         | 
| 54 | 
            +
                # expert_scores[i].shape = (tokens, num_experts)
         | 
| 55 | 
            +
                tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
         | 
| 56 | 
            +
                num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
         | 
| 57 | 
            +
                if args.num_layers_per_virtual_pipeline_stage is not None:
         | 
| 58 | 
            +
                    num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                if len(tokens_per_expert) != num_layers_per_pipeline_stage:
         | 
| 61 | 
            +
                    raise ValueError(
         | 
| 62 | 
            +
                        f'Expected {num_layers_per_pipeline_stage} token_per_experts '
         | 
| 63 | 
            +
                        f'but found {len(tokens_per_expert)}.\nnum_layers = '
         | 
| 64 | 
            +
                        f'{args.num_layers}\npipeline_model_parallel_size = '
         | 
| 65 | 
            +
                        f'{args.pipeline_model_parallel_size}\n'
         | 
| 66 | 
            +
                        'num_layers_per_virtual_pipeline_stage'
         | 
| 67 | 
            +
                        f' = {args.num_layers_per_virtual_pipeline_stage}',
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
                if len(expert_scores) != num_layers_per_pipeline_stage:
         | 
| 70 | 
            +
                    raise ValueError(
         | 
| 71 | 
            +
                        f'Expected {num_layers_per_pipeline_stage} expert_scores '
         | 
| 72 | 
            +
                        f'but found {len(tokens_per_expert)}.\nnum_layers = '
         | 
| 73 | 
            +
                        f'{args.num_layers}\npipeline_model_parallel_size = '
         | 
| 74 | 
            +
                        f'{args.pipeline_model_parallel_size}\n'
         | 
| 75 | 
            +
                        'num_layers_per_virtual_pipeline_stage'
         | 
| 76 | 
            +
                        f' = {args.num_layers_per_virtual_pipeline_stage}',
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # Verify the shape of the tokens_per_expert and expert_scores tensors.
         | 
| 80 | 
            +
                assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                tokens = expert_scores[0].shape[0]
         | 
| 83 | 
            +
                assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                # Concatenate the contributions of each layer and convert to
         | 
| 86 | 
            +
                # the correct types and formats for the dot product.
         | 
| 87 | 
            +
                expert_scores = torch.cat(expert_scores, dim=1)
         | 
| 88 | 
            +
                if args.moe_lbl_in_fp32:
         | 
| 89 | 
            +
                    expert_scores = expert_scores.float()
         | 
| 90 | 
            +
                if tokens != 0:
         | 
| 91 | 
            +
                    expert_scores = expert_scores.mean(dim=0)
         | 
| 92 | 
            +
                else:
         | 
| 93 | 
            +
                    expert_scores = expert_scores.sum(dim=0)
         | 
| 94 | 
            +
                tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
         | 
| 97 | 
            +
                assert tokens_per_expert.numel() == expected_values
         | 
| 98 | 
            +
                assert expert_scores.numel() == expected_values
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                # Calculate the total scale across all factors.
         | 
| 101 | 
            +
                #
         | 
| 102 | 
            +
                # loss_weight * num_experts / (num_layers * tokens * top_k)
         | 
| 103 | 
            +
                scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
         | 
| 104 | 
            +
                scale_denominator = (args.num_layers * tokens * args.moe_top_k)
         | 
| 105 | 
            +
                scale = scale_numerator / scale_denominator
         | 
| 106 | 
            +
                return scale * torch.dot(tokens_per_expert, expert_scores)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            # NOTE: This class defines MoE expert computation, including expert model parallel
         | 
| 110 | 
            +
            # communication. When using FSDP on top of MegaBlocks this is the module that should
         | 
| 111 | 
            +
            # be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
         | 
| 112 | 
            +
            # parallel all2all.
         | 
| 113 | 
            +
            class ParallelMLP(torch.nn.Module):
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def __init__(self, args: Arguments):
         | 
| 116 | 
            +
                    super(ParallelMLP, self).__init__()
         | 
| 117 | 
            +
                    self.args = args
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # Calculate the number of experts in total and the number of experts
         | 
| 120 | 
            +
                    # owned by this rank.
         | 
| 121 | 
            +
                    # world_size = mpu.get_expert_parallel_world_size(args)
         | 
| 122 | 
            +
                    self.num_experts = args.moe_num_experts
         | 
| 123 | 
            +
                    self.top_k = self.args.moe_top_k
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # Calculate the number of bits needed to represent the expert indices
         | 
| 126 | 
            +
                    # so that we can pass it to radix sort.
         | 
| 127 | 
            +
                    self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # Expert MLP.
         | 
| 130 | 
            +
                    self.mlp = mlp.MLP(args)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    self.bias: Optional[torch.Tensor]
         | 
| 133 | 
            +
                    if self.args.bias:
         | 
| 134 | 
            +
                        # Note that the output bias is not parallelized with expert
         | 
| 135 | 
            +
                        # model parallelism.
         | 
| 136 | 
            +
                        self.bias = torch.nn.Parameter(
         | 
| 137 | 
            +
                            torch.empty(
         | 
| 138 | 
            +
                                args.hidden_size,
         | 
| 139 | 
            +
                                device=args.device,
         | 
| 140 | 
            +
                                dtype=common.dtype(args),
         | 
| 141 | 
            +
                            ),
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
                        torch.nn.init.zeros_(self.bias)
         | 
| 144 | 
            +
                    else:
         | 
| 145 | 
            +
                        self.register_parameter('bias', None)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # Select the forward function for the operating mode.
         | 
| 148 | 
            +
                    self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def expert_capacity(self, tokens: int) -> int:
         | 
| 151 | 
            +
                    world_size = mpu.get_expert_parallel_world_size(self.args)
         | 
| 152 | 
            +
                    tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
         | 
| 153 | 
            +
                    return int(self.args.moe_capacity_factor * tokens_per_expert)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
         | 
| 156 | 
            +
                    """Calculate the load balancing loss contribution."""
         | 
| 157 | 
            +
                    assert len(expert_scores.size()) == 2
         | 
| 158 | 
            +
                    tokens, num_experts = expert_scores.size()
         | 
| 159 | 
            +
                    assert num_experts == self.num_experts
         | 
| 160 | 
            +
                    assert len(tokens_per_expert.size()) == 1
         | 
| 161 | 
            +
                    num_experts, = tokens_per_expert.size()
         | 
| 162 | 
            +
                    assert num_experts == self.num_experts
         | 
| 163 | 
            +
                    scale = self.num_experts / (tokens * self.top_k)
         | 
| 164 | 
            +
                    return scale * torch.dot(
         | 
| 165 | 
            +
                        tokens_per_expert.to(expert_scores.dtype),
         | 
| 166 | 
            +
                        expert_scores.mean(dim=0),
         | 
| 167 | 
            +
                    )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def indices_and_bins(self,
         | 
| 170 | 
            +
                                     top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 171 | 
            +
                    # Sort the expert ids to produce the scatter/gather
         | 
| 172 | 
            +
                    # indices for the permutation.
         | 
| 173 | 
            +
                    #
         | 
| 174 | 
            +
                    # TODO(tgale): Is it worth doing this conversion to 32-bit
         | 
| 175 | 
            +
                    # prior? Could we place the `torch.max` operation to return
         | 
| 176 | 
            +
                    # 32-bit expert indices?
         | 
| 177 | 
            +
                    top_expert = top_expert.int()
         | 
| 178 | 
            +
                    # output = ops.sort(top_expert, self.sort_end_bit)
         | 
| 179 | 
            +
                    output = sort(top_expert, self.sort_end_bit)
         | 
| 180 | 
            +
                    assert output is not None
         | 
| 181 | 
            +
                    bin_ids, indices = output
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # Histogram the expert ids to identify the number of
         | 
| 184 | 
            +
                    # tokens routed to each expert.
         | 
| 185 | 
            +
                    #
         | 
| 186 | 
            +
                    # TODO(tgale): Does the sorted data produce a more favorable
         | 
| 187 | 
            +
                    # data distribution for histogram? Or is the op parallelism
         | 
| 188 | 
            +
                    # worth more?
         | 
| 189 | 
            +
                    # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
         | 
| 190 | 
            +
                    tokens_per_expert = histogram(top_expert, self.num_experts)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    # Calculate the bin bounds for the sorted tokens.
         | 
| 193 | 
            +
                    # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 194 | 
            +
                    bins = inclusive_cumsum(tokens_per_expert, 0)
         | 
| 195 | 
            +
                    assert bins is not None
         | 
| 196 | 
            +
                    bins = bins.view(1) if not len(bins.size()) else bins
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    assert isinstance(indices, torch.Tensor)
         | 
| 199 | 
            +
                    assert isinstance(bin_ids, torch.Tensor)
         | 
| 200 | 
            +
                    assert isinstance(bins, torch.Tensor)
         | 
| 201 | 
            +
                    assert isinstance(tokens_per_expert, torch.Tensor)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    return indices, bin_ids, bins, tokens_per_expert
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def permute_and_compute(
         | 
| 206 | 
            +
                    self,
         | 
| 207 | 
            +
                    x: torch.Tensor,
         | 
| 208 | 
            +
                    tokens_per_expert: int,  # unused
         | 
| 209 | 
            +
                    indices: torch.Tensor,
         | 
| 210 | 
            +
                    bin_ids: torch.Tensor,  # unused
         | 
| 211 | 
            +
                    expert_weights: torch.Tensor,
         | 
| 212 | 
            +
                    bins: torch.Tensor,
         | 
| 213 | 
            +
                    expert_capacity: int,
         | 
| 214 | 
            +
                    top_k: int,
         | 
| 215 | 
            +
                ):
         | 
| 216 | 
            +
                    # Route the tokens for MoE computation.
         | 
| 217 | 
            +
                    x = x.view(-1, x.shape[-1])
         | 
| 218 | 
            +
                    # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
         | 
| 219 | 
            +
                    output = binned_gather(x, indices, bins, expert_capacity, top_k)
         | 
| 220 | 
            +
                    assert output is not None
         | 
| 221 | 
            +
                    x = output
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    # Perform the expert computation. Note that we don't
         | 
| 224 | 
            +
                    # use biases for these linear operations.
         | 
| 225 | 
            +
                    x = self.mlp(x)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # Un-route the data for the MoE output.
         | 
| 228 | 
            +
                    # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
         | 
| 229 | 
            +
                    return binned_scatter(x, indices, expert_weights, bins, top_k)
         | 
| 230 | 
            +
                
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
         | 
| 233 | 
            +
                    # x: [sl, bs, hs]
         | 
| 234 | 
            +
                    # expert_weights: [sl * bs, top-k]
         | 
| 235 | 
            +
                    # top_experts: [sl * bs, top-k]
         | 
| 236 | 
            +
                    expert_weights = expert_weights.flatten()
         | 
| 237 | 
            +
                    top_experts = top_experts.flatten()
         | 
| 238 | 
            +
                    with torch.no_grad():
         | 
| 239 | 
            +
                        indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        # If expert_capacity is set to zero, set the number of tokens
         | 
| 242 | 
            +
                        # per expert to the maximum we need to avoid dropping tokens.
         | 
| 243 | 
            +
                        sl, bs, _ = x.size()
         | 
| 244 | 
            +
                        expert_capacity = self.expert_capacity(sl * bs)
         | 
| 245 | 
            +
                        if expert_capacity == 0:
         | 
| 246 | 
            +
                            expert_capacity = torch.max(tokens_per_expert).item()
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    x = self.permute_and_compute(
         | 
| 249 | 
            +
                        x,
         | 
| 250 | 
            +
                        tokens_per_expert,
         | 
| 251 | 
            +
                        indices,
         | 
| 252 | 
            +
                        bin_ids,
         | 
| 253 | 
            +
                        expert_weights,
         | 
| 254 | 
            +
                        bins,
         | 
| 255 | 
            +
                        expert_capacity,
         | 
| 256 | 
            +
                        self.top_k,
         | 
| 257 | 
            +
                    )
         | 
| 258 | 
            +
                    return x, tokens_per_expert
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
         | 
| 261 | 
            +
                    # NOTE: This function implements the same computation as forward_once
         | 
| 262 | 
            +
                    # but with expert model parallelism.
         | 
| 263 | 
            +
                    #
         | 
| 264 | 
            +
                    # 1. Permute the tokens locally so that they are grouped by their
         | 
| 265 | 
            +
                    # expert assignments. This allows us to transfer all of the tokens
         | 
| 266 | 
            +
                    # for a remote device in one communication primitive.
         | 
| 267 | 
            +
                    #
         | 
| 268 | 
            +
                    # 2. Permute the tokens across the expert parallel devices. After
         | 
| 269 | 
            +
                    # this is completed each device has all of the tokens assigned to
         | 
| 270 | 
            +
                    # its set of experts in its local HBM.
         | 
| 271 | 
            +
                    #
         | 
| 272 | 
            +
                    # 3. Permute the tokens locally so that they are grouped by their
         | 
| 273 | 
            +
                    # expert assignement. After the distributed permutation the tokens
         | 
| 274 | 
            +
                    # are grouped by which device they came from. We re-order them
         | 
| 275 | 
            +
                    # locally to allow for efficient computation.
         | 
| 276 | 
            +
                    #
         | 
| 277 | 
            +
                    # After this series of permutations we compute the linear layers
         | 
| 278 | 
            +
                    # and then repeat these three steps in reverse to produce the final
         | 
| 279 | 
            +
                    # output.
         | 
| 280 | 
            +
                    #
         | 
| 281 | 
            +
                    # Compute the mapping of local tokens to experts.
         | 
| 282 | 
            +
                    expert_weights = expert_weights.flatten()
         | 
| 283 | 
            +
                    top_experts = top_experts.flatten()
         | 
| 284 | 
            +
                    with torch.no_grad():
         | 
| 285 | 
            +
                        indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                        # If we're sharding the experts along the hidden dimension
         | 
| 288 | 
            +
                        # multiple devices own parts of the same sets of experts.
         | 
| 289 | 
            +
                        # Replicate the token counts so every device gets the counts.
         | 
| 290 | 
            +
                        # repeated_tokens_per_expert = ops.repeat(
         | 
| 291 | 
            +
                        repeated_tokens_per_expert = repeat(
         | 
| 292 | 
            +
                            tokens_per_expert,
         | 
| 293 | 
            +
                            (mpu.hidden_sharding_degree(self.args),),
         | 
| 294 | 
            +
                        )
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                        # Pass token count information to the device on which the
         | 
| 297 | 
            +
                        # target expert resides.
         | 
| 298 | 
            +
                        parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
         | 
| 299 | 
            +
                        tpe_handle = dist.all_to_all_single(
         | 
| 300 | 
            +
                            parallel_tokens_per_expert,
         | 
| 301 | 
            +
                            repeated_tokens_per_expert,
         | 
| 302 | 
            +
                            group=self.args.expert_parallel_group,
         | 
| 303 | 
            +
                            async_op=True,
         | 
| 304 | 
            +
                        )
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    # Permute locally and without any padding so that tokens for each
         | 
| 307 | 
            +
                    # parallel device are stored contiguously.
         | 
| 308 | 
            +
                    #
         | 
| 309 | 
            +
                    # This view updates the shape of the tensor from [sl, bs, hs] to
         | 
| 310 | 
            +
                    # [sl * bs, hs] prior to the permutation.
         | 
| 311 | 
            +
                    x = x.view(-1, x.shape[-1])
         | 
| 312 | 
            +
                    # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
         | 
| 313 | 
            +
                    output = gather(x, indices, bin_ids, bins, self.top_k)
         | 
| 314 | 
            +
                    assert output is not None
         | 
| 315 | 
            +
                    x = output
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    # Compute the number of tokens that will be received from each
         | 
| 318 | 
            +
                    # device and permute the input data across the devices.
         | 
| 319 | 
            +
                    with torch.no_grad():
         | 
| 320 | 
            +
                        tpe_handle.wait()
         | 
| 321 | 
            +
                        experts_per_rank = mpu.experts_per_rank(self.args)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                        # Reshape to [world_size, num_experts_per_rank].
         | 
| 324 | 
            +
                        world_size = mpu.get_expert_parallel_world_size(self.args)
         | 
| 325 | 
            +
                        repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
         | 
| 326 | 
            +
                        parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                        # TODO(tgale): It might be faster to do this on the GPU and
         | 
| 329 | 
            +
                        # then communicate the results back to the host.
         | 
| 330 | 
            +
                        send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
         | 
| 331 | 
            +
                        parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
         | 
| 332 | 
            +
                        recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                        # Convert the send/recv counts to lists.
         | 
| 335 | 
            +
                        send_counts = send_counts.tolist()
         | 
| 336 | 
            +
                        recv_counts = recv_counts.tolist()
         | 
| 337 | 
            +
                        tokens_received = sum(recv_counts)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    # If we're sharding the experts along the hidden dimension
         | 
| 340 | 
            +
                    # multiple devices own parts of the same sets of experts.
         | 
| 341 | 
            +
                    # Replicate the token counts so devices that share experts
         | 
| 342 | 
            +
                    # get all of the tokens assigned to them.
         | 
| 343 | 
            +
                    #
         | 
| 344 | 
            +
                    # TODO(tgale): Fuse this into the prior, local permutation.
         | 
| 345 | 
            +
                    # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
         | 
| 346 | 
            +
                    x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    # Start the cross-device permutation asynchronously so we can
         | 
| 349 | 
            +
                    # overlap communication with computation.
         | 
| 350 | 
            +
                    parallel_x, parallel_x_handle = all_to_all(
         | 
| 351 | 
            +
                        x,
         | 
| 352 | 
            +
                        recv_counts,
         | 
| 353 | 
            +
                        send_counts,
         | 
| 354 | 
            +
                        self.args.expert_parallel_group,
         | 
| 355 | 
            +
                        async_op=True,
         | 
| 356 | 
            +
                    )
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    with torch.no_grad():
         | 
| 359 | 
            +
                        # After we do the cross-device permutation we have the tokens on the
         | 
| 360 | 
            +
                        # correct device but not yet grouped by expert because we received
         | 
| 361 | 
            +
                        # tokens from each device as contiguous chunks. To group the tokens
         | 
| 362 | 
            +
                        # for expert computation we'll do one more local permutation. The
         | 
| 363 | 
            +
                        # rest of this torch.no_grad() scope sets up the indices and bins
         | 
| 364 | 
            +
                        # for this permutation.
         | 
| 365 | 
            +
                        # replicate_bins = ops.inclusive_cumsum(
         | 
| 366 | 
            +
                        replicate_bins = inclusive_cumsum(
         | 
| 367 | 
            +
                            parallel_tokens_per_expert.flatten(),
         | 
| 368 | 
            +
                            0,
         | 
| 369 | 
            +
                        )
         | 
| 370 | 
            +
                        replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                        # Construct the expert indices for the permuted tokens.
         | 
| 373 | 
            +
                        parallel_top_expert = torch.remainder(
         | 
| 374 | 
            +
                            torch.arange(
         | 
| 375 | 
            +
                                self.num_experts * mpu.hidden_sharding_degree(self.args),
         | 
| 376 | 
            +
                                dtype=torch.int32,
         | 
| 377 | 
            +
                                device=indices.device,
         | 
| 378 | 
            +
                            ),
         | 
| 379 | 
            +
                            mpu.experts_per_rank(self.args),
         | 
| 380 | 
            +
                        )
         | 
| 381 | 
            +
                        # parallel_top_expert = ops.replicate(
         | 
| 382 | 
            +
                        parallel_top_expert = replicate(
         | 
| 383 | 
            +
                            parallel_top_expert.unsqueeze(dim=0),
         | 
| 384 | 
            +
                            replicate_bins,
         | 
| 385 | 
            +
                            tokens_received,
         | 
| 386 | 
            +
                        ).flatten()
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                        # TODO(tgale): The sort_end_bit here can be reduced.
         | 
| 389 | 
            +
                        # parallel_bin_ids, parallel_indices = ops.sort(
         | 
| 390 | 
            +
                        parallel_bin_ids, parallel_indices = sort(
         | 
| 391 | 
            +
                            parallel_top_expert,
         | 
| 392 | 
            +
                            self.sort_end_bit,
         | 
| 393 | 
            +
                        )
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                        # Calculate the bins boundaries from the token counts.
         | 
| 396 | 
            +
                        parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
         | 
| 397 | 
            +
                            dim=0,
         | 
| 398 | 
            +
                            dtype=torch.int,
         | 
| 399 | 
            +
                        )
         | 
| 400 | 
            +
                        # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
         | 
| 401 | 
            +
                        parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
         | 
| 402 | 
            +
                        parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                        # If expert_capacity is set to zero, set the number of tokens
         | 
| 405 | 
            +
                        # per expert to the maximum we need to avoid dropping tokens.
         | 
| 406 | 
            +
                        tokens, _ = x.size()
         | 
| 407 | 
            +
                        expert_capacity = self.expert_capacity(tokens)
         | 
| 408 | 
            +
                        if expert_capacity == 0:
         | 
| 409 | 
            +
                            expert_capacity = torch.max(parallel_tokens_per_expert).item()
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    # Locally permute the tokens and perform the expert computation.
         | 
| 412 | 
            +
                    # Block to make sure that the cross-device permutation is complete.
         | 
| 413 | 
            +
                    if self.args.mlp_impl == 'grouped':
         | 
| 414 | 
            +
                        # GroupedMLP requires counts on CPU. We can use the tensor already
         | 
| 415 | 
            +
                        # moved to CPU for the prior all_to_all, which avoids an extra
         | 
| 416 | 
            +
                        # device synchronization.
         | 
| 417 | 
            +
                        parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
         | 
| 418 | 
            +
                            dim=0,
         | 
| 419 | 
            +
                            dtype=torch.int,
         | 
| 420 | 
            +
                        )
         | 
| 421 | 
            +
                    parallel_x_handle.wait()
         | 
| 422 | 
            +
                    parallel_x = self.permute_and_compute(
         | 
| 423 | 
            +
                        parallel_x,
         | 
| 424 | 
            +
                        parallel_tokens_per_expert,
         | 
| 425 | 
            +
                        parallel_indices,
         | 
| 426 | 
            +
                        parallel_bin_ids,
         | 
| 427 | 
            +
                        None,  # expert_weights
         | 
| 428 | 
            +
                        parallel_bins,
         | 
| 429 | 
            +
                        expert_capacity,
         | 
| 430 | 
            +
                        top_k=1,
         | 
| 431 | 
            +
                    )
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    # Un-permute the tokens across the devices.
         | 
| 434 | 
            +
                    x, _ = all_to_all(
         | 
| 435 | 
            +
                        parallel_x,
         | 
| 436 | 
            +
                        send_counts,
         | 
| 437 | 
            +
                        recv_counts,
         | 
| 438 | 
            +
                        self.args.expert_parallel_group,
         | 
| 439 | 
            +
                    )
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    # Reduce along the hidden sharding to get the final outputs.
         | 
| 442 | 
            +
                    #
         | 
| 443 | 
            +
                    # TODO(tgale): Fuse this into the following local permutation.
         | 
| 444 | 
            +
                    shape = (
         | 
| 445 | 
            +
                        mpu.hidden_sharding_degree(self.args),
         | 
| 446 | 
            +
                        -1,
         | 
| 447 | 
            +
                        self.args.hidden_size,
         | 
| 448 | 
            +
                    )
         | 
| 449 | 
            +
                    # x = ops.sum(x.view(shape), dim=0)
         | 
| 450 | 
            +
                    x = x.view(shape).sum(dim=0)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                    # Un-permute locally to setup for the next series of operations.
         | 
| 453 | 
            +
                    # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
         | 
| 454 | 
            +
                    x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
         | 
| 455 | 
            +
                    return x, tokens_per_expert.flatten()
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
         | 
| 458 | 
            +
                    in_shape = x.size()
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    # Compute the experts.
         | 
| 461 | 
            +
                    x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
         | 
| 462 | 
            +
                    if self.training and self.args.moe_loss_weight > 0:
         | 
| 463 | 
            +
                        save_load_balancing_loss((tokens_per_expert, scores))
         | 
| 464 | 
            +
                    x = x.view(in_shape)
         | 
| 465 | 
            +
                    if self.bias is not None:
         | 
| 466 | 
            +
                        if self.args.return_bias:
         | 
| 467 | 
            +
                            return x, self.bias
         | 
| 468 | 
            +
                        return x + self.bias
         | 
| 469 | 
            +
                    return x
         | 
| 470 | 
            +
             | 
| 471 | 
            +
             | 
| 472 | 
            +
            class MoE(torch.nn.Module):
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                def __init__(self, args: Arguments):
         | 
| 475 | 
            +
                    super(MoE, self).__init__()
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    # Token router.
         | 
| 478 | 
            +
                    self.router = router.LearnedRouter(args)
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    # Expert computation helper.
         | 
| 481 | 
            +
                    self.experts = self._init_experts_mlp(args)
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    self.shared_expert = None
         | 
| 484 | 
            +
                    if args.shared_expert:
         | 
| 485 | 
            +
                        # SharedExpert computation helper.
         | 
| 486 | 
            +
                        self.shared_expert = sharedexpert_registry.get(args)
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                def _init_experts_mlp(self, args: Arguments):
         | 
| 489 | 
            +
                    return ParallelMLP(args)
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 492 | 
            +
                    # NOTE: If we're going to cast the activations to lower precision
         | 
| 493 | 
            +
                    # do it before we permute the tokens to save bandwidth.
         | 
| 494 | 
            +
                    x = common.cast_if_autocast_enabled(x)
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    # Compute the expert scores and assignments.
         | 
| 497 | 
            +
                    scores, expert_weights, top_experts = self.router(x)
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                    # Compute the experts.
         | 
| 500 | 
            +
                    out = self.experts(x, scores, expert_weights, top_experts)
         | 
| 501 | 
            +
                    if self.shared_expert is not None:
         | 
| 502 | 
            +
                        shared_expert_out = self.shared_expert(x)
         | 
| 503 | 
            +
                        out = self.shared_expert.add_experts_sharedexpert(
         | 
| 504 | 
            +
                            shared_expert_out,
         | 
| 505 | 
            +
                            out,
         | 
| 506 | 
            +
                        )
         | 
| 507 | 
            +
                    return out
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/mpu.py
    ADDED
    
    | @@ -0,0 +1,94 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.distributed as dist
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # from megablocks.layers.arguments import Arguments
         | 
| 10 | 
            +
            from .arguments import Arguments
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class MoeParam(torch.Tensor):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def __init__(self):
         | 
| 16 | 
            +
                    super().__init__(self)
         | 
| 17 | 
            +
                    self.expert_model_parallel: bool
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def is_moe_param(tensor: torch.Tensor) -> bool:
         | 
| 21 | 
            +
                return hasattr(tensor, 'expert_model_parallel')
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def get_expert_parallel_world_size(args: Arguments) -> int:
         | 
| 25 | 
            +
                return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def get_expert_parallel_rank(args: Arguments) -> int:
         | 
| 29 | 
            +
                return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def set_expert_model_parallel_attributes(
         | 
| 33 | 
            +
                tensor: torch.Tensor,
         | 
| 34 | 
            +
                is_parallel: bool,
         | 
| 35 | 
            +
            ):
         | 
| 36 | 
            +
                assert not hasattr(tensor, 'expert_model_parallel')
         | 
| 37 | 
            +
                setattr(tensor, 'expert_model_parallel', is_parallel)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def param_is_expert_model_parallel(param: MoeParam) -> bool:
         | 
| 41 | 
            +
                return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def copy_expert_model_parallel_attributes(
         | 
| 45 | 
            +
                destination_tensor: torch.Tensor,
         | 
| 46 | 
            +
                source_tensor: torch.Tensor,
         | 
| 47 | 
            +
            ):
         | 
| 48 | 
            +
                if hasattr(source_tensor, 'expert_model_parallel'):
         | 
| 49 | 
            +
                    setattr(
         | 
| 50 | 
            +
                        destination_tensor,
         | 
| 51 | 
            +
                        'expert_model_parallel',
         | 
| 52 | 
            +
                        getattr(source_tensor, 'expert_model_parallel'),
         | 
| 53 | 
            +
                    )
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
         | 
| 57 | 
            +
                world_size = dist.get_world_size(group)
         | 
| 58 | 
            +
                rank = dist.get_rank(group)
         | 
| 59 | 
            +
                for i in range(world_size):
         | 
| 60 | 
            +
                    dist.barrier(group)
         | 
| 61 | 
            +
                    if i == rank:
         | 
| 62 | 
            +
                        print(f'rank = {rank}', *x)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            # Helpers for expert/tensor sharding.
         | 
| 66 | 
            +
            def expert_sharding_degree(args: Arguments) -> int:
         | 
| 67 | 
            +
                world_size = get_expert_parallel_world_size(args)
         | 
| 68 | 
            +
                esd = min(world_size, args.moe_num_experts)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                if (args.moe_num_experts % esd) != 0:
         | 
| 71 | 
            +
                    raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
         | 
| 72 | 
            +
                return esd
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def hidden_sharding_degree(args: Arguments) -> int:
         | 
| 76 | 
            +
                world_size = get_expert_parallel_world_size(args)
         | 
| 77 | 
            +
                esd = expert_sharding_degree(args)
         | 
| 78 | 
            +
                hsd = world_size // esd
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                if (args.ffn_hidden_size % hsd) != 0:
         | 
| 81 | 
            +
                    raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
         | 
| 82 | 
            +
                if (esd * hsd) != world_size:
         | 
| 83 | 
            +
                    raise ValueError(
         | 
| 84 | 
            +
                        f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
         | 
| 85 | 
            +
                    )
         | 
| 86 | 
            +
                return hsd
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            def experts_per_rank(args: Arguments) -> int:
         | 
| 90 | 
            +
                return args.moe_num_experts // expert_sharding_degree(args)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def features_per_rank(args: Arguments) -> int:
         | 
| 94 | 
            +
                return args.ffn_hidden_size // hidden_sharding_degree(args)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/router.py
    ADDED
    
    | @@ -0,0 +1,116 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            from typing import Any
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # from megablocks.layers import common
         | 
| 8 | 
            +
            # from megablocks.layers.arguments import Arguments
         | 
| 9 | 
            +
            from . import common
         | 
| 10 | 
            +
            from .arguments import Arguments
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            _ROUTER_LOGITS = []
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def _save_router_logits(logits: torch.Tensor, args: Arguments):
         | 
| 16 | 
            +
                if args.moe_zloss_weight == 0:
         | 
| 17 | 
            +
                    return
         | 
| 18 | 
            +
                global _ROUTER_LOGITS
         | 
| 19 | 
            +
                _ROUTER_LOGITS.append(logits)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def clear_router_zloss():
         | 
| 23 | 
            +
                global _ROUTER_LOGITS
         | 
| 24 | 
            +
                _ROUTER_LOGITS.clear()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def batched_router_zloss(args: Arguments):
         | 
| 28 | 
            +
                global _ROUTER_LOGITS
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                if args.moe_zloss_weight == 0:
         | 
| 31 | 
            +
                    import warnings
         | 
| 32 | 
            +
                    warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
         | 
| 33 | 
            +
                    return 0
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                logits_per_router = _ROUTER_LOGITS
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                if args.moe_zloss_in_fp32:
         | 
| 38 | 
            +
                    logits_per_router = [logits.float() for logits in logits_per_router]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                unscaled_zloss_per_router = torch.stack([
         | 
| 41 | 
            +
                    torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
         | 
| 42 | 
            +
                ])
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                return args.moe_zloss_weight * unscaled_zloss_per_router
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            # NOTE: To enable end-to-end benchmarking without convergence we
         | 
| 48 | 
            +
            # support a flag to force the router to assign tokens uniformly
         | 
| 49 | 
            +
            # across the experts. We do this with a custom autograd operation
         | 
| 50 | 
            +
            # so that PyTorch still executes the full set of router operation.
         | 
| 51 | 
            +
            class _UniformExpertAssignment(torch.autograd.Function):
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                @staticmethod
         | 
| 54 | 
            +
                def forward(ctx: Any, x: torch.Tensor, num_experts: int):
         | 
| 55 | 
            +
                    out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
         | 
| 56 | 
            +
                    out = torch.remainder(out, num_experts)
         | 
| 57 | 
            +
                    return out.view(x.shape)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
             | 
| 60 | 
            +
            _uniform_expert_assignment = _UniformExpertAssignment.apply
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            class LearnedRouter(torch.nn.Module):
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def __init__(self, args: Arguments):
         | 
| 66 | 
            +
                    super().__init__()
         | 
| 67 | 
            +
                    self.args = args
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # Learned router parameters.
         | 
| 70 | 
            +
                    #
         | 
| 71 | 
            +
                    # NOTE: This weight matrix is not parallelized with expert model
         | 
| 72 | 
            +
                    # parallelism. Each device needs the entire router weight matrix
         | 
| 73 | 
            +
                    # so that it can route its batch of data correctly.
         | 
| 74 | 
            +
                    self.layer = torch.nn.Linear(
         | 
| 75 | 
            +
                        args.hidden_size,
         | 
| 76 | 
            +
                        args.moe_num_experts,
         | 
| 77 | 
            +
                        bias=False,
         | 
| 78 | 
            +
                        dtype=common.dtype(args),
         | 
| 79 | 
            +
                        device=args.device,
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
                    args.init_method(self.layer.weight)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                def jitter(self, x: torch.Tensor):
         | 
| 84 | 
            +
                    low: float = 1.0 - self.args.moe_jitter_eps
         | 
| 85 | 
            +
                    high: float = 1.0 + self.args.moe_jitter_eps
         | 
| 86 | 
            +
                    noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
         | 
| 87 | 
            +
                    return low + noise * (high - low)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def _top_k(self, scores: torch.Tensor):
         | 
| 90 | 
            +
                    if self.args.moe_top_k == 1:
         | 
| 91 | 
            +
                        return scores.max(dim=-1, keepdim=True)
         | 
| 92 | 
            +
                    return torch.topk(scores, self.args.moe_top_k, dim=-1)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 95 | 
            +
                    if self.training and self.args.moe_jitter_eps is not None:
         | 
| 96 | 
            +
                        x = x * self.jitter(x)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    logits = self.layer(x.view(-1, x.shape[-1]))
         | 
| 99 | 
            +
                    _save_router_logits(logits, self.args)
         | 
| 100 | 
            +
                    scores = logits.softmax(dim=-1)
         | 
| 101 | 
            +
                    expert_weights, expert_indices = self._top_k(scores)
         | 
| 102 | 
            +
                    if self.args.moe_normalize_expert_weights:
         | 
| 103 | 
            +
                        expert_weights = expert_weights / torch.norm(
         | 
| 104 | 
            +
                            expert_weights,
         | 
| 105 | 
            +
                            p=self.args.moe_normalize_expert_weights,
         | 
| 106 | 
            +
                            dim=-1,
         | 
| 107 | 
            +
                            keepdim=True,
         | 
| 108 | 
            +
                        )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    expert_indices = (
         | 
| 111 | 
            +
                        _uniform_expert_assignment(
         | 
| 112 | 
            +
                            expert_indices,
         | 
| 113 | 
            +
                            self.args.moe_num_experts,
         | 
| 114 | 
            +
                        ) if self.args.uniform_expert_assignment else expert_indices
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
                    return scores, expert_weights, expert_indices
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_layers/sharedexpert_registry.py
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # from megablocks.layers import glu, mlp
         | 
| 7 | 
            +
            # from megablocks.layers.arguments import Arguments
         | 
| 8 | 
            +
            from . import glu, mlp
         | 
| 9 | 
            +
            from .arguments import Arguments
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            _REGISTRY = {
         | 
| 12 | 
            +
                'mlp': mlp.SharedMLP,
         | 
| 13 | 
            +
                'glu': glu.SharedGLU,
         | 
| 14 | 
            +
            }
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
         | 
| 18 | 
            +
                """Returns an SharedMLP for use in a dMoE instance.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Uses the provided arguments to instantiate the appropriate
         | 
| 21 | 
            +
                SharedMLP instance.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                Args:
         | 
| 24 | 
            +
                    args: propagated Arguments dataclass.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Returns:
         | 
| 27 | 
            +
                    An instantiated SharedMLP constructed using the input args.
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                if args.mlp_type not in _REGISTRY:
         | 
| 30 | 
            +
                    raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                return _REGISTRY[args.mlp_type](args)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_3bdb4b8_dirty.abi3.so
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:3f00f02cb159ccecc961af4ceab76fbebd06b61569f8b109a1c63cbcf9cf4a02
         | 
| 3 | 
            +
            size 10513752
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from . import _megablocks_3bdb4b8_dirty
         | 
| 3 | 
            +
            ops = torch.ops._megablocks_3bdb4b8_dirty
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def add_op_namespace_prefix(op_name: str):
         | 
| 6 | 
            +
                """
         | 
| 7 | 
            +
                Prefix op by namespace.
         | 
| 8 | 
            +
                """
         | 
| 9 | 
            +
                return f"_megablocks_3bdb4b8_dirty::{op_name}"
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/_version.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            """The MegaBlocks Version."""
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            __version__ = '0.11.0.dev0'
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/backend/kernels.py
    ADDED
    
    | @@ -0,0 +1,557 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import triton
         | 
| 6 | 
            +
            import triton.language as tl
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # Stub triton autotune when testing in a env that does not have CUDA
         | 
| 9 | 
            +
            # this approach preserves the original code but enables testing without a GPU
         | 
| 10 | 
            +
            if torch.cuda.is_available() is False:
         | 
| 11 | 
            +
                import warnings 
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                warnings.warn("CUDA is not available. Triton autotuning is disabled.")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def _no_autotune(*args, **kwargs):
         | 
| 16 | 
            +
                    def deco(fn):
         | 
| 17 | 
            +
                        return fn
         | 
| 18 | 
            +
                    return deco
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                triton.autotune = _no_autotune
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def assert_is_tensor(x, ndim):
         | 
| 24 | 
            +
                if x.ndim != ndim:
         | 
| 25 | 
            +
                    raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def assert_is_matrix(x):
         | 
| 29 | 
            +
                assert_is_tensor(x, 2)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def assert_is_vector(x):
         | 
| 33 | 
            +
                if x.ndim != 1:
         | 
| 34 | 
            +
                    raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def assert_equal(a, b):
         | 
| 38 | 
            +
                if a != b:
         | 
| 39 | 
            +
                    raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            # a: (tokens, hidden_size), real.
         | 
| 43 | 
            +
            # indices: (tokens * top_k), integer.
         | 
| 44 | 
            +
            # bin_ids: (tokens * top_k), integer.
         | 
| 45 | 
            +
            # weights: (tokens * top_k), real.
         | 
| 46 | 
            +
            # bins: (num_experts), integer.
         | 
| 47 | 
            +
            # padded_bins: (num_experts), integer.
         | 
| 48 | 
            +
            @triton.autotune(
         | 
| 49 | 
            +
                configs=[
         | 
| 50 | 
            +
                    triton.Config({'BLOCK_X': 64}, num_warps=2),
         | 
| 51 | 
            +
                    triton.Config({'BLOCK_X': 128}, num_warps=2),
         | 
| 52 | 
            +
                    triton.Config({'BLOCK_X': 256}, num_warps=2),
         | 
| 53 | 
            +
                    triton.Config({'BLOCK_X': 128}, num_warps=4),
         | 
| 54 | 
            +
                    triton.Config({'BLOCK_X': 256}, num_warps=4),
         | 
| 55 | 
            +
                ],
         | 
| 56 | 
            +
                key=['NUM_COLUMNS'],
         | 
| 57 | 
            +
            )
         | 
| 58 | 
            +
            @triton.jit
         | 
| 59 | 
            +
            def _padded_copy(
         | 
| 60 | 
            +
                a,
         | 
| 61 | 
            +
                b,
         | 
| 62 | 
            +
                indices,
         | 
| 63 | 
            +
                bin_ids,
         | 
| 64 | 
            +
                weights,
         | 
| 65 | 
            +
                bins,
         | 
| 66 | 
            +
                padded_bins,
         | 
| 67 | 
            +
                NUM_COLUMNS: tl.constexpr,
         | 
| 68 | 
            +
                TOP_K: tl.constexpr,
         | 
| 69 | 
            +
                BLOCK_X: tl.constexpr,
         | 
| 70 | 
            +
                A_TO_B: tl.constexpr,
         | 
| 71 | 
            +
                SCALE: tl.constexpr,
         | 
| 72 | 
            +
            ):
         | 
| 73 | 
            +
                # Our index into array 'a'.
         | 
| 74 | 
            +
                index_a = tl.load(indices + tl.program_id(0))
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                # One threadblock per row in 'a'. Array 'b' has greater or equal
         | 
| 77 | 
            +
                # number of rows since they could be padded.
         | 
| 78 | 
            +
                bin_idx = tl.load(bin_ids + tl.program_id(0))
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                # Now we know what bin we're assigned to, but we need to know how
         | 
| 81 | 
            +
                # many threadblocks were assigned to earlier bins so we can offset
         | 
| 82 | 
            +
                # in our bin properly.
         | 
| 83 | 
            +
                offset_in_bin = tl.program_id(0)
         | 
| 84 | 
            +
                if bin_idx > 0:
         | 
| 85 | 
            +
                    offset_in_bin -= tl.load(bins + bin_idx - 1)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                # Load the starting index of our bin in array 'b'.
         | 
| 88 | 
            +
                index_b = offset_in_bin
         | 
| 89 | 
            +
                if bin_idx > 0:
         | 
| 90 | 
            +
                    index_b += tl.load(padded_bins + bin_idx - 1)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                # Offset the input and output pointers.
         | 
| 93 | 
            +
                #
         | 
| 94 | 
            +
                # If we're going from A to B, divide the input index to copy
         | 
| 95 | 
            +
                # the same input repeatedly. If we're going from B to A we
         | 
| 96 | 
            +
                # need to reduce the result. Using atomics is slow, so we
         | 
| 97 | 
            +
                # do the reduce step in a second kernel.
         | 
| 98 | 
            +
                offset = index_a // TOP_K if A_TO_B else index_a
         | 
| 99 | 
            +
                a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
         | 
| 100 | 
            +
                b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
         | 
| 101 | 
            +
                offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                # Load the scale, if requested.
         | 
| 104 | 
            +
                scale = tl.load(weights + index_a) if SCALE else 1
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                # Swap the pointers depending on the direction.
         | 
| 107 | 
            +
                iptr = a if A_TO_B else b
         | 
| 108 | 
            +
                optr = b if A_TO_B else a
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
         | 
| 111 | 
            +
                for _ in range(iterations):
         | 
| 112 | 
            +
                    mask = offsets < NUM_COLUMNS
         | 
| 113 | 
            +
                    x = tl.load(iptr + offsets, mask=mask)
         | 
| 114 | 
            +
                    x = x.to(tl.float32) * scale.to(tl.float32)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    offsets += BLOCK_X
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
         | 
| 122 | 
            +
                # Validate the input shapes.
         | 
| 123 | 
            +
                assert_is_matrix(x)
         | 
| 124 | 
            +
                assert_is_vector(indices)
         | 
| 125 | 
            +
                assert_is_vector(bin_ids)
         | 
| 126 | 
            +
                assert_is_vector(bins)
         | 
| 127 | 
            +
                assert_is_vector(padded_bins)
         | 
| 128 | 
            +
                assert_equal(indices.shape[0], x.shape[0] * top_k)
         | 
| 129 | 
            +
                assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
         | 
| 130 | 
            +
                assert_equal(bins.size(), padded_bins.size())
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                if weights is not None:
         | 
| 133 | 
            +
                    assert_equal(weights.shape[0], x.shape[0] * top_k)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                # NOTE: Because of the padding, the output size is dynamic.
         | 
| 136 | 
            +
                # We load the final padded bin bound to get the output rows.
         | 
| 137 | 
            +
                output_rows = padded_bins[-1].cpu().item()
         | 
| 138 | 
            +
                out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
         | 
| 139 | 
            +
                _padded_copy[(indices.shape[0],)](
         | 
| 140 | 
            +
                    x,
         | 
| 141 | 
            +
                    out,
         | 
| 142 | 
            +
                    indices,
         | 
| 143 | 
            +
                    bin_ids,
         | 
| 144 | 
            +
                    weights,
         | 
| 145 | 
            +
                    bins,
         | 
| 146 | 
            +
                    padded_bins,
         | 
| 147 | 
            +
                    NUM_COLUMNS=x.shape[1],
         | 
| 148 | 
            +
                    A_TO_B=True,
         | 
| 149 | 
            +
                    TOP_K=top_k,
         | 
| 150 | 
            +
                    SCALE=weights is not None,
         | 
| 151 | 
            +
                )
         | 
| 152 | 
            +
                return out
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            def gather(x, indices, bin_ids, weights, bins, top_k):
         | 
| 156 | 
            +
                # Validate the input shapes.
         | 
| 157 | 
            +
                assert_is_matrix(x)
         | 
| 158 | 
            +
                assert_is_vector(indices)
         | 
| 159 | 
            +
                assert_is_vector(bin_ids)
         | 
| 160 | 
            +
                assert_is_vector(bins)
         | 
| 161 | 
            +
                assert_equal(indices.shape[0], x.shape[0] * top_k)
         | 
| 162 | 
            +
                assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                if weights is not None:
         | 
| 165 | 
            +
                    assert_equal(weights.shape[0], x.shape[0] * top_k)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                # NOTE: There is no padding so the output rows equals the
         | 
| 168 | 
            +
                # input rows multiplied by top_k.
         | 
| 169 | 
            +
                output_rows = x.shape[0] * top_k
         | 
| 170 | 
            +
                out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
         | 
| 171 | 
            +
                _padded_copy[(indices.shape[0],)](
         | 
| 172 | 
            +
                    x,
         | 
| 173 | 
            +
                    out,
         | 
| 174 | 
            +
                    indices,
         | 
| 175 | 
            +
                    bin_ids,
         | 
| 176 | 
            +
                    weights,
         | 
| 177 | 
            +
                    bins,
         | 
| 178 | 
            +
                    bins,
         | 
| 179 | 
            +
                    NUM_COLUMNS=x.shape[1],
         | 
| 180 | 
            +
                    A_TO_B=True,
         | 
| 181 | 
            +
                    TOP_K=top_k,
         | 
| 182 | 
            +
                    SCALE=weights is not None,
         | 
| 183 | 
            +
                )
         | 
| 184 | 
            +
                return out
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
         | 
| 188 | 
            +
                # Validate the input shapes.
         | 
| 189 | 
            +
                assert_is_matrix(x)
         | 
| 190 | 
            +
                assert_is_vector(indices)
         | 
| 191 | 
            +
                assert_is_vector(bin_ids)
         | 
| 192 | 
            +
                assert_is_vector(bins)
         | 
| 193 | 
            +
                assert_is_vector(padded_bins)
         | 
| 194 | 
            +
                assert_equal(indices.shape[0], bin_ids.shape[0])
         | 
| 195 | 
            +
                assert_equal(bins.size(), padded_bins.size())
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                if weights is not None:
         | 
| 198 | 
            +
                    assert_equal(indices.shape[0], weights.shape[0])
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                tokens = indices.shape[0] // top_k
         | 
| 201 | 
            +
                out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
         | 
| 202 | 
            +
                _padded_copy[(indices.shape[0],)](
         | 
| 203 | 
            +
                    out,
         | 
| 204 | 
            +
                    x,
         | 
| 205 | 
            +
                    indices,
         | 
| 206 | 
            +
                    bin_ids,
         | 
| 207 | 
            +
                    weights,
         | 
| 208 | 
            +
                    bins,
         | 
| 209 | 
            +
                    padded_bins,
         | 
| 210 | 
            +
                    NUM_COLUMNS=x.shape[1],
         | 
| 211 | 
            +
                    A_TO_B=False,
         | 
| 212 | 
            +
                    TOP_K=top_k,
         | 
| 213 | 
            +
                    SCALE=weights is not None,
         | 
| 214 | 
            +
                )
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                # Reduce along the top-k dimension, if needed.
         | 
| 217 | 
            +
                return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            def scatter(x, indices, bin_ids, weights, bins, top_k):
         | 
| 221 | 
            +
                return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
             | 
| 224 | 
            +
            # x: (tokens, top_k, hidden_size), real
         | 
| 225 | 
            +
            # grad: (tokens, hidden_size), real.
         | 
| 226 | 
            +
            # wgrad: (tokens, top_k), real.
         | 
| 227 | 
            +
            # indices: (tokens * top_k), integer.
         | 
| 228 | 
            +
            # bin_ids: (tokens * top_k), integer.
         | 
| 229 | 
            +
            # bins: (num_experts), integer.
         | 
| 230 | 
            +
            # padded_bins: (num_experts), integer.
         | 
| 231 | 
            +
            @triton.autotune(
         | 
| 232 | 
            +
                configs=[
         | 
| 233 | 
            +
                    triton.Config({'BLOCK_X': 64}, num_warps=2),
         | 
| 234 | 
            +
                    triton.Config({'BLOCK_X': 128}, num_warps=2),
         | 
| 235 | 
            +
                    triton.Config({'BLOCK_X': 256}, num_warps=2),
         | 
| 236 | 
            +
                    triton.Config({'BLOCK_X': 128}, num_warps=4),
         | 
| 237 | 
            +
                    triton.Config({'BLOCK_X': 256}, num_warps=4),
         | 
| 238 | 
            +
                ],
         | 
| 239 | 
            +
                key=['NUM_COLUMNS'],
         | 
| 240 | 
            +
            )
         | 
| 241 | 
            +
            @triton.jit
         | 
| 242 | 
            +
            def _padded_copy_wgrad(
         | 
| 243 | 
            +
                x,
         | 
| 244 | 
            +
                grad,
         | 
| 245 | 
            +
                wgrad,
         | 
| 246 | 
            +
                indices,
         | 
| 247 | 
            +
                bin_ids,
         | 
| 248 | 
            +
                bins,
         | 
| 249 | 
            +
                padded_bins,
         | 
| 250 | 
            +
                NUM_COLUMNS: tl.constexpr,
         | 
| 251 | 
            +
                TOP_K: tl.constexpr,
         | 
| 252 | 
            +
                BLOCK_X: tl.constexpr,
         | 
| 253 | 
            +
            ):
         | 
| 254 | 
            +
                # Our index into 'tokens * top_k'.
         | 
| 255 | 
            +
                index_out = tl.load(indices + tl.program_id(0))
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                # One threadblock per row in 'a'. Array 'b' has greater or equal
         | 
| 258 | 
            +
                # number of rows since they could be padded.
         | 
| 259 | 
            +
                bin_idx = tl.load(bin_ids + tl.program_id(0))
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                # Now we know what bin we're assigned to, but we need to know how
         | 
| 262 | 
            +
                # many threadblocks were assigned to earlier bins so we can offset
         | 
| 263 | 
            +
                # in our bin properly.
         | 
| 264 | 
            +
                offset_in_bin = tl.program_id(0)
         | 
| 265 | 
            +
                if bin_idx > 0:
         | 
| 266 | 
            +
                    offset_in_bin -= tl.load(bins + bin_idx - 1)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                # Load the starting index of our bin in array 'x'.
         | 
| 269 | 
            +
                index_x = offset_in_bin
         | 
| 270 | 
            +
                if bin_idx > 0:
         | 
| 271 | 
            +
                    index_x += tl.load(padded_bins + bin_idx - 1)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                # Offset the input and output pointers.
         | 
| 274 | 
            +
                wgrad += index_out
         | 
| 275 | 
            +
                grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
         | 
| 276 | 
            +
                x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
         | 
| 277 | 
            +
                offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
         | 
| 280 | 
            +
                iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
         | 
| 281 | 
            +
                for _ in range(iterations):
         | 
| 282 | 
            +
                    mask = offsets < NUM_COLUMNS
         | 
| 283 | 
            +
                    data = tl.load(x + offsets, mask=mask).to(tl.float32)
         | 
| 284 | 
            +
                    scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
         | 
| 285 | 
            +
                    acc += data * scale
         | 
| 286 | 
            +
                    offsets += BLOCK_X
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                # Reduce to get the final result and store.
         | 
| 289 | 
            +
                out = tl.sum(acc).to(wgrad.dtype.element_ty)
         | 
| 290 | 
            +
                tl.store(wgrad, out)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
             | 
| 293 | 
            +
            def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
         | 
| 294 | 
            +
                # Validate the input shapes.
         | 
| 295 | 
            +
                assert_is_matrix(x)
         | 
| 296 | 
            +
                assert_is_matrix(grad)
         | 
| 297 | 
            +
                assert_is_vector(indices)
         | 
| 298 | 
            +
                assert_is_vector(bin_ids)
         | 
| 299 | 
            +
                assert_is_vector(bins)
         | 
| 300 | 
            +
                assert_is_vector(padded_bins)
         | 
| 301 | 
            +
                assert_equal(indices.shape[0], bin_ids.shape[0])
         | 
| 302 | 
            +
                assert_equal(bins.size(), padded_bins.size())
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                tokens = indices.shape[0] // top_k
         | 
| 305 | 
            +
                out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
         | 
| 306 | 
            +
                _padded_copy_wgrad[(indices.shape[0],)](
         | 
| 307 | 
            +
                    x,
         | 
| 308 | 
            +
                    grad,
         | 
| 309 | 
            +
                    out,
         | 
| 310 | 
            +
                    indices,
         | 
| 311 | 
            +
                    bin_ids,
         | 
| 312 | 
            +
                    bins,
         | 
| 313 | 
            +
                    padded_bins,
         | 
| 314 | 
            +
                    NUM_COLUMNS=x.shape[1],
         | 
| 315 | 
            +
                    TOP_K=top_k,
         | 
| 316 | 
            +
                )
         | 
| 317 | 
            +
                return out
         | 
| 318 | 
            +
             | 
| 319 | 
            +
             | 
| 320 | 
            +
            def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
         | 
| 321 | 
            +
                return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            # a: (tokens, hidden_size), real.
         | 
| 325 | 
            +
            # b: (num_experts, expert_capacity, num_columns), real.
         | 
| 326 | 
            +
            # indices: (tokens * top_k), integer.
         | 
| 327 | 
            +
            # weights: (tokens * top_k), real.
         | 
| 328 | 
            +
            # bins: (num_experts), integer.
         | 
| 329 | 
            +
            @triton.autotune(
         | 
| 330 | 
            +
                configs=[
         | 
| 331 | 
            +
                    triton.Config({'BLOCK_X': 64}, num_warps=2),
         | 
| 332 | 
            +
                    triton.Config({'BLOCK_X': 128}, num_warps=2),
         | 
| 333 | 
            +
                    triton.Config({'BLOCK_X': 256}, num_warps=2),
         | 
| 334 | 
            +
                    triton.Config({'BLOCK_X': 128}, num_warps=4),
         | 
| 335 | 
            +
                    triton.Config({'BLOCK_X': 256}, num_warps=4),
         | 
| 336 | 
            +
                ],
         | 
| 337 | 
            +
                key=['NUM_COLUMNS'],
         | 
| 338 | 
            +
            )
         | 
| 339 | 
            +
            @triton.jit
         | 
| 340 | 
            +
            def _binned_copy(
         | 
| 341 | 
            +
                a,
         | 
| 342 | 
            +
                b,
         | 
| 343 | 
            +
                num_experts,
         | 
| 344 | 
            +
                expert_capacity,
         | 
| 345 | 
            +
                indices,
         | 
| 346 | 
            +
                weights,
         | 
| 347 | 
            +
                bins,
         | 
| 348 | 
            +
                NUM_COLUMNS: tl.constexpr,
         | 
| 349 | 
            +
                TOP_K: tl.constexpr,
         | 
| 350 | 
            +
                BLOCK_X: tl.constexpr,
         | 
| 351 | 
            +
                A_TO_B: tl.constexpr,
         | 
| 352 | 
            +
                SCALE: tl.constexpr,
         | 
| 353 | 
            +
            ):
         | 
| 354 | 
            +
                # Load our indices into the output.
         | 
| 355 | 
            +
                expert_idx = tl.program_id(0)
         | 
| 356 | 
            +
                entry_idx = tl.program_id(1)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                # Calculate our offset into the output.
         | 
| 359 | 
            +
                index_b = expert_idx * expert_capacity + entry_idx
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                # Load the index bounds for our bin and calculate
         | 
| 362 | 
            +
                # the number of tokens assigned to our expert.
         | 
| 363 | 
            +
                start = 0
         | 
| 364 | 
            +
                if expert_idx > 0:
         | 
| 365 | 
            +
                    start = tl.load(bins + expert_idx - 1)
         | 
| 366 | 
            +
                end = tl.load(bins + expert_idx)
         | 
| 367 | 
            +
                num_tokens = end - start
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                # Calculate our offset into the input. If we don't
         | 
| 370 | 
            +
                # have an input exit early.
         | 
| 371 | 
            +
                if entry_idx >= num_tokens:
         | 
| 372 | 
            +
                    return
         | 
| 373 | 
            +
                index_a = tl.load(indices + start + entry_idx)
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                # Offset the input and output pointers.
         | 
| 376 | 
            +
                #
         | 
| 377 | 
            +
                # If we're going from A to B, divide the input index to copy
         | 
| 378 | 
            +
                # the same input repeatedly. If we're going from B to A we
         | 
| 379 | 
            +
                # need to reduce the result. Using atomics is slow, so we
         | 
| 380 | 
            +
                # do the reduce step in a second kernel.
         | 
| 381 | 
            +
                offset = index_a // TOP_K if A_TO_B else index_a
         | 
| 382 | 
            +
                a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
         | 
| 383 | 
            +
                b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
         | 
| 384 | 
            +
                offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                # Load the scale, if requested.
         | 
| 387 | 
            +
                scale = tl.load(weights + index_a) if SCALE else 1
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                # Swap the pointers depending on the direction.
         | 
| 390 | 
            +
                #
         | 
| 391 | 
            +
                # NOTE: We need to zero the output in both directions.
         | 
| 392 | 
            +
                iptr = a if A_TO_B else b
         | 
| 393 | 
            +
                optr = b if A_TO_B else a
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
         | 
| 396 | 
            +
                for _ in range(iterations):
         | 
| 397 | 
            +
                    mask = offsets < NUM_COLUMNS
         | 
| 398 | 
            +
                    x = tl.load(iptr + offsets, mask=mask)
         | 
| 399 | 
            +
                    x = x.to(tl.float32) * scale.to(tl.float32)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                    offsets += BLOCK_X
         | 
| 404 | 
            +
             | 
| 405 | 
            +
             | 
| 406 | 
            +
            def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
         | 
| 407 | 
            +
                # Validate the input shapes.
         | 
| 408 | 
            +
                assert_is_matrix(x)
         | 
| 409 | 
            +
                assert_is_vector(indices)
         | 
| 410 | 
            +
                assert_is_vector(bins)
         | 
| 411 | 
            +
                assert_equal(indices.shape[0], x.shape[0] * top_k)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                if weights is not None:
         | 
| 414 | 
            +
                    assert_equal(weights.shape[0], x.shape[0] * top_k)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                num_experts = bins.shape[0]
         | 
| 417 | 
            +
                out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                _binned_copy[(num_experts, expert_capacity)](
         | 
| 420 | 
            +
                    x,
         | 
| 421 | 
            +
                    out,
         | 
| 422 | 
            +
                    num_experts,
         | 
| 423 | 
            +
                    expert_capacity,
         | 
| 424 | 
            +
                    indices,
         | 
| 425 | 
            +
                    weights,
         | 
| 426 | 
            +
                    bins,
         | 
| 427 | 
            +
                    NUM_COLUMNS=x.shape[1],
         | 
| 428 | 
            +
                    A_TO_B=True,
         | 
| 429 | 
            +
                    TOP_K=top_k,
         | 
| 430 | 
            +
                    SCALE=weights is not None,
         | 
| 431 | 
            +
                )
         | 
| 432 | 
            +
                return out
         | 
| 433 | 
            +
             | 
| 434 | 
            +
             | 
| 435 | 
            +
            def binned_scatter(x, indices, weights, bins, top_k):
         | 
| 436 | 
            +
                # Validate the input shapes.
         | 
| 437 | 
            +
                assert_is_tensor(x, 3)
         | 
| 438 | 
            +
                assert_is_vector(indices)
         | 
| 439 | 
            +
                assert_is_vector(bins)
         | 
| 440 | 
            +
                assert_equal(bins.shape[0], x.shape[0])
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                if weights is not None:
         | 
| 443 | 
            +
                    assert_equal(indices.shape[0], weights.shape[0])
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                num_experts, expert_capacity, hidden_size = x.shape
         | 
| 446 | 
            +
                tokens = indices.shape[0] // top_k
         | 
| 447 | 
            +
                out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
         | 
| 448 | 
            +
                _binned_copy[(num_experts, expert_capacity)](
         | 
| 449 | 
            +
                    out,
         | 
| 450 | 
            +
                    x,
         | 
| 451 | 
            +
                    num_experts,
         | 
| 452 | 
            +
                    expert_capacity,
         | 
| 453 | 
            +
                    indices,
         | 
| 454 | 
            +
                    weights,
         | 
| 455 | 
            +
                    bins,
         | 
| 456 | 
            +
                    NUM_COLUMNS=hidden_size,
         | 
| 457 | 
            +
                    A_TO_B=False,
         | 
| 458 | 
            +
                    TOP_K=top_k,
         | 
| 459 | 
            +
                    SCALE=weights is not None,
         | 
| 460 | 
            +
                )
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                # Reduce along the top-k dimension, if needed.
         | 
| 463 | 
            +
                return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
             | 
| 466 | 
            +
            # a: (tokens, hidden_size), real.
         | 
| 467 | 
            +
            # b: (num_experts, expert_capacity, num_columns), real.
         | 
| 468 | 
            +
            # indices: (tokens * top_k), integer.
         | 
| 469 | 
            +
            # weights: (tokens * top_k), real.
         | 
| 470 | 
            +
            # bins: (num_experts), integer.
         | 
| 471 | 
            +
            @triton.autotune(
         | 
| 472 | 
            +
                configs=[
         | 
| 473 | 
            +
                    triton.Config({'BLOCK_X': 64}, num_warps=2),
         | 
| 474 | 
            +
                    triton.Config({'BLOCK_X': 128}, num_warps=2),
         | 
| 475 | 
            +
                    triton.Config({'BLOCK_X': 256}, num_warps=2),
         | 
| 476 | 
            +
                    triton.Config({'BLOCK_X': 128}, num_warps=4),
         | 
| 477 | 
            +
                    triton.Config({'BLOCK_X': 256}, num_warps=4),
         | 
| 478 | 
            +
                ],
         | 
| 479 | 
            +
                key=['NUM_COLUMNS'],
         | 
| 480 | 
            +
            )
         | 
| 481 | 
            +
            @triton.jit
         | 
| 482 | 
            +
            def _binned_copy_wgrad(
         | 
| 483 | 
            +
                x,
         | 
| 484 | 
            +
                grad,
         | 
| 485 | 
            +
                wgrad,
         | 
| 486 | 
            +
                num_experts,
         | 
| 487 | 
            +
                expert_capacity,
         | 
| 488 | 
            +
                indices,
         | 
| 489 | 
            +
                bins,
         | 
| 490 | 
            +
                NUM_COLUMNS: tl.constexpr,
         | 
| 491 | 
            +
                TOP_K: tl.constexpr,
         | 
| 492 | 
            +
                BLOCK_X: tl.constexpr,
         | 
| 493 | 
            +
            ):
         | 
| 494 | 
            +
                # Load our indices into the output.
         | 
| 495 | 
            +
                expert_idx = tl.program_id(0)
         | 
| 496 | 
            +
                entry_idx = tl.program_id(1)
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                # Calculate our offset into the output.
         | 
| 499 | 
            +
                index_x = expert_idx * expert_capacity + entry_idx
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                # Load the index bounds for our bin and calculate
         | 
| 502 | 
            +
                # the number of tokens assigned to our expert.
         | 
| 503 | 
            +
                start = 0
         | 
| 504 | 
            +
                if expert_idx > 0:
         | 
| 505 | 
            +
                    start = tl.load(bins + expert_idx - 1)
         | 
| 506 | 
            +
                end = tl.load(bins + expert_idx)
         | 
| 507 | 
            +
                num_tokens = end - start
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                # Calculate our offset into the input. If we don't
         | 
| 510 | 
            +
                # have an input exit early.
         | 
| 511 | 
            +
                if entry_idx >= num_tokens:
         | 
| 512 | 
            +
                    return
         | 
| 513 | 
            +
                index_out = tl.load(indices + start + entry_idx)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                # Offset the input and output pointers.
         | 
| 516 | 
            +
                wgrad += index_out
         | 
| 517 | 
            +
                grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
         | 
| 518 | 
            +
                x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
         | 
| 519 | 
            +
                offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
         | 
| 522 | 
            +
                iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
         | 
| 523 | 
            +
                for _ in range(iterations):
         | 
| 524 | 
            +
                    mask = offsets < NUM_COLUMNS
         | 
| 525 | 
            +
                    data = tl.load(x + offsets, mask=mask).to(tl.float32)
         | 
| 526 | 
            +
                    scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
         | 
| 527 | 
            +
                    acc += data * scale
         | 
| 528 | 
            +
                    offsets += BLOCK_X
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                # Reduce to get the final result and store.
         | 
| 531 | 
            +
                out = tl.sum(acc).to(wgrad.dtype.element_ty)
         | 
| 532 | 
            +
                tl.store(wgrad, out)
         | 
| 533 | 
            +
             | 
| 534 | 
            +
             | 
| 535 | 
            +
            def binned_scatter_wgrad(x, grad, indices, bins, top_k):
         | 
| 536 | 
            +
                # Validate the input shapes.
         | 
| 537 | 
            +
                assert_is_tensor(x, 3)
         | 
| 538 | 
            +
                assert_is_matrix(grad)
         | 
| 539 | 
            +
                assert_is_vector(indices)
         | 
| 540 | 
            +
                assert_is_vector(bins)
         | 
| 541 | 
            +
                assert_equal(bins.shape[0], x.shape[0])
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                num_experts, expert_capacity, hidden_size = x.shape
         | 
| 544 | 
            +
                tokens = indices.shape[0] // top_k
         | 
| 545 | 
            +
                out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
         | 
| 546 | 
            +
                _binned_copy_wgrad[(num_experts, expert_capacity)](
         | 
| 547 | 
            +
                    x,
         | 
| 548 | 
            +
                    grad,
         | 
| 549 | 
            +
                    out,
         | 
| 550 | 
            +
                    num_experts,
         | 
| 551 | 
            +
                    expert_capacity,
         | 
| 552 | 
            +
                    indices,
         | 
| 553 | 
            +
                    bins,
         | 
| 554 | 
            +
                    NUM_COLUMNS=hidden_size,
         | 
| 555 | 
            +
                    TOP_K=top_k,
         | 
| 556 | 
            +
                )
         | 
| 557 | 
            +
                return out
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/bak.__init__.py
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from megablocks_moe.megablocks import (
         | 
| 2 | 
            +
                MoE,
         | 
| 3 | 
            +
                dMoE,
         | 
| 4 | 
            +
                get_load_balancing_loss,
         | 
| 5 | 
            +
                ParallelMLP,
         | 
| 6 | 
            +
                ParallelDroplessMLP,
         | 
| 7 | 
            +
                SparseMLP,
         | 
| 8 | 
            +
                MLP,
         | 
| 9 | 
            +
                SparseGLU,
         | 
| 10 | 
            +
                Arguments,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            __all__ = [
         | 
| 14 | 
            +
                "MoE",
         | 
| 15 | 
            +
                "dMoE",
         | 
| 16 | 
            +
                "get_load_balancing_loss",
         | 
| 17 | 
            +
                "ParallelMLP",
         | 
| 18 | 
            +
                "ParallelDroplessMLP",
         | 
| 19 | 
            +
                "SparseMLP",
         | 
| 20 | 
            +
                "MLP",
         | 
| 21 | 
            +
                "SparseGLU",
         | 
| 22 | 
            +
                "Arguments",
         | 
| 23 | 
            +
            ]
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/benchmark_util.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def log_benchmark(name, arguments, time, std):
         | 
| 9 | 
            +
                print('=' * 60)
         | 
| 10 | 
            +
                print(f'{name} Benchmark')
         | 
| 11 | 
            +
                print('Benchmark Parameters:')
         | 
| 12 | 
            +
                for (key, value) in arguments.items():
         | 
| 13 | 
            +
                    print(f'{key} = {value}')
         | 
| 14 | 
            +
                print('Results:')
         | 
| 15 | 
            +
                print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
         | 
| 16 | 
            +
                print('=' * 60)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def benchmark_function(fn, iterations=100, warmup=10):
         | 
| 20 | 
            +
                # Warmup iterations.
         | 
| 21 | 
            +
                for _ in range(warmup):
         | 
| 22 | 
            +
                    fn()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                times = []
         | 
| 25 | 
            +
                for i in range(iterations):
         | 
| 26 | 
            +
                    start = torch.cuda.Event(enable_timing=True)
         | 
| 27 | 
            +
                    end = torch.cuda.Event(enable_timing=True)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    start.record()
         | 
| 30 | 
            +
                    fn()
         | 
| 31 | 
            +
                    end.record()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    torch.cuda.synchronize()
         | 
| 34 | 
            +
                    times.append(start.elapsed_time(end))
         | 
| 35 | 
            +
                return np.mean(times), np.std(times)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from . import ops
         | 
| 2 | 
            +
            from . import backend
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/backend.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # NOTE: Torch needs to be imported before the custom
         | 
| 2 | 
            +
            # extensions. Otherwise libc10.so cannot be found.
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # # TODO(tgale): Wrap this in a try-block with better
         | 
| 6 | 
            +
            # # error message and instructions for building the
         | 
| 7 | 
            +
            # # c++ operations.
         | 
| 8 | 
            +
            # import grouped_gemm_backend as backend
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # We import the backend operations from the megablocks package as
         | 
| 11 | 
            +
            # grouped_gemm is vendored in megablocks in this repository.
         | 
| 12 | 
            +
            # from ... import _ops as backend
         | 
| 13 | 
            +
            # from megablocks._ops import ops as backend  # type: ignore
         | 
| 14 | 
            +
            from .._ops import ops as backend  # type: ignore
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
         | 
| 17 | 
            +
                assert not (trans_a and trans_b)
         | 
| 18 | 
            +
                assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
         | 
| 19 | 
            +
                assert a.ndim == 2, "Expected 2d tensor for 'a'"
         | 
| 20 | 
            +
                assert b.ndim == (2 if trans_a else 3)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                shape = (
         | 
| 23 | 
            +
                    (batch_sizes.shape[0], a.shape[1], b.shape[1])
         | 
| 24 | 
            +
                    if trans_a else
         | 
| 25 | 
            +
                    (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
         | 
| 26 | 
            +
                )
         | 
| 27 | 
            +
                return torch.empty(*shape, device=a.device, dtype=a.dtype)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
         | 
| 30 | 
            +
                if c is None:
         | 
| 31 | 
            +
                    c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
         | 
| 32 | 
            +
                backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
         | 
| 33 | 
            +
                return c
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm/ops.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from . import backend
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class GroupedGemm(torch.autograd.Function):
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                @staticmethod
         | 
| 8 | 
            +
                def forward(ctx, a, b, batch_sizes, trans_b):
         | 
| 9 | 
            +
                    ctx.save_for_backward(a, b, batch_sizes)
         | 
| 10 | 
            +
                    ctx.trans_b = trans_b
         | 
| 11 | 
            +
                    return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                @staticmethod
         | 
| 14 | 
            +
                def backward(ctx, grad):
         | 
| 15 | 
            +
                    grad = grad.contiguous()
         | 
| 16 | 
            +
                    a, b, batch_sizes = ctx.saved_tensors
         | 
| 17 | 
            +
                    trans_b = ctx.trans_b
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    agrad = None
         | 
| 20 | 
            +
                    if ctx.needs_input_grad[0]:
         | 
| 21 | 
            +
                        agrad = backend.gmm(
         | 
| 22 | 
            +
                            grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    bgrad = None
         | 
| 25 | 
            +
                    if ctx.needs_input_grad[1]:
         | 
| 26 | 
            +
                        lhs, rhs = (grad, a) if trans_b else (a, grad)
         | 
| 27 | 
            +
                        bgrad = backend.gmm(
         | 
| 28 | 
            +
                            lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
         | 
| 29 | 
            +
                    return agrad, bgrad, None, None
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def gmm(a, b, batch_sizes, trans_b=False):
         | 
| 33 | 
            +
                return GroupedGemm.apply(a, b, batch_sizes, trans_b)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/grouped_gemm_util.py
    ADDED
    
    | @@ -0,0 +1,31 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            import warnings
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            _grouped_gemm_is_available: bool = False
         | 
| 6 | 
            +
            try:
         | 
| 7 | 
            +
                # import grouped_gemm
         | 
| 8 | 
            +
                pass
         | 
| 9 | 
            +
                _grouped_gemm_is_available = True
         | 
| 10 | 
            +
            except ImportError as error:
         | 
| 11 | 
            +
                warnings.warn('Grouped GEMM not available.')
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def grouped_gemm_is_available():
         | 
| 15 | 
            +
                return _grouped_gemm_is_available
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def assert_grouped_gemm_is_available():
         | 
| 19 | 
            +
                msg = (
         | 
| 20 | 
            +
                    'Grouped GEMM not available. Please run '
         | 
| 21 | 
            +
                    '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
         | 
| 22 | 
            +
                )
         | 
| 23 | 
            +
                assert _grouped_gemm_is_available, msg
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
         | 
| 27 | 
            +
            # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            from .grouped_gemm import backend as ops
         | 
| 31 | 
            +
            from .grouped_gemm import ops as backend
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py
    ADDED
    
    | @@ -0,0 +1,1001 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.distributed as dist
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Optional, Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from . import _layers
         | 
| 7 | 
            +
            from . import ops
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Set the expert model parallel attributes on a tensor
         | 
| 11 | 
            +
            def set_expert_model_parallel_attributes(
         | 
| 12 | 
            +
                tensor: torch.Tensor,
         | 
| 13 | 
            +
                is_parallel: bool,
         | 
| 14 | 
            +
            ):
         | 
| 15 | 
            +
                assert not hasattr(tensor, "expert_model_parallel")
         | 
| 16 | 
            +
                setattr(tensor, "expert_model_parallel", is_parallel)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            # Get the expert model parallel attributes from a tensor
         | 
| 20 | 
            +
            def expert_sharding_degree(
         | 
| 21 | 
            +
                world_size: int,
         | 
| 22 | 
            +
                moe_num_experts: int,
         | 
| 23 | 
            +
            ) -> int:
         | 
| 24 | 
            +
                esd = min(world_size, moe_num_experts)
         | 
| 25 | 
            +
                if (moe_num_experts % esd) != 0:
         | 
| 26 | 
            +
                    raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
         | 
| 27 | 
            +
                return esd
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            # Calculate the hidden sharding degree based on world size and expert sharding degree
         | 
| 31 | 
            +
            def hidden_sharding_degree(
         | 
| 32 | 
            +
                world_size: int,
         | 
| 33 | 
            +
                moe_num_experts: int,
         | 
| 34 | 
            +
                ffn_hidden_size: int,
         | 
| 35 | 
            +
            ) -> int:
         | 
| 36 | 
            +
                esd = expert_sharding_degree(world_size, moe_num_experts)
         | 
| 37 | 
            +
                hsd = world_size // esd
         | 
| 38 | 
            +
                if (ffn_hidden_size % hsd) != 0:
         | 
| 39 | 
            +
                    raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
         | 
| 40 | 
            +
                if (esd * hsd) != world_size:
         | 
| 41 | 
            +
                    raise ValueError(
         | 
| 42 | 
            +
                        f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
                return hsd
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            # Calculate the number of experts per rank based on world size and expert sharding degree
         | 
| 48 | 
            +
            def experts_per_rank(
         | 
| 49 | 
            +
                moe_num_experts: int,
         | 
| 50 | 
            +
                world_size: int,
         | 
| 51 | 
            +
            ) -> int:
         | 
| 52 | 
            +
                return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            # Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
         | 
| 56 | 
            +
            def features_per_rank(
         | 
| 57 | 
            +
                ffn_hidden_size: int, world_size: int, moe_num_experts: int
         | 
| 58 | 
            +
            ) -> int:
         | 
| 59 | 
            +
                return ffn_hidden_size // hidden_sharding_degree(
         | 
| 60 | 
            +
                    world_size, moe_num_experts, ffn_hidden_size
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            # Apply jitter to the input tensor
         | 
| 65 | 
            +
            def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
         | 
| 66 | 
            +
                low = 1.0 - moe_jitter_eps
         | 
| 67 | 
            +
                high = 1.0 + moe_jitter_eps
         | 
| 68 | 
            +
                noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
         | 
| 69 | 
            +
                return x * (low + noise * (high - low))
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            # Compute the top-k scores from the logits
         | 
| 73 | 
            +
            def compute_top_k(scores: torch.Tensor, moe_top_k: int):
         | 
| 74 | 
            +
                if moe_top_k == 1:
         | 
| 75 | 
            +
                    return scores.max(dim=-1, keepdim=True)
         | 
| 76 | 
            +
                return torch.topk(scores, moe_top_k, dim=-1)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            # Route tokens to experts and compute expert weights and indices
         | 
| 80 | 
            +
            def route_tokens(
         | 
| 81 | 
            +
                x: torch.Tensor,
         | 
| 82 | 
            +
                router_weight: torch.Tensor,
         | 
| 83 | 
            +
                moe_top_k: int,
         | 
| 84 | 
            +
                moe_num_experts: int,
         | 
| 85 | 
            +
                moe_jitter_eps: float = None,
         | 
| 86 | 
            +
                moe_normalize_expert_weights: int = None,
         | 
| 87 | 
            +
                uniform_expert_assignment: bool = False,
         | 
| 88 | 
            +
                training: bool = False,
         | 
| 89 | 
            +
            ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 90 | 
            +
                if training and moe_jitter_eps is not None:
         | 
| 91 | 
            +
                    x = apply_jitter(x, moe_jitter_eps)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                x_flat = x.view(-1, x.shape[-1])
         | 
| 94 | 
            +
                logits = torch.nn.functional.linear(x_flat, router_weight)
         | 
| 95 | 
            +
                expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
         | 
| 96 | 
            +
                expert_weights = expert_weights.softmax(dim=-1)
         | 
| 97 | 
            +
                if moe_normalize_expert_weights is not None:
         | 
| 98 | 
            +
                    expert_weights = expert_weights / torch.norm(
         | 
| 99 | 
            +
                        expert_weights,
         | 
| 100 | 
            +
                        p=moe_normalize_expert_weights,
         | 
| 101 | 
            +
                        dim=-1,
         | 
| 102 | 
            +
                        keepdim=True,
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                if uniform_expert_assignment:
         | 
| 105 | 
            +
                    expert_indices = _layers.router._uniform_expert_assignment(
         | 
| 106 | 
            +
                        expert_indices,
         | 
| 107 | 
            +
                        moe_num_experts,
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                return logits, expert_weights, expert_indices
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            # Scale the gradient of the weights
         | 
| 114 | 
            +
            def scale_grad(
         | 
| 115 | 
            +
                w: torch.Tensor,
         | 
| 116 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 117 | 
            +
            ) -> torch.Tensor:
         | 
| 118 | 
            +
                if gradient_scale is None:
         | 
| 119 | 
            +
                    return w
         | 
| 120 | 
            +
                return _layers.mlp.scale_gradient(w, gradient_scale)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            # Forward pass for the MLP layer
         | 
| 124 | 
            +
            def mlp_forward(
         | 
| 125 | 
            +
                x: torch.Tensor,
         | 
| 126 | 
            +
                w1: torch.Tensor,
         | 
| 127 | 
            +
                w2: torch.Tensor,
         | 
| 128 | 
            +
                w1_bias: torch.Tensor,
         | 
| 129 | 
            +
                w2_bias: torch.Tensor,
         | 
| 130 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 131 | 
            +
                alpha: float = 1.702,
         | 
| 132 | 
            +
            ):
         | 
| 133 | 
            +
                # Scale weights
         | 
| 134 | 
            +
                w1 = scale_grad(w1, gradient_scale)
         | 
| 135 | 
            +
                w2 = scale_grad(w2, gradient_scale)
         | 
| 136 | 
            +
                w1_bias = scale_grad(w1_bias, gradient_scale)
         | 
| 137 | 
            +
                w2_bias = scale_grad(w2_bias, gradient_scale)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                # Resolve dtensors
         | 
| 140 | 
            +
                w1 = _layers.mlp.resolve_dtensor(w1)
         | 
| 141 | 
            +
                w2 = _layers.mlp.resolve_dtensor(w2)
         | 
| 142 | 
            +
                w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
         | 
| 143 | 
            +
                w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                # Forward pass
         | 
| 146 | 
            +
                gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
         | 
| 147 | 
            +
                gate, up = gate_up.chunk(2, dim=-1)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                glu = gate * torch.sigmoid(gate * alpha)
         | 
| 150 | 
            +
                x = (up + 1) * glu
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                return torch.bmm(x, w2) + w2_bias[..., None, :]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            # Shared expert MLP forward pass
         | 
| 156 | 
            +
            def shared_mlp_forward(
         | 
| 157 | 
            +
                x: torch.Tensor,
         | 
| 158 | 
            +
                up_proj_weight: torch.Tensor,
         | 
| 159 | 
            +
                down_proj_weight: torch.Tensor,
         | 
| 160 | 
            +
                up_proj_bias: Optional[torch.Tensor] = None,
         | 
| 161 | 
            +
                down_proj_bias: Optional[torch.Tensor] = None,
         | 
| 162 | 
            +
                activation_fn: Optional[Any] = None,
         | 
| 163 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 164 | 
            +
            ) -> torch.Tensor:
         | 
| 165 | 
            +
                # Default activation function
         | 
| 166 | 
            +
                if activation_fn is None:
         | 
| 167 | 
            +
                    activation_fn = torch.nn.functional.gelu
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                # Scale weights
         | 
| 170 | 
            +
                up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
         | 
| 171 | 
            +
                down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
         | 
| 172 | 
            +
                if up_proj_bias is not None:
         | 
| 173 | 
            +
                    up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
         | 
| 174 | 
            +
                if down_proj_bias is not None:
         | 
| 175 | 
            +
                    down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                # Resolve dtensors
         | 
| 178 | 
            +
                up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
         | 
| 179 | 
            +
                down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
         | 
| 180 | 
            +
                if up_proj_bias is not None:
         | 
| 181 | 
            +
                    up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
         | 
| 182 | 
            +
                if down_proj_bias is not None:
         | 
| 183 | 
            +
                    down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                # Up projection
         | 
| 186 | 
            +
                x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                # Activation
         | 
| 189 | 
            +
                x = activation_fn(x)
         | 
| 190 | 
            +
                
         | 
| 191 | 
            +
                # Down projection
         | 
| 192 | 
            +
                x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
         | 
| 193 | 
            +
                
         | 
| 194 | 
            +
                return x
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            # Combine outputs from shared expert and regular experts
         | 
| 198 | 
            +
            def combine_expert_shared_outputs(
         | 
| 199 | 
            +
                shared_expert_out: torch.Tensor,
         | 
| 200 | 
            +
                expert_out: torch.Tensor,
         | 
| 201 | 
            +
                shared_expert_weighted_sum: bool = False,
         | 
| 202 | 
            +
                moe_top_k: int = 1,
         | 
| 203 | 
            +
            ) -> torch.Tensor:
         | 
| 204 | 
            +
                if shared_expert_weighted_sum:
         | 
| 205 | 
            +
                    # Weighted sum based on number of experts used
         | 
| 206 | 
            +
                    total_experts = moe_top_k + 1
         | 
| 207 | 
            +
                    shared_weight = 1.0 / total_experts
         | 
| 208 | 
            +
                    expert_weight = moe_top_k / total_experts
         | 
| 209 | 
            +
                    return shared_expert_out * shared_weight + expert_out * expert_weight
         | 
| 210 | 
            +
                else:
         | 
| 211 | 
            +
                    # Simple addition
         | 
| 212 | 
            +
                    return shared_expert_out + expert_out
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
            # Global variable to store load balancing loss
         | 
| 216 | 
            +
            _LOAD_BALANCING_LOSS = []
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            def save_load_balancing_loss(loss):
         | 
| 220 | 
            +
                global _LOAD_BALANCING_LOSS
         | 
| 221 | 
            +
                _LOAD_BALANCING_LOSS.append(loss)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
             | 
| 224 | 
            +
            def get_load_balancing_loss():
         | 
| 225 | 
            +
                global _LOAD_BALANCING_LOSS
         | 
| 226 | 
            +
                return _LOAD_BALANCING_LOSS
         | 
| 227 | 
            +
             | 
| 228 | 
            +
             | 
| 229 | 
            +
            def clear_load_balancing_loss():
         | 
| 230 | 
            +
                global _LOAD_BALANCING_LOSS
         | 
| 231 | 
            +
                _LOAD_BALANCING_LOSS.clear()
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            def batched_load_balancing_loss(args):
         | 
| 235 | 
            +
                if args.moe_loss_weight == 0:
         | 
| 236 | 
            +
                    return 0.0
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
         | 
| 239 | 
            +
                num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
         | 
| 240 | 
            +
                if args.num_layers_per_virtual_pipeline_stage is not None:
         | 
| 241 | 
            +
                    num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                if len(tokens_per_expert) != num_layers_per_pipeline_stage:
         | 
| 244 | 
            +
                    raise ValueError(
         | 
| 245 | 
            +
                        f"Expected {num_layers_per_pipeline_stage} token_per_experts "
         | 
| 246 | 
            +
                        f"but found {len(tokens_per_expert)}.\nnum_layers = "
         | 
| 247 | 
            +
                        f"{args.num_layers}\npipeline_model_parallel_size = "
         | 
| 248 | 
            +
                        f"{args.pipeline_model_parallel_size}\n"
         | 
| 249 | 
            +
                        "num_layers_per_virtual_pipeline_stage"
         | 
| 250 | 
            +
                        f" = {args.num_layers_per_virtual_pipeline_stage}",
         | 
| 251 | 
            +
                    )
         | 
| 252 | 
            +
                if len(expert_scores) != num_layers_per_pipeline_stage:
         | 
| 253 | 
            +
                    raise ValueError(
         | 
| 254 | 
            +
                        f"Expected {num_layers_per_pipeline_stage} expert_scores "
         | 
| 255 | 
            +
                        f"but found {len(tokens_per_expert)}.\nnum_layers = "
         | 
| 256 | 
            +
                        f"{args.num_layers}\npipeline_model_parallel_size = "
         | 
| 257 | 
            +
                        f"{args.pipeline_model_parallel_size}\n"
         | 
| 258 | 
            +
                        "num_layers_per_virtual_pipeline_stage"
         | 
| 259 | 
            +
                        f" = {args.num_layers_per_virtual_pipeline_stage}",
         | 
| 260 | 
            +
                    )
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                # Verify the shape of the tokens_per_expert and expert_scores tensors.
         | 
| 263 | 
            +
                assert all(
         | 
| 264 | 
            +
                    (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
         | 
| 265 | 
            +
                )
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                tokens = expert_scores[0].shape[0]
         | 
| 268 | 
            +
                assert all(
         | 
| 269 | 
            +
                    (
         | 
| 270 | 
            +
                        (
         | 
| 271 | 
            +
                            x.ndim == 2
         | 
| 272 | 
            +
                            and x.shape[1] == args.moe_num_experts
         | 
| 273 | 
            +
                            and x.shape[0] == tokens
         | 
| 274 | 
            +
                        )
         | 
| 275 | 
            +
                        for x in expert_scores
         | 
| 276 | 
            +
                    )
         | 
| 277 | 
            +
                )
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                # Concatenate the contributions of each layer and convert to
         | 
| 280 | 
            +
                # the correct types and formats for the dot product.
         | 
| 281 | 
            +
                expert_scores = torch.cat(expert_scores, dim=1)
         | 
| 282 | 
            +
                if args.moe_lbl_in_fp32:
         | 
| 283 | 
            +
                    expert_scores = expert_scores.float()
         | 
| 284 | 
            +
                if tokens != 0:
         | 
| 285 | 
            +
                    expert_scores = expert_scores.mean(dim=0)
         | 
| 286 | 
            +
                else:
         | 
| 287 | 
            +
                    expert_scores = expert_scores.sum(dim=0)
         | 
| 288 | 
            +
                tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
         | 
| 291 | 
            +
                assert tokens_per_expert.numel() == expected_values
         | 
| 292 | 
            +
                assert expert_scores.numel() == expected_values
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                # Calculate the total scale across all factors.
         | 
| 295 | 
            +
                #
         | 
| 296 | 
            +
                # loss_weight * num_experts / (num_layers * tokens * top_k)
         | 
| 297 | 
            +
                scale_numerator = args.moe_num_experts * args.moe_loss_weight
         | 
| 298 | 
            +
                scale_denominator = args.num_layers * tokens * args.moe_top_k
         | 
| 299 | 
            +
                scale = scale_numerator / scale_denominator
         | 
| 300 | 
            +
                return scale * torch.dot(tokens_per_expert, expert_scores)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            # Calculate the expert capacity based on tokens, top_k, number of experts,
         | 
| 304 | 
            +
            # expert parallel group, capacity factor, and whether expert model parallelism is used.
         | 
| 305 | 
            +
            def expert_capacity(
         | 
| 306 | 
            +
                tokens: int,
         | 
| 307 | 
            +
                top_k: int,
         | 
| 308 | 
            +
                num_experts: int,
         | 
| 309 | 
            +
                expert_parallel_group: int,
         | 
| 310 | 
            +
                moe_capacity_factor: float,
         | 
| 311 | 
            +
                moe_expert_model_parallelism: bool,
         | 
| 312 | 
            +
            ) -> int:
         | 
| 313 | 
            +
                world_size = (
         | 
| 314 | 
            +
                    dist.get_world_size(expert_parallel_group)
         | 
| 315 | 
            +
                    if moe_expert_model_parallelism
         | 
| 316 | 
            +
                    else 1
         | 
| 317 | 
            +
                )
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                tokens_per_expert = top_k * tokens * world_size / num_experts
         | 
| 320 | 
            +
                return int(moe_capacity_factor * tokens_per_expert)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
             | 
| 323 | 
            +
            def load_balancing_loss(
         | 
| 324 | 
            +
                tokens_per_expert: torch.Tensor,
         | 
| 325 | 
            +
                expert_scores: torch.Tensor,
         | 
| 326 | 
            +
                top_k: int,
         | 
| 327 | 
            +
                num_experts: int,
         | 
| 328 | 
            +
            ):
         | 
| 329 | 
            +
                assert len(expert_scores.size()) == 2
         | 
| 330 | 
            +
                tokens, num_experts = expert_scores.size()
         | 
| 331 | 
            +
                assert num_experts == num_experts
         | 
| 332 | 
            +
                assert len(tokens_per_expert.size()) == 1
         | 
| 333 | 
            +
                (num_experts,) = tokens_per_expert.size()
         | 
| 334 | 
            +
                assert num_experts == num_experts
         | 
| 335 | 
            +
                scale = num_experts / (tokens * top_k)
         | 
| 336 | 
            +
                return scale * torch.dot(
         | 
| 337 | 
            +
                    tokens_per_expert.to(expert_scores.dtype),
         | 
| 338 | 
            +
                    expert_scores.mean(dim=0),
         | 
| 339 | 
            +
                )
         | 
| 340 | 
            +
             | 
| 341 | 
            +
             | 
| 342 | 
            +
            def indices_and_bins(
         | 
| 343 | 
            +
                top_expert: torch.Tensor,
         | 
| 344 | 
            +
                sort_end_bit: int,
         | 
| 345 | 
            +
                num_experts: int,
         | 
| 346 | 
            +
            ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 347 | 
            +
                top_expert = top_expert.int()
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                # Ensure contiguous memory layout
         | 
| 350 | 
            +
                top_expert = top_expert.contiguous()
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                # Ensure CUB knows which device to use
         | 
| 353 | 
            +
                with torch.cuda.device(top_expert.device):
         | 
| 354 | 
            +
                    output = ops.sort(top_expert, sort_end_bit)
         | 
| 355 | 
            +
                    bin_ids, indices = output
         | 
| 356 | 
            +
                    tokens_per_expert = ops.histogram(top_expert, num_experts)
         | 
| 357 | 
            +
                    bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                bins = bins.view(1) if not len(bins.size()) else bins
         | 
| 360 | 
            +
                return indices, bin_ids, bins, tokens_per_expert
         | 
| 361 | 
            +
             | 
| 362 | 
            +
             | 
| 363 | 
            +
            def expert_capacity_fn(
         | 
| 364 | 
            +
                tokens: int,
         | 
| 365 | 
            +
                top_k: int,
         | 
| 366 | 
            +
                num_experts: int,
         | 
| 367 | 
            +
                expert_parallel_group: torch.distributed.ProcessGroup,
         | 
| 368 | 
            +
                moe_capacity_factor: float = 1.0,
         | 
| 369 | 
            +
                moe_expert_model_parallelism: bool = False,
         | 
| 370 | 
            +
            ) -> int:
         | 
| 371 | 
            +
                world_size = (
         | 
| 372 | 
            +
                    dist.get_world_size(expert_parallel_group)
         | 
| 373 | 
            +
                    if moe_expert_model_parallelism
         | 
| 374 | 
            +
                    else 1
         | 
| 375 | 
            +
                )
         | 
| 376 | 
            +
                tokens_per_expert = top_k * tokens * world_size / num_experts
         | 
| 377 | 
            +
                return int(moe_capacity_factor * tokens_per_expert)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
            def permute_and_compute(
         | 
| 381 | 
            +
                x,
         | 
| 382 | 
            +
                tokens_per_expert,
         | 
| 383 | 
            +
                indices,
         | 
| 384 | 
            +
                bin_ids,
         | 
| 385 | 
            +
                expert_weights,
         | 
| 386 | 
            +
                bins,
         | 
| 387 | 
            +
                expert_capacity,
         | 
| 388 | 
            +
                top_k,
         | 
| 389 | 
            +
                w1,
         | 
| 390 | 
            +
                w2,
         | 
| 391 | 
            +
                w1_bias,
         | 
| 392 | 
            +
                w2_bias,
         | 
| 393 | 
            +
                gradient_scale,
         | 
| 394 | 
            +
                alpha,
         | 
| 395 | 
            +
            ):
         | 
| 396 | 
            +
                # Route tokens to experts
         | 
| 397 | 
            +
                x = x.view(-1, x.shape[-1])
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                # Ensure CUB knows which device to use
         | 
| 400 | 
            +
                with torch.cuda.device(x.device):
         | 
| 401 | 
            +
                    x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                # Expert computation
         | 
| 404 | 
            +
                x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                # Ensure CUB knows which device to use
         | 
| 407 | 
            +
                with torch.cuda.device(x.device):
         | 
| 408 | 
            +
                    # Route tokens back
         | 
| 409 | 
            +
                    out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
         | 
| 410 | 
            +
                return out
         | 
| 411 | 
            +
             | 
| 412 | 
            +
             | 
| 413 | 
            +
            def forward_once(
         | 
| 414 | 
            +
                x: torch.Tensor,
         | 
| 415 | 
            +
                expert_weights: torch.Tensor,
         | 
| 416 | 
            +
                top_experts: torch.Tensor,
         | 
| 417 | 
            +
                w1: torch.Tensor,
         | 
| 418 | 
            +
                w2: torch.Tensor,
         | 
| 419 | 
            +
                w1_bias: torch.Tensor,
         | 
| 420 | 
            +
                w2_bias: torch.Tensor,
         | 
| 421 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 422 | 
            +
                alpha: float = 1.702,
         | 
| 423 | 
            +
                sort_end_bit: int = 0,
         | 
| 424 | 
            +
                top_k: int = 4,
         | 
| 425 | 
            +
                num_experts: int = 128,
         | 
| 426 | 
            +
                expert_parallel_group: int = None,
         | 
| 427 | 
            +
                moe_capacity_factor: float = 1.0,
         | 
| 428 | 
            +
                moe_expert_model_parallelism: bool = False,
         | 
| 429 | 
            +
                mlp_impl: Optional[str] = None,
         | 
| 430 | 
            +
            ):
         | 
| 431 | 
            +
                # x: [sl, bs, hs]
         | 
| 432 | 
            +
                # expert_weights: [sl * bs, top-k]
         | 
| 433 | 
            +
                # top_experts: [sl * bs, top-k]
         | 
| 434 | 
            +
                expert_weights = expert_weights.flatten()
         | 
| 435 | 
            +
                top_experts = top_experts.flatten()
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                with torch.no_grad():
         | 
| 438 | 
            +
                    indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
         | 
| 439 | 
            +
                        top_experts, sort_end_bit, num_experts
         | 
| 440 | 
            +
                    )
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    # Calculate expert capacity
         | 
| 443 | 
            +
                    sl, bs, _ = x.size()
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    expert_capacity = expert_capacity_fn(
         | 
| 446 | 
            +
                        sl * bs,
         | 
| 447 | 
            +
                        top_k,
         | 
| 448 | 
            +
                        num_experts,
         | 
| 449 | 
            +
                        expert_parallel_group,
         | 
| 450 | 
            +
                        moe_capacity_factor,
         | 
| 451 | 
            +
                        moe_expert_model_parallelism,
         | 
| 452 | 
            +
                    )
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    if expert_capacity == 0:
         | 
| 455 | 
            +
                        expert_capacity = torch.max(tokens_per_expert).item()
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                x = permute_and_compute(
         | 
| 458 | 
            +
                    x,
         | 
| 459 | 
            +
                    tokens_per_expert,
         | 
| 460 | 
            +
                    indices,
         | 
| 461 | 
            +
                    bin_ids,
         | 
| 462 | 
            +
                    expert_weights,
         | 
| 463 | 
            +
                    bins,
         | 
| 464 | 
            +
                    expert_capacity,
         | 
| 465 | 
            +
                    top_k,
         | 
| 466 | 
            +
                    w1,
         | 
| 467 | 
            +
                    w2,
         | 
| 468 | 
            +
                    w1_bias,
         | 
| 469 | 
            +
                    w2_bias,
         | 
| 470 | 
            +
                    gradient_scale,
         | 
| 471 | 
            +
                    alpha,
         | 
| 472 | 
            +
                )
         | 
| 473 | 
            +
                return x, tokens_per_expert
         | 
| 474 | 
            +
             | 
| 475 | 
            +
             | 
| 476 | 
            +
            def parallel_forward_once(
         | 
| 477 | 
            +
                x: torch.Tensor,
         | 
| 478 | 
            +
                expert_weights: torch.Tensor,
         | 
| 479 | 
            +
                top_experts: torch.Tensor,
         | 
| 480 | 
            +
                w1: torch.Tensor,
         | 
| 481 | 
            +
                w2: torch.Tensor,
         | 
| 482 | 
            +
                w1_bias: torch.Tensor,
         | 
| 483 | 
            +
                w2_bias: torch.Tensor,
         | 
| 484 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 485 | 
            +
                alpha: float = 1.702,
         | 
| 486 | 
            +
                sort_end_bit: int = 0,
         | 
| 487 | 
            +
                top_k: int = 4,
         | 
| 488 | 
            +
                num_experts: int = 128,
         | 
| 489 | 
            +
                expert_parallel_group: torch.distributed.ProcessGroup = None,
         | 
| 490 | 
            +
                moe_capacity_factor: float = 1.0,
         | 
| 491 | 
            +
                moe_expert_model_parallelism: bool = True,
         | 
| 492 | 
            +
                hidden_size: int = 1152,
         | 
| 493 | 
            +
                mlp_impl: Optional[str] = "grouped",
         | 
| 494 | 
            +
            ):
         | 
| 495 | 
            +
                # Flatten inputs
         | 
| 496 | 
            +
                expert_weights = expert_weights.flatten()
         | 
| 497 | 
            +
                top_experts = top_experts.flatten()
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                # TODO: remove debugging var
         | 
| 500 | 
            +
                # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                with torch.no_grad():
         | 
| 503 | 
            +
                    # Step 1: Local permutation setup
         | 
| 504 | 
            +
                    indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
         | 
| 505 | 
            +
                        top_experts, sort_end_bit, num_experts
         | 
| 506 | 
            +
                    )
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                    # Calculate sharding parameters
         | 
| 509 | 
            +
                    world_size = dist.get_world_size(expert_parallel_group)
         | 
| 510 | 
            +
                    hidden_sharding_deg = hidden_sharding_degree(
         | 
| 511 | 
            +
                        world_size, num_experts, hidden_size
         | 
| 512 | 
            +
                    )
         | 
| 513 | 
            +
                    experts_per_rank_val = experts_per_rank(num_experts, world_size)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    # Replicate token counts for hidden sharding
         | 
| 516 | 
            +
                    repeated_tokens_per_expert = ops.repeat(
         | 
| 517 | 
            +
                        tokens_per_expert, (hidden_sharding_deg,)
         | 
| 518 | 
            +
                    )
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    # Exchange token counts across devices
         | 
| 521 | 
            +
                    parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    # Ensure CUB knows which device to use
         | 
| 524 | 
            +
                    tpe_handle = dist.all_to_all_single(
         | 
| 525 | 
            +
                        parallel_tokens_per_expert,
         | 
| 526 | 
            +
                        repeated_tokens_per_expert,
         | 
| 527 | 
            +
                        group=expert_parallel_group,
         | 
| 528 | 
            +
                        async_op=True,
         | 
| 529 | 
            +
                    )
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                # Step 2: Local permutation - group tokens by target device
         | 
| 532 | 
            +
                x = x.view(-1, x.shape[-1])  # [sl * bs, hs]
         | 
| 533 | 
            +
                x = ops.gather(x, indices, bin_ids, bins, top_k)
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                # Step 3: Compute communication counts and exchange tokens
         | 
| 536 | 
            +
                with torch.no_grad():
         | 
| 537 | 
            +
                    tpe_handle.wait()
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # Reshape for per-device calculations
         | 
| 540 | 
            +
                    repeated_tokens_per_expert = repeated_tokens_per_expert.view(
         | 
| 541 | 
            +
                        world_size, experts_per_rank_val
         | 
| 542 | 
            +
                    )
         | 
| 543 | 
            +
                    parallel_tokens_per_expert = parallel_tokens_per_expert.view(
         | 
| 544 | 
            +
                        world_size, experts_per_rank_val
         | 
| 545 | 
            +
                    )
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                    # Calculate send/recv counts
         | 
| 548 | 
            +
                    send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
         | 
| 549 | 
            +
                    # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
         | 
| 550 | 
            +
                    parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
         | 
| 551 | 
            +
                    recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
         | 
| 552 | 
            +
                    tokens_received = sum(recv_counts)
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                # Replicate for hidden sharding
         | 
| 555 | 
            +
                x = ops.repeat(x, (hidden_sharding_deg, 1))
         | 
| 556 | 
            +
             | 
| 557 | 
            +
                # Cross-device token exchange
         | 
| 558 | 
            +
                parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
         | 
| 559 | 
            +
                    x, recv_counts, send_counts, expert_parallel_group, async_op=True
         | 
| 560 | 
            +
                )
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                with torch.no_grad():
         | 
| 563 | 
            +
                    # Step 4: Setup for local expert computation
         | 
| 564 | 
            +
                    replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
         | 
| 565 | 
            +
                    replicate_bins = (
         | 
| 566 | 
            +
                        replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
         | 
| 567 | 
            +
                    )
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                    # Create expert indices for received tokens
         | 
| 570 | 
            +
                    parallel_top_expert = torch.remainder(
         | 
| 571 | 
            +
                        torch.arange(
         | 
| 572 | 
            +
                            num_experts * hidden_sharding_deg,
         | 
| 573 | 
            +
                            dtype=torch.int32,
         | 
| 574 | 
            +
                            device=indices.device,
         | 
| 575 | 
            +
                        ),
         | 
| 576 | 
            +
                        experts_per_rank_val,
         | 
| 577 | 
            +
                    )
         | 
| 578 | 
            +
                    parallel_top_expert = ops.replicate(
         | 
| 579 | 
            +
                        parallel_top_expert.unsqueeze(dim=0),
         | 
| 580 | 
            +
                        replicate_bins,
         | 
| 581 | 
            +
                        tokens_received,
         | 
| 582 | 
            +
                    ).flatten()
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    # Sort tokens by expert assignment
         | 
| 585 | 
            +
                    parallel_bin_ids, parallel_indices = ops.sort(
         | 
| 586 | 
            +
                        parallel_top_expert,
         | 
| 587 | 
            +
                        sort_end_bit,
         | 
| 588 | 
            +
                    )
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                    # Calculate bins for local experts
         | 
| 591 | 
            +
                    parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
         | 
| 592 | 
            +
                        dim=0, dtype=torch.int
         | 
| 593 | 
            +
                    )
         | 
| 594 | 
            +
                    parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
         | 
| 595 | 
            +
                    parallel_bins = (
         | 
| 596 | 
            +
                        parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
         | 
| 597 | 
            +
                    )
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    # Calculate expert capacity
         | 
| 600 | 
            +
                    expert_capacity = expert_capacity_fn(
         | 
| 601 | 
            +
                        tokens_received,
         | 
| 602 | 
            +
                        top_k,
         | 
| 603 | 
            +
                        experts_per_rank_val,
         | 
| 604 | 
            +
                        expert_parallel_group,
         | 
| 605 | 
            +
                        moe_capacity_factor,
         | 
| 606 | 
            +
                        moe_expert_model_parallelism,
         | 
| 607 | 
            +
                    )
         | 
| 608 | 
            +
                    if expert_capacity == 0:
         | 
| 609 | 
            +
                        expert_capacity = torch.max(parallel_tokens_per_expert).item()
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                # Locally permute the tokens and perform the expert computation.
         | 
| 612 | 
            +
                # Block to make sure that the cross-device permutation is complete.
         | 
| 613 | 
            +
                if mlp_impl == "grouped":
         | 
| 614 | 
            +
                    # GroupedMLP requires counts on CPU. We can use the tensor already
         | 
| 615 | 
            +
                    # moved to CPU for the prior all_to_all, which avoids an extra
         | 
| 616 | 
            +
                    # device synchronization.
         | 
| 617 | 
            +
                    parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
         | 
| 618 | 
            +
                        dim=0,
         | 
| 619 | 
            +
                        dtype=torch.int,
         | 
| 620 | 
            +
                    )
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                # Step 5: Expert computation
         | 
| 623 | 
            +
                parallel_x_handle.wait()
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                parallel_x = permute_and_compute(
         | 
| 626 | 
            +
                    parallel_x,
         | 
| 627 | 
            +
                    parallel_tokens_per_expert,
         | 
| 628 | 
            +
                    parallel_indices,
         | 
| 629 | 
            +
                    parallel_bin_ids,
         | 
| 630 | 
            +
                    None,  # expert_weights
         | 
| 631 | 
            +
                    parallel_bins,
         | 
| 632 | 
            +
                    expert_capacity,
         | 
| 633 | 
            +
                    top_k=1,
         | 
| 634 | 
            +
                    w1=w1,
         | 
| 635 | 
            +
                    w2=w2,
         | 
| 636 | 
            +
                    w1_bias=w1_bias,
         | 
| 637 | 
            +
                    w2_bias=w2_bias,
         | 
| 638 | 
            +
                    gradient_scale=gradient_scale,
         | 
| 639 | 
            +
                    alpha=alpha,
         | 
| 640 | 
            +
                )
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                # Step 6: Reverse communication - send results back
         | 
| 643 | 
            +
                x, _ = _layers.all_to_all.all_to_all(
         | 
| 644 | 
            +
                    parallel_x, send_counts, recv_counts, expert_parallel_group
         | 
| 645 | 
            +
                )
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                # Step 7: Reduce across hidden sharding dimension
         | 
| 648 | 
            +
                shape = (hidden_sharding_deg, -1, hidden_size)
         | 
| 649 | 
            +
                x = x.view(shape).sum(dim=0)
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                # Step 8: Final local unpermutation
         | 
| 652 | 
            +
                x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
         | 
| 653 | 
            +
             | 
| 654 | 
            +
                return x, tokens_per_expert.flatten()
         | 
| 655 | 
            +
             | 
| 656 | 
            +
             | 
| 657 | 
            +
            def moe_forward(
         | 
| 658 | 
            +
                x: torch.Tensor,
         | 
| 659 | 
            +
                router_weight: torch.Tensor,
         | 
| 660 | 
            +
                moe_top_k: int,
         | 
| 661 | 
            +
                moe_num_experts: int,
         | 
| 662 | 
            +
                moe_jitter_eps: float = None,
         | 
| 663 | 
            +
                moe_normalize_expert_weights: int = None,
         | 
| 664 | 
            +
                uniform_expert_assignment: bool = False,
         | 
| 665 | 
            +
                training: bool = False,
         | 
| 666 | 
            +
                w1: torch.Tensor = None,
         | 
| 667 | 
            +
                w2: torch.Tensor = None,
         | 
| 668 | 
            +
                w1_bias: torch.Tensor = None,
         | 
| 669 | 
            +
                w2_bias: torch.Tensor = None,
         | 
| 670 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 671 | 
            +
                alpha: float = 1.702,
         | 
| 672 | 
            +
                sort_end_bit: int = 0,
         | 
| 673 | 
            +
                expert_parallel_group: torch.distributed.ProcessGroup = None,
         | 
| 674 | 
            +
                moe_capacity_factor: float = 1.0,
         | 
| 675 | 
            +
                moe_expert_model_parallelism: bool = False,
         | 
| 676 | 
            +
                forward_fn: Any = None,
         | 
| 677 | 
            +
                hidden_size: int = None,
         | 
| 678 | 
            +
                mlp_impl: str = "grouped",
         | 
| 679 | 
            +
            ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 680 | 
            +
             | 
| 681 | 
            +
                # Route tokens to experts
         | 
| 682 | 
            +
                logits, expert_weights, expert_indices = route_tokens(
         | 
| 683 | 
            +
                    x,
         | 
| 684 | 
            +
                    router_weight,
         | 
| 685 | 
            +
                    moe_top_k,
         | 
| 686 | 
            +
                    moe_num_experts,
         | 
| 687 | 
            +
                    moe_jitter_eps,
         | 
| 688 | 
            +
                    moe_normalize_expert_weights,
         | 
| 689 | 
            +
                    uniform_expert_assignment,
         | 
| 690 | 
            +
                    training,
         | 
| 691 | 
            +
                )
         | 
| 692 | 
            +
             | 
| 693 | 
            +
                # Create router scores for output
         | 
| 694 | 
            +
                router_scores = (
         | 
| 695 | 
            +
                    torch.zeros_like(logits)
         | 
| 696 | 
            +
                    .scatter_(1, expert_indices, expert_weights)
         | 
| 697 | 
            +
                    .transpose(0, 1)
         | 
| 698 | 
            +
                )
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                in_shape = x.size()
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                # Prepare forward function arguments
         | 
| 703 | 
            +
                forward_args = {
         | 
| 704 | 
            +
                    "x": x,
         | 
| 705 | 
            +
                    "expert_weights": expert_weights,
         | 
| 706 | 
            +
                    "top_experts": expert_indices,
         | 
| 707 | 
            +
                    "w1": w1,
         | 
| 708 | 
            +
                    "w2": w2,
         | 
| 709 | 
            +
                    "w1_bias": w1_bias,
         | 
| 710 | 
            +
                    "w2_bias": w2_bias,
         | 
| 711 | 
            +
                    "gradient_scale": gradient_scale,
         | 
| 712 | 
            +
                    "alpha": alpha,
         | 
| 713 | 
            +
                    "sort_end_bit": sort_end_bit,
         | 
| 714 | 
            +
                    "top_k": moe_top_k,
         | 
| 715 | 
            +
                    "num_experts": moe_num_experts,
         | 
| 716 | 
            +
                    "expert_parallel_group": expert_parallel_group,
         | 
| 717 | 
            +
                    "moe_capacity_factor": moe_capacity_factor,
         | 
| 718 | 
            +
                    "moe_expert_model_parallelism": moe_expert_model_parallelism,
         | 
| 719 | 
            +
                    "mlp_impl": mlp_impl,
         | 
| 720 | 
            +
                }
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                # Add hidden_size for parallel forward
         | 
| 723 | 
            +
                if moe_expert_model_parallelism and hidden_size is not None:
         | 
| 724 | 
            +
                    forward_args["hidden_size"] = hidden_size
         | 
| 725 | 
            +
                elif moe_expert_model_parallelism and hidden_size is None:
         | 
| 726 | 
            +
                    # Infer hidden_size from input shape
         | 
| 727 | 
            +
                    forward_args["hidden_size"] = x.shape[-1]
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                # Compute expert outputs
         | 
| 730 | 
            +
                x, tokens_per_expert = forward_fn(**forward_args)
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                # Save load balancing loss if needed
         | 
| 733 | 
            +
                moe_loss_weight = 0.0  # Can be made configurable
         | 
| 734 | 
            +
                if training and moe_loss_weight > 0:
         | 
| 735 | 
            +
                    save_load_balancing_loss((tokens_per_expert, logits))
         | 
| 736 | 
            +
             | 
| 737 | 
            +
                # Restore original shape
         | 
| 738 | 
            +
                x = x.view(in_shape)
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                return x, expert_weights, router_scores
         | 
| 741 | 
            +
             | 
| 742 | 
            +
             | 
| 743 | 
            +
            def moe_forward_with_shared_expert(
         | 
| 744 | 
            +
                x: torch.Tensor,
         | 
| 745 | 
            +
                router_weight: torch.Tensor,
         | 
| 746 | 
            +
                moe_top_k: int,
         | 
| 747 | 
            +
                moe_num_experts: int,
         | 
| 748 | 
            +
                moe_jitter_eps: float = None,
         | 
| 749 | 
            +
                moe_normalize_expert_weights: int = None,
         | 
| 750 | 
            +
                uniform_expert_assignment: bool = False,
         | 
| 751 | 
            +
                training: bool = False,
         | 
| 752 | 
            +
                w1: torch.Tensor = None,
         | 
| 753 | 
            +
                w2: torch.Tensor = None,
         | 
| 754 | 
            +
                w1_bias: torch.Tensor = None,
         | 
| 755 | 
            +
                w2_bias: torch.Tensor = None,
         | 
| 756 | 
            +
                gradient_scale: Optional[float] = None,
         | 
| 757 | 
            +
                alpha: float = 1.702,
         | 
| 758 | 
            +
                sort_end_bit: int = 0,
         | 
| 759 | 
            +
                expert_parallel_group: torch.distributed.ProcessGroup = None,
         | 
| 760 | 
            +
                moe_capacity_factor: float = 1.0,
         | 
| 761 | 
            +
                moe_expert_model_parallelism: bool = False,
         | 
| 762 | 
            +
                forward_fn: Any = None,
         | 
| 763 | 
            +
                hidden_size: int = None,
         | 
| 764 | 
            +
                mlp_impl: str = "grouped",
         | 
| 765 | 
            +
                # Shared expert parameters
         | 
| 766 | 
            +
                shared_up_proj_weight: Optional[torch.Tensor] = None,
         | 
| 767 | 
            +
                shared_down_proj_weight: Optional[torch.Tensor] = None,
         | 
| 768 | 
            +
                shared_up_proj_bias: Optional[torch.Tensor] = None,
         | 
| 769 | 
            +
                shared_down_proj_bias: Optional[torch.Tensor] = None,
         | 
| 770 | 
            +
                shared_expert_weighted_sum: bool = False,
         | 
| 771 | 
            +
                shared_activation_fn: Optional[Any] = None,
         | 
| 772 | 
            +
            ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                # First, compute regular MoE forward pass
         | 
| 775 | 
            +
                expert_out, expert_weights, router_scores = moe_forward(
         | 
| 776 | 
            +
                    x=x,
         | 
| 777 | 
            +
                    router_weight=router_weight,
         | 
| 778 | 
            +
                    moe_top_k=moe_top_k,
         | 
| 779 | 
            +
                    moe_num_experts=moe_num_experts,
         | 
| 780 | 
            +
                    moe_jitter_eps=moe_jitter_eps,
         | 
| 781 | 
            +
                    moe_normalize_expert_weights=moe_normalize_expert_weights,
         | 
| 782 | 
            +
                    uniform_expert_assignment=uniform_expert_assignment,
         | 
| 783 | 
            +
                    training=training,
         | 
| 784 | 
            +
                    w1=w1,
         | 
| 785 | 
            +
                    w2=w2,
         | 
| 786 | 
            +
                    w1_bias=w1_bias,
         | 
| 787 | 
            +
                    w2_bias=w2_bias,
         | 
| 788 | 
            +
                    gradient_scale=gradient_scale,
         | 
| 789 | 
            +
                    alpha=alpha,
         | 
| 790 | 
            +
                    sort_end_bit=sort_end_bit,
         | 
| 791 | 
            +
                    expert_parallel_group=expert_parallel_group,
         | 
| 792 | 
            +
                    moe_capacity_factor=moe_capacity_factor,
         | 
| 793 | 
            +
                    moe_expert_model_parallelism=moe_expert_model_parallelism,
         | 
| 794 | 
            +
                    forward_fn=forward_fn,
         | 
| 795 | 
            +
                    hidden_size=hidden_size,
         | 
| 796 | 
            +
                    mlp_impl=mlp_impl,
         | 
| 797 | 
            +
                )
         | 
| 798 | 
            +
                
         | 
| 799 | 
            +
                # If shared expert weights provided, compute shared expert output
         | 
| 800 | 
            +
                if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
         | 
| 801 | 
            +
                    shared_expert_out = shared_mlp_forward(
         | 
| 802 | 
            +
                        x=x,
         | 
| 803 | 
            +
                        up_proj_weight=shared_up_proj_weight,
         | 
| 804 | 
            +
                        down_proj_weight=shared_down_proj_weight,
         | 
| 805 | 
            +
                        up_proj_bias=shared_up_proj_bias,
         | 
| 806 | 
            +
                        down_proj_bias=shared_down_proj_bias,
         | 
| 807 | 
            +
                        activation_fn=shared_activation_fn,
         | 
| 808 | 
            +
                        gradient_scale=gradient_scale,
         | 
| 809 | 
            +
                    )
         | 
| 810 | 
            +
                    
         | 
| 811 | 
            +
                    # Combine expert outputs
         | 
| 812 | 
            +
                    combined_out = combine_expert_shared_outputs(
         | 
| 813 | 
            +
                        shared_expert_out=shared_expert_out,
         | 
| 814 | 
            +
                        expert_out=expert_out,
         | 
| 815 | 
            +
                        shared_expert_weighted_sum=shared_expert_weighted_sum,
         | 
| 816 | 
            +
                        moe_top_k=moe_top_k,
         | 
| 817 | 
            +
                    )
         | 
| 818 | 
            +
                    
         | 
| 819 | 
            +
                    return combined_out, expert_weights, router_scores
         | 
| 820 | 
            +
                
         | 
| 821 | 
            +
                # Return regular MoE output if no shared expert
         | 
| 822 | 
            +
                return expert_out, expert_weights, router_scores
         | 
| 823 | 
            +
             | 
| 824 | 
            +
             | 
| 825 | 
            +
            def create_shared_expert_weights(
         | 
| 826 | 
            +
                hidden_size: int,
         | 
| 827 | 
            +
                shared_expert_hidden_size: int,
         | 
| 828 | 
            +
                device: torch.device,
         | 
| 829 | 
            +
                dtype: torch.dtype,
         | 
| 830 | 
            +
                init_method: Any,
         | 
| 831 | 
            +
                output_layer_init_method: Any = None,
         | 
| 832 | 
            +
            ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
         | 
| 833 | 
            +
             | 
| 834 | 
            +
                if output_layer_init_method is None:
         | 
| 835 | 
            +
                    output_layer_init_method = init_method
         | 
| 836 | 
            +
                    
         | 
| 837 | 
            +
                # Create weight tensors
         | 
| 838 | 
            +
                up_proj_weight = torch.empty(
         | 
| 839 | 
            +
                    shared_expert_hidden_size,
         | 
| 840 | 
            +
                    hidden_size,
         | 
| 841 | 
            +
                    device=device,
         | 
| 842 | 
            +
                    dtype=dtype,
         | 
| 843 | 
            +
                )
         | 
| 844 | 
            +
                down_proj_weight = torch.empty(
         | 
| 845 | 
            +
                    hidden_size,
         | 
| 846 | 
            +
                    shared_expert_hidden_size,
         | 
| 847 | 
            +
                    device=device,
         | 
| 848 | 
            +
                    dtype=dtype,
         | 
| 849 | 
            +
                )
         | 
| 850 | 
            +
                
         | 
| 851 | 
            +
                # Initialize weights
         | 
| 852 | 
            +
                init_method(up_proj_weight)
         | 
| 853 | 
            +
                output_layer_init_method(down_proj_weight)
         | 
| 854 | 
            +
                
         | 
| 855 | 
            +
                # No bias by default
         | 
| 856 | 
            +
                return up_proj_weight, down_proj_weight, None, None
         | 
| 857 | 
            +
             | 
| 858 | 
            +
            # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
         | 
| 859 | 
            +
            # This exists because device_mesh is trapped in hook closures with no model attribute
         | 
| 860 | 
            +
            # Fragile - breaks if hook structure changes or Python internals change
         | 
| 861 | 
            +
            # TODO: Replace with a more robust solution when available
         | 
| 862 | 
            +
            def get_device_mesh(model):
         | 
| 863 | 
            +
                # Extract device_mesh from child's unused pre_hook closure
         | 
| 864 | 
            +
                try:
         | 
| 865 | 
            +
                    # Find the pre-hook that contains 'device_mesh' in its closure
         | 
| 866 | 
            +
                    hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
         | 
| 867 | 
            +
                    # Extract the device_mesh from the closure
         | 
| 868 | 
            +
                    return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
         | 
| 869 | 
            +
                except Exception:
         | 
| 870 | 
            +
                    return None
         | 
| 871 | 
            +
             | 
| 872 | 
            +
             | 
| 873 | 
            +
            class MegaBlocksMoeMLP(torch.nn.Module):
         | 
| 874 | 
            +
             | 
| 875 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 876 | 
            +
                    moe_top_k = getattr(self.router, "top_k", 4)
         | 
| 877 | 
            +
                    moe_num_experts = getattr(self.experts, "num_experts", 128)
         | 
| 878 | 
            +
                    gradient_scale = getattr(self.experts, "gradient_scale", None)
         | 
| 879 | 
            +
                    alpha = getattr(self.experts, "alpha", 1.0)
         | 
| 880 | 
            +
                    moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
         | 
| 881 | 
            +
                    moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
         | 
| 882 | 
            +
                    moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
         | 
| 883 | 
            +
                    uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
         | 
| 884 | 
            +
             | 
| 885 | 
            +
                    expert_parallel_group = getattr(self, "expert_parallel_group", None)
         | 
| 886 | 
            +
                    if expert_parallel_group is None:
         | 
| 887 | 
            +
                        device_mesh = get_device_mesh(self)
         | 
| 888 | 
            +
                        expert_parallel_group = device_mesh.get_group() if device_mesh else None
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                    has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
         | 
| 891 | 
            +
                    forward_fn = parallel_forward_once if has_parallel else forward_once
         | 
| 892 | 
            +
                    
         | 
| 893 | 
            +
                    sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
         | 
| 894 | 
            +
                    mlp_impl = getattr(self, "mlp_impl", "grouped")
         | 
| 895 | 
            +
                    
         | 
| 896 | 
            +
                    output, expert_weights_out, *_ = moe_forward(
         | 
| 897 | 
            +
                        x=x,
         | 
| 898 | 
            +
                        router_weight=self.router.weight,
         | 
| 899 | 
            +
                        moe_top_k=moe_top_k,
         | 
| 900 | 
            +
                        moe_num_experts=moe_num_experts,
         | 
| 901 | 
            +
                        moe_jitter_eps=moe_jitter_eps,
         | 
| 902 | 
            +
                        moe_normalize_expert_weights=moe_normalize_expert_weights,
         | 
| 903 | 
            +
                        uniform_expert_assignment=uniform_expert_assignment,
         | 
| 904 | 
            +
                        training=self.training,
         | 
| 905 | 
            +
                        w1=self.experts.gate_up_proj,
         | 
| 906 | 
            +
                        w2=self.experts.down_proj,
         | 
| 907 | 
            +
                        w1_bias=self.experts.gate_up_proj_bias,
         | 
| 908 | 
            +
                        w2_bias=self.experts.down_proj_bias,
         | 
| 909 | 
            +
                        gradient_scale=gradient_scale,
         | 
| 910 | 
            +
                        alpha=alpha,
         | 
| 911 | 
            +
                        sort_end_bit=sort_end_bit,
         | 
| 912 | 
            +
                        expert_parallel_group=expert_parallel_group,
         | 
| 913 | 
            +
                        moe_capacity_factor=moe_capacity_factor,
         | 
| 914 | 
            +
                        moe_expert_model_parallelism=has_parallel,
         | 
| 915 | 
            +
                        forward_fn=forward_fn,
         | 
| 916 | 
            +
                        hidden_size=self.experts.hidden_size,
         | 
| 917 | 
            +
                        mlp_impl=mlp_impl,
         | 
| 918 | 
            +
                    )
         | 
| 919 | 
            +
                    return output, expert_weights_out
         | 
| 920 | 
            +
             | 
| 921 | 
            +
             | 
| 922 | 
            +
            class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
         | 
| 923 | 
            +
                
         | 
| 924 | 
            +
                def __init__(self):
         | 
| 925 | 
            +
                    super().__init__()
         | 
| 926 | 
            +
                    # Shared expert weights will be set by the user
         | 
| 927 | 
            +
                    self.shared_up_proj_weight = None
         | 
| 928 | 
            +
                    self.shared_down_proj_weight = None
         | 
| 929 | 
            +
                    self.shared_up_proj_bias = None
         | 
| 930 | 
            +
                    self.shared_down_proj_bias = None
         | 
| 931 | 
            +
                    self.shared_expert_weighted_sum = False
         | 
| 932 | 
            +
                    self.shared_activation_fn = None
         | 
| 933 | 
            +
                    
         | 
| 934 | 
            +
                def set_shared_expert_weights(
         | 
| 935 | 
            +
                    self,
         | 
| 936 | 
            +
                    up_proj_weight: torch.Tensor,
         | 
| 937 | 
            +
                    down_proj_weight: torch.Tensor,
         | 
| 938 | 
            +
                    up_proj_bias: Optional[torch.Tensor] = None,
         | 
| 939 | 
            +
                    down_proj_bias: Optional[torch.Tensor] = None,
         | 
| 940 | 
            +
                    weighted_sum: bool = False,
         | 
| 941 | 
            +
                    activation_fn: Optional[Any] = None,
         | 
| 942 | 
            +
                ):
         | 
| 943 | 
            +
                    self.shared_up_proj_weight = up_proj_weight
         | 
| 944 | 
            +
                    self.shared_down_proj_weight = down_proj_weight
         | 
| 945 | 
            +
                    self.shared_up_proj_bias = up_proj_bias
         | 
| 946 | 
            +
                    self.shared_down_proj_bias = down_proj_bias
         | 
| 947 | 
            +
                    self.shared_expert_weighted_sum = weighted_sum
         | 
| 948 | 
            +
                    self.shared_activation_fn = activation_fn
         | 
| 949 | 
            +
                
         | 
| 950 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 951 | 
            +
                    moe_top_k = getattr(self.router, "top_k", 4)
         | 
| 952 | 
            +
                    moe_num_experts = getattr(self.experts, "num_experts", 128)
         | 
| 953 | 
            +
                    gradient_scale = getattr(self.experts, "gradient_scale", None)
         | 
| 954 | 
            +
                    alpha = getattr(self.experts, "alpha", 1.0)
         | 
| 955 | 
            +
                    moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
         | 
| 956 | 
            +
                    moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
         | 
| 957 | 
            +
                    moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
         | 
| 958 | 
            +
                    uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
         | 
| 959 | 
            +
             | 
| 960 | 
            +
                    expert_parallel_group = getattr(self, "expert_parallel_group", None)
         | 
| 961 | 
            +
                    if expert_parallel_group is None:
         | 
| 962 | 
            +
                        device_mesh = get_device_mesh(self)
         | 
| 963 | 
            +
                        expert_parallel_group = device_mesh.get_group() if device_mesh else None
         | 
| 964 | 
            +
             | 
| 965 | 
            +
                    has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
         | 
| 966 | 
            +
                    forward_fn = parallel_forward_once if has_parallel else forward_once
         | 
| 967 | 
            +
                    
         | 
| 968 | 
            +
                    sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
         | 
| 969 | 
            +
                    mlp_impl = getattr(self, "mlp_impl", "grouped")
         | 
| 970 | 
            +
                    
         | 
| 971 | 
            +
                    output, expert_weights_out, *_ = moe_forward_with_shared_expert(
         | 
| 972 | 
            +
                        x=x,
         | 
| 973 | 
            +
                        router_weight=self.router.weight,
         | 
| 974 | 
            +
                        moe_top_k=moe_top_k,
         | 
| 975 | 
            +
                        moe_num_experts=moe_num_experts,
         | 
| 976 | 
            +
                        moe_jitter_eps=moe_jitter_eps,
         | 
| 977 | 
            +
                        moe_normalize_expert_weights=moe_normalize_expert_weights,
         | 
| 978 | 
            +
                        uniform_expert_assignment=uniform_expert_assignment,
         | 
| 979 | 
            +
                        training=self.training,
         | 
| 980 | 
            +
                        w1=self.experts.gate_up_proj,
         | 
| 981 | 
            +
                        w2=self.experts.down_proj,
         | 
| 982 | 
            +
                        w1_bias=self.experts.gate_up_proj_bias,
         | 
| 983 | 
            +
                        w2_bias=self.experts.down_proj_bias,
         | 
| 984 | 
            +
                        gradient_scale=gradient_scale,
         | 
| 985 | 
            +
                        alpha=alpha,
         | 
| 986 | 
            +
                        sort_end_bit=sort_end_bit,
         | 
| 987 | 
            +
                        expert_parallel_group=expert_parallel_group,
         | 
| 988 | 
            +
                        moe_capacity_factor=moe_capacity_factor,
         | 
| 989 | 
            +
                        moe_expert_model_parallelism=has_parallel,
         | 
| 990 | 
            +
                        forward_fn=forward_fn,
         | 
| 991 | 
            +
                        hidden_size=self.experts.hidden_size,
         | 
| 992 | 
            +
                        mlp_impl=mlp_impl,
         | 
| 993 | 
            +
                        # Shared expert parameters
         | 
| 994 | 
            +
                        shared_up_proj_weight=self.shared_up_proj_weight,
         | 
| 995 | 
            +
                        shared_down_proj_weight=self.shared_down_proj_weight,
         | 
| 996 | 
            +
                        shared_up_proj_bias=self.shared_up_proj_bias,
         | 
| 997 | 
            +
                        shared_down_proj_bias=self.shared_down_proj_bias,
         | 
| 998 | 
            +
                        shared_expert_weighted_sum=self.shared_expert_weighted_sum,
         | 
| 999 | 
            +
                        shared_activation_fn=self.shared_activation_fn,
         | 
| 1000 | 
            +
                    )
         | 
| 1001 | 
            +
                    return output, expert_weights_out
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/__init__.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .binned_gather import binned_gather
         | 
| 5 | 
            +
            from .binned_scatter import binned_scatter
         | 
| 6 | 
            +
            from .cumsum import exclusive_cumsum, inclusive_cumsum
         | 
| 7 | 
            +
            from .gather import gather
         | 
| 8 | 
            +
            from .histogram import histogram
         | 
| 9 | 
            +
            from .padded_gather import padded_gather
         | 
| 10 | 
            +
            from .padded_scatter import padded_scatter
         | 
| 11 | 
            +
            from .repeat import repeat
         | 
| 12 | 
            +
            from .replicate import replicate
         | 
| 13 | 
            +
            from .round_up import round_up
         | 
| 14 | 
            +
            from .scatter import scatter
         | 
| 15 | 
            +
            from .sort import sort
         | 
| 16 | 
            +
            from .sum import sum
         | 
| 17 | 
            +
            from .topology import topology
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            __all__ = [
         | 
| 20 | 
            +
                'binned_gather',
         | 
| 21 | 
            +
                'binned_scatter',
         | 
| 22 | 
            +
                'exclusive_cumsum',
         | 
| 23 | 
            +
                'inclusive_cumsum',
         | 
| 24 | 
            +
                'gather',
         | 
| 25 | 
            +
                'histogram',
         | 
| 26 | 
            +
                'padded_gather',
         | 
| 27 | 
            +
                'padded_scatter',
         | 
| 28 | 
            +
                'repeat',
         | 
| 29 | 
            +
                'replicate',
         | 
| 30 | 
            +
                'round_up',
         | 
| 31 | 
            +
                'scatter',
         | 
| 32 | 
            +
                'sort',
         | 
| 33 | 
            +
                'sum',
         | 
| 34 | 
            +
                'topology',
         | 
| 35 | 
            +
            ]
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
    ADDED
    
    | @@ -0,0 +1,63 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.distributed as dist
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # from megablocks import benchmark_util
         | 
| 8 | 
            +
            # from megablocks.layers.all_to_all import all_to_all
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .. import benchmark_util
         | 
| 11 | 
            +
            from .._layers.all_to_all import all_to_all
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            _ALL_TO_ALL_BENCHMARK = (
         | 
| 14 | 
            +
                (8, 1024),
         | 
| 15 | 
            +
                (16, 1024),
         | 
| 16 | 
            +
                (32, 1024),
         | 
| 17 | 
            +
                (64, 1024),
         | 
| 18 | 
            +
                (128, 1024),
         | 
| 19 | 
            +
                (256, 1024),
         | 
| 20 | 
            +
                (512, 1024),
         | 
| 21 | 
            +
                (1024, 1024),
         | 
| 22 | 
            +
                (2 * 1024, 1024),
         | 
| 23 | 
            +
                (4 * 1024, 1024),
         | 
| 24 | 
            +
                (8 * 1024, 1024),
         | 
| 25 | 
            +
                (16 * 1024, 1024),
         | 
| 26 | 
            +
                (32 * 1024, 1024),
         | 
| 27 | 
            +
                (64 * 1024, 1024),
         | 
| 28 | 
            +
                (128 * 1024, 1024),
         | 
| 29 | 
            +
                (256 * 1024, 1024),
         | 
| 30 | 
            +
                (512 * 1024, 1024),
         | 
| 31 | 
            +
                (1024 * 1024, 1024),
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def benchmark_all_to_all(group, sl, hs):
         | 
| 36 | 
            +
                world_size = dist.get_world_size(group)
         | 
| 37 | 
            +
                assert (sl % world_size) == 0
         | 
| 38 | 
            +
                send_recv_sizes = [sl // world_size] * world_size
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                x = torch.randn((sl, hs)).cuda().half()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                details = {
         | 
| 43 | 
            +
                    'world_size': world_size,
         | 
| 44 | 
            +
                    'message_size (B)': send_recv_sizes[0] * hs * 2,  # 2B elements.
         | 
| 45 | 
            +
                }
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def benchmark():
         | 
| 48 | 
            +
                    return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                time, std = benchmark_util.benchmark_function(benchmark)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                if dist.get_rank(group) == 0:
         | 
| 53 | 
            +
                    benchmark_util.log_benchmark('All-To-All', details, time, std)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            if __name__ == '__main__':
         | 
| 57 | 
            +
                assert dist.is_available()
         | 
| 58 | 
            +
                group = dist.init_process_group(backend='nccl')
         | 
| 59 | 
            +
                local_rank = dist.get_rank(group)
         | 
| 60 | 
            +
                torch.cuda.set_device(local_rank)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                for args in _ALL_TO_ALL_BENCHMARK:
         | 
| 63 | 
            +
                    benchmark_all_to_all(group, *args)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_gather.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            from typing import Any
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from .stk_autocast import custom_bwd, custom_fwd
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ..backend import kernels
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Autograd wrapper for binned_gather kernel.
         | 
| 12 | 
            +
            class BinnedGatherOp(torch.autograd.Function):
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                @staticmethod
         | 
| 15 | 
            +
                @custom_fwd
         | 
| 16 | 
            +
                def forward(
         | 
| 17 | 
            +
                    ctx: Any,
         | 
| 18 | 
            +
                    x: torch.Tensor,
         | 
| 19 | 
            +
                    indices: torch.Tensor,
         | 
| 20 | 
            +
                    bins: torch.Tensor,
         | 
| 21 | 
            +
                    bin_size: int,
         | 
| 22 | 
            +
                    top_k: int,
         | 
| 23 | 
            +
                ):
         | 
| 24 | 
            +
                    ctx.save_for_backward(indices, bins)
         | 
| 25 | 
            +
                    ctx.top_k = top_k
         | 
| 26 | 
            +
                    return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                @staticmethod
         | 
| 29 | 
            +
                @custom_bwd
         | 
| 30 | 
            +
                def backward(ctx: Any, grad: torch.Tensor):
         | 
| 31 | 
            +
                    grad = grad.contiguous()
         | 
| 32 | 
            +
                    indices, bins = ctx.saved_tensors
         | 
| 33 | 
            +
                    out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
         | 
| 34 | 
            +
                    return out, None, None, None, None
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            binned_gather = BinnedGatherOp.apply
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/binned_scatter.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            from typing import Any
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from .stk_autocast import custom_bwd, custom_fwd
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ..backend import kernels
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Autograd wrapper for binned_scatter kernel.
         | 
| 12 | 
            +
            class BinnedScatterOp(torch.autograd.Function):
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                @staticmethod
         | 
| 15 | 
            +
                @custom_fwd
         | 
| 16 | 
            +
                def forward(
         | 
| 17 | 
            +
                    ctx: Any,
         | 
| 18 | 
            +
                    x: torch.Tensor,
         | 
| 19 | 
            +
                    indices: torch.Tensor,
         | 
| 20 | 
            +
                    weights: torch.Tensor,
         | 
| 21 | 
            +
                    bins: torch.Tensor,
         | 
| 22 | 
            +
                    top_k: int,
         | 
| 23 | 
            +
                ):
         | 
| 24 | 
            +
                    assert len(x.size()) == 3
         | 
| 25 | 
            +
                    ctx.bin_size = x.size(1)
         | 
| 26 | 
            +
                    ctx.top_k = top_k
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    # TODO(tgale): Don't save 'x' for backwards if we don't need to
         | 
| 29 | 
            +
                    # calculate the gradient w.r.t. 'weights'.
         | 
| 30 | 
            +
                    ctx.save_for_backward(x, indices, weights, bins)
         | 
| 31 | 
            +
                    return kernels.binned_scatter(x, indices, weights, bins, top_k)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                @staticmethod
         | 
| 34 | 
            +
                @custom_bwd
         | 
| 35 | 
            +
                def backward(ctx: Any, grad: torch.Tensor):
         | 
| 36 | 
            +
                    grad = grad.contiguous()
         | 
| 37 | 
            +
                    x, indices, weights, bins = ctx.saved_tensors
         | 
| 38 | 
            +
                    out = kernels.binned_gather(
         | 
| 39 | 
            +
                        grad,
         | 
| 40 | 
            +
                        indices,
         | 
| 41 | 
            +
                        weights,
         | 
| 42 | 
            +
                        bins,
         | 
| 43 | 
            +
                        ctx.bin_size,
         | 
| 44 | 
            +
                        ctx.top_k,
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    wgrad = None
         | 
| 48 | 
            +
                    if ctx.needs_input_grad[2]:
         | 
| 49 | 
            +
                        wgrad = kernels.binned_scatter_wgrad(
         | 
| 50 | 
            +
                            x,
         | 
| 51 | 
            +
                            grad,
         | 
| 52 | 
            +
                            indices,
         | 
| 53 | 
            +
                            bins,
         | 
| 54 | 
            +
                            ctx.top_k,
         | 
| 55 | 
            +
                        )
         | 
| 56 | 
            +
                    return out, None, wgrad, None, None
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            binned_scatter = BinnedScatterOp.apply
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/cumsum.py
    ADDED
    
    | @@ -0,0 +1,52 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # NOTE: Torch needs to be imported before the custom
         | 
| 7 | 
            +
            # extensions. Otherwise libc10.so cannot be found.
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Wrap this in a try-block with better error message and
         | 
| 11 | 
            +
            # instructions for building the c++ operations.
         | 
| 12 | 
            +
            try:
         | 
| 13 | 
            +
                # import megablocks_ops as ops  # type: ignore
         | 
| 14 | 
            +
                from .._ops import ops  # type: ignore
         | 
| 15 | 
            +
            except ModuleNotFoundError as e:
         | 
| 16 | 
            +
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            # Autograd wrappers for cumsum kernels.
         | 
| 20 | 
            +
            # NOTE: Does not support gradients.
         | 
| 21 | 
            +
            class ExclusiveCumsumOp(torch.autograd.Function):
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                @staticmethod
         | 
| 24 | 
            +
                def forward(ctx: Any, x: torch.Tensor, dim: int):
         | 
| 25 | 
            +
                    if len(x.size()) == 1:
         | 
| 26 | 
            +
                        x = x.view([1, -1])
         | 
| 27 | 
            +
                        out = torch.empty_like(x)
         | 
| 28 | 
            +
                        ops.exclusive_cumsum(x, 1, out)
         | 
| 29 | 
            +
                        return out.squeeze()
         | 
| 30 | 
            +
                    out = torch.empty_like(x)
         | 
| 31 | 
            +
                    ops.exclusive_cumsum(x, dim, out)
         | 
| 32 | 
            +
                    return out
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            exclusive_cumsum = ExclusiveCumsumOp.apply
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            class InclusiveCumsumOp(torch.autograd.Function):
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                @staticmethod
         | 
| 41 | 
            +
                def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
         | 
| 42 | 
            +
                    if len(x.size()) == 1:
         | 
| 43 | 
            +
                        x = x.view([1, -1])
         | 
| 44 | 
            +
                        out = torch.empty_like(x)
         | 
| 45 | 
            +
                        ops.inclusive_cumsum(x, 1, out)
         | 
| 46 | 
            +
                        return out.squeeze()
         | 
| 47 | 
            +
                    out = torch.empty_like(x)
         | 
| 48 | 
            +
                    ops.inclusive_cumsum(x, dim, out)
         | 
| 49 | 
            +
                    return out
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            inclusive_cumsum = InclusiveCumsumOp.apply
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/gather.py
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            from typing import Any
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from .stk_autocast import custom_bwd, custom_fwd
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ..backend import kernels
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Autograd wrapper for gather kernel.
         | 
| 12 | 
            +
            class GatherOp(torch.autograd.Function):
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                @staticmethod
         | 
| 15 | 
            +
                @custom_fwd
         | 
| 16 | 
            +
                def forward(
         | 
| 17 | 
            +
                    ctx: Any,
         | 
| 18 | 
            +
                    x: torch.Tensor,
         | 
| 19 | 
            +
                    indices: torch.Tensor,
         | 
| 20 | 
            +
                    bin_ids: torch.Tensor,
         | 
| 21 | 
            +
                    bins: torch.Tensor,
         | 
| 22 | 
            +
                    top_k: int,
         | 
| 23 | 
            +
                ):
         | 
| 24 | 
            +
                    ctx.save_for_backward(indices, bin_ids, bins)
         | 
| 25 | 
            +
                    ctx.top_k = top_k
         | 
| 26 | 
            +
                    return kernels.gather(x, indices, bin_ids, None, bins, top_k)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                @staticmethod
         | 
| 29 | 
            +
                @custom_bwd
         | 
| 30 | 
            +
                def backward(ctx: Any, grad: torch.Tensor):
         | 
| 31 | 
            +
                    grad = grad.contiguous()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    indices, bin_ids, bins = ctx.saved_tensors
         | 
| 34 | 
            +
                    out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
         | 
| 35 | 
            +
                    return out, None, None, None, None, None
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            gather = GatherOp.apply
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram.py
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # NOTE: Torch needs to be imported before the custom
         | 
| 7 | 
            +
            # extensions. Otherwise libc10.so cannot be found.
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Wrap this in a try-block with better error message and
         | 
| 11 | 
            +
            # instructions for building the c++ operations.
         | 
| 12 | 
            +
            try:
         | 
| 13 | 
            +
                from .._ops import ops  # type: ignore
         | 
| 14 | 
            +
            except ModuleNotFoundError as e:
         | 
| 15 | 
            +
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Autograd wrapper for histogram kernel.
         | 
| 19 | 
            +
            # NOTE: Does not support gradients.
         | 
| 20 | 
            +
            class HistogramOp(torch.autograd.Function):
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @staticmethod
         | 
| 23 | 
            +
                def forward(ctx: Any, x: torch.Tensor, max_val: float):
         | 
| 24 | 
            +
                    return ops.histogram(x, max_val)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            histogram = HistogramOp.apply
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/histogram_benchmark.py
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import unittest
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from absl.testing import parameterized
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .. import ops
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            _HISTOGRAM_TESTS = (
         | 
| 13 | 
            +
                (16384, torch.int32, 2),
         | 
| 14 | 
            +
                (16384, torch.int32, 4),
         | 
| 15 | 
            +
                (16384, torch.int32, 8),
         | 
| 16 | 
            +
                (16384, torch.int32, 16),
         | 
| 17 | 
            +
                (16384, torch.int32, 32),
         | 
| 18 | 
            +
                (16384, torch.int32, 64),
         | 
| 19 | 
            +
                (16384, torch.int32, 128),
         | 
| 20 | 
            +
                (16384, torch.int32, 256),
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def benchmark_function(fn, iterations=10):
         | 
| 25 | 
            +
                # Run once to get rid of startup overhead.
         | 
| 26 | 
            +
                fn()
         | 
| 27 | 
            +
                times = []
         | 
| 28 | 
            +
                for _ in range(iterations):
         | 
| 29 | 
            +
                    start = torch.cuda.Event(enable_timing=True)
         | 
| 30 | 
            +
                    end = torch.cuda.Event(enable_timing=True)
         | 
| 31 | 
            +
                    start.record()
         | 
| 32 | 
            +
                    fn()
         | 
| 33 | 
            +
                    end.record()
         | 
| 34 | 
            +
                    torch.cuda.synchronize()
         | 
| 35 | 
            +
                    times.append(start.elapsed_time(end))
         | 
| 36 | 
            +
                times = np.array(times)
         | 
| 37 | 
            +
                return times.mean(), times.std(), times.max(), times.min()
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def log_benchmark(arguments, mean_t, std_t):
         | 
| 41 | 
            +
                print('=' * 60)
         | 
| 42 | 
            +
                print('Benchmark Parameters:')
         | 
| 43 | 
            +
                for (key, value) in arguments.items():
         | 
| 44 | 
            +
                    print(f'{key} = {value}')
         | 
| 45 | 
            +
                print('Results:')
         | 
| 46 | 
            +
                print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
         | 
| 47 | 
            +
                print('=' * 60)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class HistogramBenchmark(parameterized.TestCase):
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                @parameterized.parameters(*_HISTOGRAM_TESTS)
         | 
| 53 | 
            +
                def testHistogram(self, n, dtype, max_val):
         | 
| 54 | 
            +
                    x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
         | 
| 57 | 
            +
                    arguments = {
         | 
| 58 | 
            +
                        'n': n,
         | 
| 59 | 
            +
                        'dtype': dtype,
         | 
| 60 | 
            +
                        'max_val': max_val,
         | 
| 61 | 
            +
                    }
         | 
| 62 | 
            +
                    log_benchmark(arguments, mean_t, std_t)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                @parameterized.parameters(*_HISTOGRAM_TESTS)
         | 
| 65 | 
            +
                def testTorchHistogram(self, n, dtype, max_val):
         | 
| 66 | 
            +
                    x = torch.randint(0, 128, (n,)).cuda().to(dtype)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
         | 
| 69 | 
            +
                    arguments = {
         | 
| 70 | 
            +
                        'n': n,
         | 
| 71 | 
            +
                        'dtype': dtype,
         | 
| 72 | 
            +
                        'max_val': max_val,
         | 
| 73 | 
            +
                    }
         | 
| 74 | 
            +
                    log_benchmark(arguments, mean_t, std_t)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            if __name__ == '__main__':
         | 
| 78 | 
            +
                unittest.main()
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/matmul_benchmark.py
    ADDED
    
    | @@ -0,0 +1,415 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import unittest
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            # import stk
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # try:
         | 
| 10 | 
            +
            #     import stk
         | 
| 11 | 
            +
            # except ImportError:
         | 
| 12 | 
            +
            #     import warnings
         | 
| 13 | 
            +
            #     warnings.warn(
         | 
| 14 | 
            +
            #         'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
         | 
| 15 | 
            +
            #     )
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from .. import stk
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import torch
         | 
| 20 | 
            +
            from absl.testing import parameterized
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from .. import benchmark_util, ops
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Calling tensor.t() calls tensor.transpose(0, 1) which calls
         | 
| 26 | 
            +
            # torch.as_strided(...). Circumvent this chain to avoid an overhead
         | 
| 27 | 
            +
            # this adds.
         | 
| 28 | 
            +
            def transpose_view(x):
         | 
| 29 | 
            +
                return torch.as_strided(
         | 
| 30 | 
            +
                    x,
         | 
| 31 | 
            +
                    (x.shape[1], x.shape[0]),
         | 
| 32 | 
            +
                    (x.stride()[1], x.stride()[0]),
         | 
| 33 | 
            +
                )
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            _MATMUL_TESTS = (
         | 
| 37 | 
            +
                (64 * 1024, 512, 2048, 64),
         | 
| 38 | 
            +
                (32 * 1024, 768, 3072, 64),
         | 
| 39 | 
            +
                (8 * 1024, 1024, 4096, 64),
         | 
| 40 | 
            +
                (4 * 2048, 4096, 4 * 4096, 4),
         | 
| 41 | 
            +
            )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def log_benchmark(name, arguments, time, std, flops):
         | 
| 45 | 
            +
                benchmark_util.log_benchmark(name, arguments, time, std)
         | 
| 46 | 
            +
                print('flops = {:.2f}B'.format(flops / 1e9))
         | 
| 47 | 
            +
                print('throughput = {:.2f}T'.format(flops / 1e9 / time))
         | 
| 48 | 
            +
                print('=' * 60)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            class MatmulBenchmark(parameterized.TestCase):
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def build_sparse_matrix(self, x, padded_bins, fhs, ne):
         | 
| 54 | 
            +
                    blocking = 128
         | 
| 55 | 
            +
                    padded_tokens, _ = x.size()
         | 
| 56 | 
            +
                    assert padded_tokens % blocking == 0
         | 
| 57 | 
            +
                    assert fhs % blocking == 0
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # Offsets for the sparse matrix. All rows have the
         | 
| 60 | 
            +
                    # same number of nonzero blocks dictated by the
         | 
| 61 | 
            +
                    # dimensionality of a single expert.
         | 
| 62 | 
            +
                    block_rows = padded_tokens // blocking
         | 
| 63 | 
            +
                    blocks_per_row = fhs // blocking
         | 
| 64 | 
            +
                    offsets = torch.arange(
         | 
| 65 | 
            +
                        0,
         | 
| 66 | 
            +
                        block_rows * blocks_per_row + 1,
         | 
| 67 | 
            +
                        blocks_per_row,
         | 
| 68 | 
            +
                        dtype=torch.int32,
         | 
| 69 | 
            +
                        device=x.device,
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    # Indices for the sparse matrix. The indices for
         | 
| 73 | 
            +
                    # the intermediate matrix are dynamic depending
         | 
| 74 | 
            +
                    # on the mapping of tokens to experts.
         | 
| 75 | 
            +
                    column_indices = ops.topology(
         | 
| 76 | 
            +
                        padded_bins,
         | 
| 77 | 
            +
                        blocking,
         | 
| 78 | 
            +
                        block_rows,
         | 
| 79 | 
            +
                        blocks_per_row,
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
                    data = torch.empty(
         | 
| 82 | 
            +
                        column_indices.numel(),
         | 
| 83 | 
            +
                        blocking,
         | 
| 84 | 
            +
                        blocking,
         | 
| 85 | 
            +
                        dtype=torch.float16,
         | 
| 86 | 
            +
                        device=x.device,
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    shape = (padded_tokens, fhs * ne)
         | 
| 89 | 
            +
                    row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
         | 
| 90 | 
            +
                    return stk.Matrix(shape, data, row_indices, column_indices, offsets)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def build_input_matrix(self, sl, hs, ne):
         | 
| 93 | 
            +
                    x = torch.randn((sl, hs)).cuda().half()
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    # Assign tokens to experts uniformly.
         | 
| 96 | 
            +
                    top_expert = torch.arange(0, sl).cuda().int() % ne
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    bin_ids, indices = ops.sort(top_expert)
         | 
| 99 | 
            +
                    tokens_per_expert = ops.histogram(top_expert, ne)
         | 
| 100 | 
            +
                    padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
         | 
| 101 | 
            +
                    padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         | 
| 102 | 
            +
                    bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 103 | 
            +
                    out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
         | 
| 104 | 
            +
                    return out, padded_bins
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def build_weight_matrix(self, ne, hs, fhs):
         | 
| 107 | 
            +
                    return torch.randn((hs, ne * fhs)).cuda().half()
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 110 | 
            +
                def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
         | 
| 111 | 
            +
                    x, padded_bins = self.build_input_matrix(sl, hs, ne)
         | 
| 112 | 
            +
                    w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
         | 
| 113 | 
            +
                    topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
         | 
| 114 | 
            +
                    w = transpose_view(w)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    def benchmark():
         | 
| 117 | 
            +
                        return stk.ops.sdd(x, w, topo)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 120 | 
            +
                    arguments = {
         | 
| 121 | 
            +
                        'sequence_length': sl,
         | 
| 122 | 
            +
                        'hidden_size': hs,
         | 
| 123 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 124 | 
            +
                        'num_experts': ne,
         | 
| 125 | 
            +
                    }
         | 
| 126 | 
            +
                    log_benchmark(
         | 
| 127 | 
            +
                        '0::Fwd::SDD::NT',
         | 
| 128 | 
            +
                        arguments,
         | 
| 129 | 
            +
                        mean_t,
         | 
| 130 | 
            +
                        std_t,
         | 
| 131 | 
            +
                        x.numel() * fhs * 2,
         | 
| 132 | 
            +
                    )
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 135 | 
            +
                def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
         | 
| 136 | 
            +
                    x, padded_bins = self.build_input_matrix(sl, hs, ne)
         | 
| 137 | 
            +
                    w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
         | 
| 138 | 
            +
                    topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    def benchmark():
         | 
| 141 | 
            +
                        return stk.ops.dsd(topo, w)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 144 | 
            +
                    arguments = {
         | 
| 145 | 
            +
                        'sequence_length': sl,
         | 
| 146 | 
            +
                        'hidden_size': hs,
         | 
| 147 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 148 | 
            +
                        'num_experts': ne,
         | 
| 149 | 
            +
                    }
         | 
| 150 | 
            +
                    log_benchmark(
         | 
| 151 | 
            +
                        '0::GradX::DSD::NN',
         | 
| 152 | 
            +
                        arguments,
         | 
| 153 | 
            +
                        mean_t,
         | 
| 154 | 
            +
                        std_t,
         | 
| 155 | 
            +
                        x.numel() * fhs * 2,
         | 
| 156 | 
            +
                    )
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 159 | 
            +
                def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
         | 
| 160 | 
            +
                    x, padded_bins = self.build_input_matrix(sl, hs, ne)
         | 
| 161 | 
            +
                    topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
         | 
| 162 | 
            +
                    topo = topo.t()
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    def benchmark():
         | 
| 165 | 
            +
                        return stk.ops.dsd(topo, x)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 168 | 
            +
                    arguments = {
         | 
| 169 | 
            +
                        'sequence_length': sl,
         | 
| 170 | 
            +
                        'hidden_size': hs,
         | 
| 171 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 172 | 
            +
                        'num_experts': ne,
         | 
| 173 | 
            +
                    }
         | 
| 174 | 
            +
                    log_benchmark(
         | 
| 175 | 
            +
                        '0::GradW::DSD::TN',
         | 
| 176 | 
            +
                        arguments,
         | 
| 177 | 
            +
                        mean_t,
         | 
| 178 | 
            +
                        std_t,
         | 
| 179 | 
            +
                        x.numel() * fhs * 2,
         | 
| 180 | 
            +
                    )
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 183 | 
            +
                def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
         | 
| 184 | 
            +
                    x, padded_bins = self.build_input_matrix(sl, hs, ne)
         | 
| 185 | 
            +
                    w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
         | 
| 186 | 
            +
                    x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    def benchmark():
         | 
| 189 | 
            +
                        return stk.ops.dsd(x, w)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 192 | 
            +
                    arguments = {
         | 
| 193 | 
            +
                        'sequence_length': sl,
         | 
| 194 | 
            +
                        'hidden_size': hs,
         | 
| 195 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 196 | 
            +
                        'num_experts': ne,
         | 
| 197 | 
            +
                    }
         | 
| 198 | 
            +
                    log_benchmark(
         | 
| 199 | 
            +
                        '1::Fwd::DSD::NN',
         | 
| 200 | 
            +
                        arguments,
         | 
| 201 | 
            +
                        mean_t,
         | 
| 202 | 
            +
                        std_t,
         | 
| 203 | 
            +
                        x.nnz * hs * 2,
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 207 | 
            +
                def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
         | 
| 208 | 
            +
                    x, padded_bins = self.build_input_matrix(sl, hs, ne)
         | 
| 209 | 
            +
                    w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
         | 
| 210 | 
            +
                    x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
         | 
| 211 | 
            +
                    out = stk.ops.dsd(x, w)
         | 
| 212 | 
            +
                    w = transpose_view(w)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    def benchmark():
         | 
| 215 | 
            +
                        return stk.ops.sdd(out, w, x)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 218 | 
            +
                    arguments = {
         | 
| 219 | 
            +
                        'sequence_length': sl,
         | 
| 220 | 
            +
                        'hidden_size': hs,
         | 
| 221 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 222 | 
            +
                        'num_experts': ne,
         | 
| 223 | 
            +
                    }
         | 
| 224 | 
            +
                    log_benchmark(
         | 
| 225 | 
            +
                        '1::GradX::SDD::NT',
         | 
| 226 | 
            +
                        arguments,
         | 
| 227 | 
            +
                        mean_t,
         | 
| 228 | 
            +
                        std_t,
         | 
| 229 | 
            +
                        x.nnz * hs * 2,
         | 
| 230 | 
            +
                    )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 233 | 
            +
                def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
         | 
| 234 | 
            +
                    x, padded_bins = self.build_input_matrix(sl, hs, ne)
         | 
| 235 | 
            +
                    w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
         | 
| 236 | 
            +
                    x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
         | 
| 237 | 
            +
                    out = stk.ops.dsd(x, w)
         | 
| 238 | 
            +
                    x = x.t()
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    def benchmark():
         | 
| 241 | 
            +
                        return stk.ops.dsd(x, out)
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 244 | 
            +
                    arguments = {
         | 
| 245 | 
            +
                        'sequence_length': sl,
         | 
| 246 | 
            +
                        'hidden_size': hs,
         | 
| 247 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 248 | 
            +
                        'num_experts': ne,
         | 
| 249 | 
            +
                    }
         | 
| 250 | 
            +
                    log_benchmark(
         | 
| 251 | 
            +
                        '1::GradW::DSD::TN',
         | 
| 252 | 
            +
                        arguments,
         | 
| 253 | 
            +
                        mean_t,
         | 
| 254 | 
            +
                        std_t,
         | 
| 255 | 
            +
                        x.nnz * hs * 2,
         | 
| 256 | 
            +
                    )
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 259 | 
            +
                def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
         | 
| 260 | 
            +
                    assert (sl % ne) == 0
         | 
| 261 | 
            +
                    x = torch.randn((ne, sl // ne, hs)).cuda().half()
         | 
| 262 | 
            +
                    w = torch.randn((ne, hs, fhs)).cuda().half()
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    w = w.transpose(1, 2).contiguous()
         | 
| 265 | 
            +
                    w = w.transpose(1, 2)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    def benchmark():
         | 
| 268 | 
            +
                        return torch.bmm(x, w)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 271 | 
            +
                    arguments = {
         | 
| 272 | 
            +
                        'sequence_length': sl,
         | 
| 273 | 
            +
                        'hidden_size': hs,
         | 
| 274 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 275 | 
            +
                        'num_experts': ne,
         | 
| 276 | 
            +
                    }
         | 
| 277 | 
            +
                    log_benchmark(
         | 
| 278 | 
            +
                        '0::Fwd:DDD::NT',
         | 
| 279 | 
            +
                        arguments,
         | 
| 280 | 
            +
                        mean_t,
         | 
| 281 | 
            +
                        std_t,
         | 
| 282 | 
            +
                        x.numel() * fhs * 2,
         | 
| 283 | 
            +
                    )
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 286 | 
            +
                def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
         | 
| 287 | 
            +
                    assert (sl % ne) == 0
         | 
| 288 | 
            +
                    x = torch.randn((ne, sl // ne, hs)).cuda().half()
         | 
| 289 | 
            +
                    w = torch.randn((ne, hs, fhs)).cuda().half()
         | 
| 290 | 
            +
                    out = torch.bmm(x, w)
         | 
| 291 | 
            +
                    w = w.transpose(1, 2).contiguous()
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    def benchmark():
         | 
| 294 | 
            +
                        return torch.bmm(out, w)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 297 | 
            +
                    arguments = {
         | 
| 298 | 
            +
                        'sequence_length': sl,
         | 
| 299 | 
            +
                        'hidden_size': hs,
         | 
| 300 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 301 | 
            +
                        'num_experts': ne,
         | 
| 302 | 
            +
                    }
         | 
| 303 | 
            +
                    log_benchmark(
         | 
| 304 | 
            +
                        '0:GradX:DDD::NN',
         | 
| 305 | 
            +
                        arguments,
         | 
| 306 | 
            +
                        mean_t,
         | 
| 307 | 
            +
                        std_t,
         | 
| 308 | 
            +
                        x.numel() * fhs * 2,
         | 
| 309 | 
            +
                    )
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 312 | 
            +
                def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
         | 
| 313 | 
            +
                    assert (sl % ne) == 0
         | 
| 314 | 
            +
                    x = torch.randn((ne, sl // ne, hs)).cuda().half()
         | 
| 315 | 
            +
                    w = torch.randn((ne, hs, fhs)).cuda().half()
         | 
| 316 | 
            +
                    out = torch.bmm(x, w)
         | 
| 317 | 
            +
                    out = out.transpose(1, 2)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    def benchmark():
         | 
| 320 | 
            +
                        return torch.bmm(out, x)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 323 | 
            +
                    arguments = {
         | 
| 324 | 
            +
                        'sequence_length': sl,
         | 
| 325 | 
            +
                        'hidden_size': hs,
         | 
| 326 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 327 | 
            +
                        'num_experts': ne,
         | 
| 328 | 
            +
                    }
         | 
| 329 | 
            +
                    log_benchmark(
         | 
| 330 | 
            +
                        '0:GradW:DDD::TN',
         | 
| 331 | 
            +
                        arguments,
         | 
| 332 | 
            +
                        mean_t,
         | 
| 333 | 
            +
                        std_t,
         | 
| 334 | 
            +
                        x.numel() * fhs * 2,
         | 
| 335 | 
            +
                    )
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 338 | 
            +
                def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
         | 
| 339 | 
            +
                    assert (sl % ne) == 0
         | 
| 340 | 
            +
                    x = torch.randn((ne, sl // ne, fhs)).cuda().half()
         | 
| 341 | 
            +
                    w = torch.randn((ne, fhs, hs)).cuda().half()
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    def benchmark():
         | 
| 344 | 
            +
                        return torch.bmm(x, w)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 347 | 
            +
                    arguments = {
         | 
| 348 | 
            +
                        'sequence_length': sl,
         | 
| 349 | 
            +
                        'hidden_size': hs,
         | 
| 350 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 351 | 
            +
                        'num_experts': ne,
         | 
| 352 | 
            +
                    }
         | 
| 353 | 
            +
                    log_benchmark(
         | 
| 354 | 
            +
                        '1::Fwd::DDD::NN',
         | 
| 355 | 
            +
                        arguments,
         | 
| 356 | 
            +
                        mean_t,
         | 
| 357 | 
            +
                        std_t,
         | 
| 358 | 
            +
                        x.numel() * hs * 2,
         | 
| 359 | 
            +
                    )
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 362 | 
            +
                def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
         | 
| 363 | 
            +
                    assert (sl % ne) == 0
         | 
| 364 | 
            +
                    x = torch.randn((ne, sl // ne, fhs)).cuda().half()
         | 
| 365 | 
            +
                    w = torch.randn((ne, fhs, hs)).cuda().half()
         | 
| 366 | 
            +
                    out = torch.bmm(x, w)
         | 
| 367 | 
            +
                    w = torch.transpose(w, 1, 2)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    def benchmark():
         | 
| 370 | 
            +
                        return torch.bmm(out, w)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 373 | 
            +
                    arguments = {
         | 
| 374 | 
            +
                        'sequence_length': sl,
         | 
| 375 | 
            +
                        'hidden_size': hs,
         | 
| 376 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 377 | 
            +
                        'num_experts': ne,
         | 
| 378 | 
            +
                    }
         | 
| 379 | 
            +
                    log_benchmark(
         | 
| 380 | 
            +
                        '1::GradX::DDD::NT',
         | 
| 381 | 
            +
                        arguments,
         | 
| 382 | 
            +
                        mean_t,
         | 
| 383 | 
            +
                        std_t,
         | 
| 384 | 
            +
                        x.numel() * hs * 2,
         | 
| 385 | 
            +
                    )
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                @parameterized.parameters(*_MATMUL_TESTS)
         | 
| 388 | 
            +
                def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
         | 
| 389 | 
            +
                    assert (sl % ne) == 0
         | 
| 390 | 
            +
                    x = torch.randn((ne, sl // ne, fhs)).cuda().half()
         | 
| 391 | 
            +
                    w = torch.randn((ne, fhs, hs)).cuda().half()
         | 
| 392 | 
            +
                    out = torch.bmm(x, w)
         | 
| 393 | 
            +
                    x = torch.transpose(x, 1, 2)
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    def benchmark():
         | 
| 396 | 
            +
                        return torch.bmm(x, out)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 399 | 
            +
                    arguments = {
         | 
| 400 | 
            +
                        'sequence_length': sl,
         | 
| 401 | 
            +
                        'hidden_size': hs,
         | 
| 402 | 
            +
                        'ffn_hidden_size': fhs,
         | 
| 403 | 
            +
                        'num_experts': ne,
         | 
| 404 | 
            +
                    }
         | 
| 405 | 
            +
                    log_benchmark(
         | 
| 406 | 
            +
                        '1::GradW::DDD::TN',
         | 
| 407 | 
            +
                        arguments,
         | 
| 408 | 
            +
                        mean_t,
         | 
| 409 | 
            +
                        std_t,
         | 
| 410 | 
            +
                        x.numel() * hs * 2,
         | 
| 411 | 
            +
                    )
         | 
| 412 | 
            +
             | 
| 413 | 
            +
             | 
| 414 | 
            +
            if __name__ == '__main__':
         | 
| 415 | 
            +
                unittest.main()
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_gather.py
    ADDED
    
    | @@ -0,0 +1,55 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            from typing import Any
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from .stk_autocast import custom_bwd, custom_fwd
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ..backend import kernels
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Autograd wrapper for padded_gather kernel.
         | 
| 12 | 
            +
            class PaddedGatherOp(torch.autograd.Function):
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                @staticmethod
         | 
| 15 | 
            +
                @custom_fwd
         | 
| 16 | 
            +
                def forward(
         | 
| 17 | 
            +
                    ctx: Any,
         | 
| 18 | 
            +
                    x: torch.Tensor,
         | 
| 19 | 
            +
                    indices: torch.Tensor,
         | 
| 20 | 
            +
                    bin_ids: torch.Tensor,
         | 
| 21 | 
            +
                    bins: torch.Tensor,
         | 
| 22 | 
            +
                    padded_bins: torch.Tensor,
         | 
| 23 | 
            +
                    top_k: int,
         | 
| 24 | 
            +
                ):
         | 
| 25 | 
            +
                    ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
         | 
| 26 | 
            +
                    ctx.top_k = top_k
         | 
| 27 | 
            +
                    return kernels.padded_gather(
         | 
| 28 | 
            +
                        x,
         | 
| 29 | 
            +
                        indices,
         | 
| 30 | 
            +
                        bin_ids,
         | 
| 31 | 
            +
                        None,
         | 
| 32 | 
            +
                        bins,
         | 
| 33 | 
            +
                        padded_bins,
         | 
| 34 | 
            +
                        top_k,
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                @staticmethod
         | 
| 38 | 
            +
                @custom_bwd
         | 
| 39 | 
            +
                def backward(ctx: Any, grad: torch.Tensor):
         | 
| 40 | 
            +
                    grad = grad.contiguous()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    indices, bin_ids, bins, padded_bins = ctx.saved_tensors
         | 
| 43 | 
            +
                    out = kernels.padded_scatter(
         | 
| 44 | 
            +
                        grad,
         | 
| 45 | 
            +
                        indices,
         | 
| 46 | 
            +
                        bin_ids,
         | 
| 47 | 
            +
                        None,
         | 
| 48 | 
            +
                        bins,
         | 
| 49 | 
            +
                        padded_bins,
         | 
| 50 | 
            +
                        ctx.top_k,
         | 
| 51 | 
            +
                    )
         | 
| 52 | 
            +
                    return out, None, None, None, None, None
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            padded_gather = PaddedGatherOp.apply
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter.py
    ADDED
    
    | @@ -0,0 +1,98 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            from typing import Any
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from .stk_autocast import custom_bwd, custom_fwd
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ..backend import kernels
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Autograd wrapper for padded_scatter kernel.
         | 
| 12 | 
            +
            class PaddedScatterOp(torch.autograd.Function):
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                @staticmethod
         | 
| 15 | 
            +
                @custom_fwd
         | 
| 16 | 
            +
                def forward(
         | 
| 17 | 
            +
                    ctx: Any,
         | 
| 18 | 
            +
                    x: torch.Tensor,
         | 
| 19 | 
            +
                    indices: torch.Tensor,
         | 
| 20 | 
            +
                    bin_ids: torch.Tensor,
         | 
| 21 | 
            +
                    weights: torch.Tensor,
         | 
| 22 | 
            +
                    bins: torch.Tensor,
         | 
| 23 | 
            +
                    padded_bins: torch.Tensor,
         | 
| 24 | 
            +
                    top_k: int,
         | 
| 25 | 
            +
                ):
         | 
| 26 | 
            +
                    maybe_x = [x] if ctx.needs_input_grad[3] else []
         | 
| 27 | 
            +
                    ctx.save_for_backward(
         | 
| 28 | 
            +
                        indices,
         | 
| 29 | 
            +
                        bin_ids,
         | 
| 30 | 
            +
                        weights,
         | 
| 31 | 
            +
                        bins,
         | 
| 32 | 
            +
                        padded_bins,
         | 
| 33 | 
            +
                        *maybe_x,
         | 
| 34 | 
            +
                    )
         | 
| 35 | 
            +
                    ctx.top_k = top_k
         | 
| 36 | 
            +
                    ctx.x_shape = x.shape
         | 
| 37 | 
            +
                    return kernels.padded_scatter(
         | 
| 38 | 
            +
                        x,
         | 
| 39 | 
            +
                        indices,
         | 
| 40 | 
            +
                        bin_ids,
         | 
| 41 | 
            +
                        weights,
         | 
| 42 | 
            +
                        bins,
         | 
| 43 | 
            +
                        padded_bins,
         | 
| 44 | 
            +
                        top_k,
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                @staticmethod
         | 
| 48 | 
            +
                @custom_bwd
         | 
| 49 | 
            +
                def backward(ctx: Any, grad: torch.Tensor):
         | 
| 50 | 
            +
                    grad = grad.contiguous()
         | 
| 51 | 
            +
                    saved_tensors = ctx.saved_tensors
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
         | 
| 54 | 
            +
                    dgrad = None
         | 
| 55 | 
            +
                    if ctx.needs_input_grad[0]:
         | 
| 56 | 
            +
                        dgrad = kernels.padded_gather(
         | 
| 57 | 
            +
                            grad,
         | 
| 58 | 
            +
                            indices,
         | 
| 59 | 
            +
                            bin_ids,
         | 
| 60 | 
            +
                            weights,
         | 
| 61 | 
            +
                            bins,
         | 
| 62 | 
            +
                            padded_bins,
         | 
| 63 | 
            +
                            ctx.top_k,
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    wgrad = None
         | 
| 67 | 
            +
                    if ctx.needs_input_grad[3]:  # need wgrad
         | 
| 68 | 
            +
                        x = saved_tensors[-1]
         | 
| 69 | 
            +
                        wgrad = kernels.padded_scatter_wgrad(
         | 
| 70 | 
            +
                            x,
         | 
| 71 | 
            +
                            grad,
         | 
| 72 | 
            +
                            indices,
         | 
| 73 | 
            +
                            bin_ids,
         | 
| 74 | 
            +
                            bins,
         | 
| 75 | 
            +
                            padded_bins,
         | 
| 76 | 
            +
                            ctx.top_k,
         | 
| 77 | 
            +
                        )
         | 
| 78 | 
            +
                    return dgrad, None, None, wgrad, None, None, None, None
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def padded_scatter(
         | 
| 82 | 
            +
                x: torch.Tensor,
         | 
| 83 | 
            +
                indices: torch.Tensor,
         | 
| 84 | 
            +
                bin_ids: torch.Tensor,
         | 
| 85 | 
            +
                weights: torch.Tensor,
         | 
| 86 | 
            +
                bins: torch.Tensor,
         | 
| 87 | 
            +
                padded_bins: torch.Tensor,
         | 
| 88 | 
            +
                top_k: int,
         | 
| 89 | 
            +
            ):
         | 
| 90 | 
            +
                return PaddedScatterOp.apply(
         | 
| 91 | 
            +
                    x,
         | 
| 92 | 
            +
                    indices,
         | 
| 93 | 
            +
                    bin_ids,
         | 
| 94 | 
            +
                    weights,
         | 
| 95 | 
            +
                    bins,
         | 
| 96 | 
            +
                    padded_bins,
         | 
| 97 | 
            +
                    top_k,
         | 
| 98 | 
            +
                )
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
    ADDED
    
    | @@ -0,0 +1,66 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import unittest
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from absl.testing import parameterized
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .. import benchmark_util, ops
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            _PADDED_SCATTER_BENCHMARK = (
         | 
| 12 | 
            +
                # dMoE-Medium, 8-way EMP.
         | 
| 13 | 
            +
                (1024 * 16, 1024, 8, 4),
         | 
| 14 | 
            +
                # dMoE-Medium, post-all-to-all.
         | 
| 15 | 
            +
                (1024 * 16 * 4, 1024, 8, 1),
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class PaddedScatterTest(parameterized.TestCase):
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
         | 
| 22 | 
            +
                def testPaddedScatter(self, sl, hs, ne, top_k):
         | 
| 23 | 
            +
                    # Create the data and indices.
         | 
| 24 | 
            +
                    x = torch.randn((sl, hs)).cuda().half()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    # Randomly assign tokens to experts.
         | 
| 27 | 
            +
                    top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
         | 
| 28 | 
            +
                    bin_ids, indices = ops.sort(top_expert)
         | 
| 29 | 
            +
                    tokens_per_expert = ops.histogram(top_expert, ne)
         | 
| 30 | 
            +
                    padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
         | 
| 31 | 
            +
                    padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         | 
| 32 | 
            +
                    bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    # Sample weights for the scatter reduce.
         | 
| 35 | 
            +
                    weights = torch.rand((sl * top_k,)).cuda().half()
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    # Gather the data to prepare for backwards.
         | 
| 38 | 
            +
                    x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    def benchmark():
         | 
| 41 | 
            +
                        return ops.padded_scatter(
         | 
| 42 | 
            +
                            x,
         | 
| 43 | 
            +
                            indices,
         | 
| 44 | 
            +
                            bin_ids,
         | 
| 45 | 
            +
                            weights,
         | 
| 46 | 
            +
                            bins,
         | 
| 47 | 
            +
                            padded_bins,
         | 
| 48 | 
            +
                            top_k,
         | 
| 49 | 
            +
                        )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    time, std = benchmark_util.benchmark_function(benchmark)
         | 
| 52 | 
            +
                    benchmark_util.log_benchmark(
         | 
| 53 | 
            +
                        'Padded Scatter',
         | 
| 54 | 
            +
                        {
         | 
| 55 | 
            +
                            'sequence_length': sl,
         | 
| 56 | 
            +
                            'hidden_size': hs,
         | 
| 57 | 
            +
                            'num_experts': ne,
         | 
| 58 | 
            +
                            'top_k': top_k,
         | 
| 59 | 
            +
                        },
         | 
| 60 | 
            +
                        time,
         | 
| 61 | 
            +
                        std,
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            if __name__ == '__main__':
         | 
| 66 | 
            +
                unittest.main()
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/permute_benchmark.py
    ADDED
    
    | @@ -0,0 +1,149 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import unittest
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from absl.testing import parameterized
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .. import benchmark_util, ops
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            _PERMUTE_TESTS = (
         | 
| 12 | 
            +
                (16384, 768, 2),
         | 
| 13 | 
            +
                (16384, 768, 4),
         | 
| 14 | 
            +
                (16384, 768, 8),
         | 
| 15 | 
            +
                (16384, 768, 16),
         | 
| 16 | 
            +
                (16384, 768, 32),
         | 
| 17 | 
            +
                (16384, 768, 64),
         | 
| 18 | 
            +
                (16384, 768, 128),
         | 
| 19 | 
            +
                (16384 * 8, 768, 2),
         | 
| 20 | 
            +
                (16384 * 8, 768, 4),
         | 
| 21 | 
            +
                (16384 * 8, 768, 8),
         | 
| 22 | 
            +
                (16384 * 8, 768, 16),
         | 
| 23 | 
            +
                (16384 * 8, 768, 32),
         | 
| 24 | 
            +
                (16384 * 8, 768, 64),
         | 
| 25 | 
            +
                (16384 * 8, 768, 128),
         | 
| 26 | 
            +
            )
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class PermuteBenchmark(parameterized.TestCase):
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                @parameterized.parameters(*_PERMUTE_TESTS)
         | 
| 32 | 
            +
                def testBinnedGather(self, sl, hs, ne):
         | 
| 33 | 
            +
                    # NOTE: Capacity factor == 1.
         | 
| 34 | 
            +
                    ec = sl // ne
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    # Create the data and indices.
         | 
| 37 | 
            +
                    x = torch.randn((sl, hs)).cuda().half()
         | 
| 38 | 
            +
                    top_expert = torch.randint(0, ne, (sl,)).cuda().int()
         | 
| 39 | 
            +
                    bin_ids, indices = ops.sort(top_expert)
         | 
| 40 | 
            +
                    tokens_per_expert = ops.histogram(indices, ne)
         | 
| 41 | 
            +
                    bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    def benchmark():
         | 
| 44 | 
            +
                        return ops.binned_gather(x, indices, bins, ec)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 47 | 
            +
                    arguments = {
         | 
| 48 | 
            +
                        'sequence_length': sl,
         | 
| 49 | 
            +
                        'hidden_size': hs,
         | 
| 50 | 
            +
                        'num_experts': ne,
         | 
| 51 | 
            +
                    }
         | 
| 52 | 
            +
                    benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                @parameterized.parameters(*_PERMUTE_TESTS)
         | 
| 55 | 
            +
                def testBinnedScatter(self, sl, hs, ne):
         | 
| 56 | 
            +
                    # NOTE: Capacity factor == 1.
         | 
| 57 | 
            +
                    ec = sl // ne
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # Create the data and indices.
         | 
| 60 | 
            +
                    x = torch.randn((sl, hs)).cuda().half()
         | 
| 61 | 
            +
                    top_expert = torch.randint(0, ne, (sl,)).cuda().int()
         | 
| 62 | 
            +
                    bin_ids, indices = ops.sort(top_expert)
         | 
| 63 | 
            +
                    tokens_per_expert = ops.histogram(indices, ne)
         | 
| 64 | 
            +
                    bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 65 | 
            +
                    x = ops.binned_gather(x, indices, bins, ec)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    def benchmark():
         | 
| 68 | 
            +
                        return ops.binned_scatter(x, indices, bins)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 71 | 
            +
                    arguments = {
         | 
| 72 | 
            +
                        'sequence_length': sl,
         | 
| 73 | 
            +
                        'hidden_size': hs,
         | 
| 74 | 
            +
                        'num_experts': ne,
         | 
| 75 | 
            +
                    }
         | 
| 76 | 
            +
                    benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                @parameterized.parameters(*_PERMUTE_TESTS)
         | 
| 79 | 
            +
                def testPaddedGather(self, sl, hs, ne):
         | 
| 80 | 
            +
                    # Create the data and indices.
         | 
| 81 | 
            +
                    x = torch.randn((sl, hs)).cuda().half()
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # Randomly assign tokens to experts.
         | 
| 84 | 
            +
                    top_expert = torch.randint(0, ne, (sl,)).cuda().int()
         | 
| 85 | 
            +
                    bin_ids, indices = ops.sort(top_expert)
         | 
| 86 | 
            +
                    tokens_per_expert = ops.histogram(top_expert, ne)
         | 
| 87 | 
            +
                    padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
         | 
| 88 | 
            +
                    padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         | 
| 89 | 
            +
                    bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    def benchmark():
         | 
| 92 | 
            +
                        return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 95 | 
            +
                    arguments = {
         | 
| 96 | 
            +
                        'sequence_length': sl,
         | 
| 97 | 
            +
                        'hidden_size': hs,
         | 
| 98 | 
            +
                        'num_experts': ne,
         | 
| 99 | 
            +
                    }
         | 
| 100 | 
            +
                    benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                @parameterized.parameters(*_PERMUTE_TESTS)
         | 
| 103 | 
            +
                def testPaddedScatter(self, sl, hs, ne):
         | 
| 104 | 
            +
                    # Create the data and indices.
         | 
| 105 | 
            +
                    x = torch.randn((sl, hs)).cuda().half()
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    # Randomly assign tokens to experts.
         | 
| 108 | 
            +
                    top_expert = torch.randint(0, ne, (sl,)).cuda().int()
         | 
| 109 | 
            +
                    bin_ids, indices = ops.sort(top_expert)
         | 
| 110 | 
            +
                    tokens_per_expert = ops.histogram(top_expert, ne)
         | 
| 111 | 
            +
                    padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
         | 
| 112 | 
            +
                    padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         | 
| 113 | 
            +
                    bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         | 
| 114 | 
            +
                    x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    def benchmark():
         | 
| 117 | 
            +
                        return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 120 | 
            +
                    arguments = {
         | 
| 121 | 
            +
                        'sequence_length': sl,
         | 
| 122 | 
            +
                        'hidden_size': hs,
         | 
| 123 | 
            +
                        'num_experts': ne,
         | 
| 124 | 
            +
                    }
         | 
| 125 | 
            +
                    benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                @parameterized.parameters(*_PERMUTE_TESTS)
         | 
| 128 | 
            +
                def testCopy(self, sl, hs, ne):
         | 
| 129 | 
            +
                    # NOTE: Capacity factor == 1.
         | 
| 130 | 
            +
                    # ec = sl // ne
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    # Create the data and indices.
         | 
| 133 | 
            +
                    x = torch.randn((sl, hs)).cuda().half()
         | 
| 134 | 
            +
                    y = x.clone()
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    def benchmark():
         | 
| 137 | 
            +
                        return y.copy_(x)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    mean_t, std_t = benchmark_util.benchmark_function(benchmark)
         | 
| 140 | 
            +
                    arguments = {
         | 
| 141 | 
            +
                        'sequence_length': sl,
         | 
| 142 | 
            +
                        'hidden_size': hs,
         | 
| 143 | 
            +
                        'num_experts': ne,
         | 
| 144 | 
            +
                    }
         | 
| 145 | 
            +
                    benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            if __name__ == '__main__':
         | 
| 149 | 
            +
                unittest.main()
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/repeat.py
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def repeat(x: torch.Tensor, tiling: torch.Size):
         | 
| 8 | 
            +
                if all((t == 1 for t in tiling)):
         | 
| 9 | 
            +
                    return x
         | 
| 10 | 
            +
                return x.repeat(*tiling)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/replicate.py
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # NOTE: Torch needs to be imported before the custom
         | 
| 7 | 
            +
            # extensions. Otherwise libc10.so cannot be found.
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Wrap this in a try-block with better error message and
         | 
| 11 | 
            +
            # instructions for building the c++ operations.
         | 
| 12 | 
            +
            try:
         | 
| 13 | 
            +
                from .._ops import ops  # type: ignore
         | 
| 14 | 
            +
            except ModuleNotFoundError as e:
         | 
| 15 | 
            +
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Autograd wrapper for replicate kernel.
         | 
| 19 | 
            +
            class ReplicateOp(torch.autograd.Function):
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                @staticmethod
         | 
| 22 | 
            +
                def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
         | 
| 23 | 
            +
                    ctx.save_for_backward(bins)
         | 
| 24 | 
            +
                    out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
         | 
| 25 | 
            +
                    ops.replicate_forward(x, bins, out)
         | 
| 26 | 
            +
                    return out
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                @staticmethod
         | 
| 29 | 
            +
                def backward(ctx: Any, grad: torch.Tensor):
         | 
| 30 | 
            +
                    bins, = ctx.saved_tensors
         | 
| 31 | 
            +
                    out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
         | 
| 32 | 
            +
                    ops.replicate_backward(grad, bins, out)
         | 
| 33 | 
            +
                    return out, None, None
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            replicate = ReplicateOp.apply
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/round_up.py
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def round_up(x: torch.Tensor, value: int):
         | 
| 8 | 
            +
                assert isinstance(value, int)
         | 
| 9 | 
            +
                assert x.dtype == torch.int32
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                # TODO(tgale): If this becomes and issue
         | 
| 12 | 
            +
                # do this in a custom kernel. We only expect
         | 
| 13 | 
            +
                # to use this on arrays of less than 1k elements.
         | 
| 14 | 
            +
                return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/scatter.py
    ADDED
    
    | @@ -0,0 +1,72 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Any, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from .stk_autocast import custom_bwd, custom_fwd
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ..backend import kernels
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Autograd wrapper for scatter kernel.
         | 
| 13 | 
            +
            class ScatterOp(torch.autograd.Function):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                @staticmethod
         | 
| 16 | 
            +
                @custom_fwd
         | 
| 17 | 
            +
                def forward(
         | 
| 18 | 
            +
                    ctx: Any,
         | 
| 19 | 
            +
                    x: torch.Tensor,
         | 
| 20 | 
            +
                    indices: torch.Tensor,
         | 
| 21 | 
            +
                    bin_ids: torch.Tensor,
         | 
| 22 | 
            +
                    weights: torch.Tensor,
         | 
| 23 | 
            +
                    bins: torch.Tensor,
         | 
| 24 | 
            +
                    top_k: int,
         | 
| 25 | 
            +
                ) -> torch.Tensor:
         | 
| 26 | 
            +
                    maybe_x = [x] if ctx.needs_input_grad[3] else []
         | 
| 27 | 
            +
                    ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
         | 
| 28 | 
            +
                    ctx.top_k = top_k
         | 
| 29 | 
            +
                    ctx.x_shape = x.shape
         | 
| 30 | 
            +
                    return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                @staticmethod
         | 
| 33 | 
            +
                @custom_bwd
         | 
| 34 | 
            +
                def backward(ctx: Any, grad: torch.Tensor):
         | 
| 35 | 
            +
                    grad = grad.contiguous()
         | 
| 36 | 
            +
                    saved_tensors = ctx.saved_tensors
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    indices, bin_ids, weights, bins = saved_tensors[:4]
         | 
| 39 | 
            +
                    dgrad = None
         | 
| 40 | 
            +
                    if ctx.needs_input_grad[0]:
         | 
| 41 | 
            +
                        dgrad = kernels.gather(
         | 
| 42 | 
            +
                            grad,
         | 
| 43 | 
            +
                            indices,
         | 
| 44 | 
            +
                            bin_ids,
         | 
| 45 | 
            +
                            weights,
         | 
| 46 | 
            +
                            bins,
         | 
| 47 | 
            +
                            ctx.top_k,
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    wgrad = None
         | 
| 51 | 
            +
                    if ctx.needs_input_grad[3]:  # need wgrad
         | 
| 52 | 
            +
                        x = saved_tensors[-1]
         | 
| 53 | 
            +
                        wgrad = kernels.scatter_wgrad(
         | 
| 54 | 
            +
                            x,
         | 
| 55 | 
            +
                            grad,
         | 
| 56 | 
            +
                            indices,
         | 
| 57 | 
            +
                            bin_ids,
         | 
| 58 | 
            +
                            bins,
         | 
| 59 | 
            +
                            ctx.top_k,
         | 
| 60 | 
            +
                        )
         | 
| 61 | 
            +
                    return dgrad, None, None, wgrad, None, None, None
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def scatter(
         | 
| 65 | 
            +
                x: torch.Tensor,
         | 
| 66 | 
            +
                indices: torch.Tensor,
         | 
| 67 | 
            +
                bin_ids: torch.Tensor,
         | 
| 68 | 
            +
                weights: torch.Tensor,
         | 
| 69 | 
            +
                bins: torch.Tensor,
         | 
| 70 | 
            +
                top_k: int,
         | 
| 71 | 
            +
            ) -> Optional[torch.Tensor]:
         | 
| 72 | 
            +
                return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort.py
    ADDED
    
    | @@ -0,0 +1,38 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import Any, Optional, Tuple
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # NOTE: Torch needs to be imported before the custom
         | 
| 7 | 
            +
            # extensions. Otherwise libc10.so cannot be found.
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Wrap this in a try-block with better error message and
         | 
| 11 | 
            +
            # instructions for building the c++ operations.
         | 
| 12 | 
            +
            try:
         | 
| 13 | 
            +
                from .._ops import ops  # type: ignore
         | 
| 14 | 
            +
            except ModuleNotFoundError as e:
         | 
| 15 | 
            +
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            _BITS_FOR_DTYPE = {
         | 
| 18 | 
            +
                torch.int16: 16,
         | 
| 19 | 
            +
                torch.int32: 32,
         | 
| 20 | 
            +
                torch.int64: 64,
         | 
| 21 | 
            +
            }
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            # Autograd wrapper for sort kernel.
         | 
| 25 | 
            +
            # NOTE: Does not support gradients.
         | 
| 26 | 
            +
            class SortOp(torch.autograd.Function):
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                @staticmethod
         | 
| 29 | 
            +
                def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 30 | 
            +
                    if end_bit is None:
         | 
| 31 | 
            +
                        end_bit = _BITS_FOR_DTYPE[x.dtype]
         | 
| 32 | 
            +
                    x_out = torch.empty_like(x)
         | 
| 33 | 
            +
                    iota_out = torch.empty_like(x)
         | 
| 34 | 
            +
                    ops.sort(x, end_bit, x_out, iota_out)
         | 
| 35 | 
            +
                    return (x_out, iota_out)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            sort = SortOp.apply
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sort_benchmark.py
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import unittest
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from absl.testing import parameterized
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .. import ops
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            _SORT_TESTS = (
         | 
| 13 | 
            +
                (16384, torch.int32, None),
         | 
| 14 | 
            +
                (16384, torch.int32, 2),
         | 
| 15 | 
            +
                (16384, torch.int32, 128),
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            _BASELINE_SORT_TESTS = ((16384,),)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def numpy_dtype(dtype):
         | 
| 22 | 
            +
                types = {
         | 
| 23 | 
            +
                    torch.int16: np.int16,
         | 
| 24 | 
            +
                    torch.int32: np.int32,
         | 
| 25 | 
            +
                    torch.int64: np.int64,
         | 
| 26 | 
            +
                }
         | 
| 27 | 
            +
                return types[dtype]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def benchmark_function(fn, iterations=10):
         | 
| 31 | 
            +
                # Run once to get rid of startup overhead.
         | 
| 32 | 
            +
                fn()
         | 
| 33 | 
            +
                times = []
         | 
| 34 | 
            +
                for _ in range(iterations):
         | 
| 35 | 
            +
                    start = torch.cuda.Event(enable_timing=True)
         | 
| 36 | 
            +
                    end = torch.cuda.Event(enable_timing=True)
         | 
| 37 | 
            +
                    start.record()
         | 
| 38 | 
            +
                    fn()
         | 
| 39 | 
            +
                    end.record()
         | 
| 40 | 
            +
                    torch.cuda.synchronize()
         | 
| 41 | 
            +
                    times.append(start.elapsed_time(end))
         | 
| 42 | 
            +
                times = np.array(times)
         | 
| 43 | 
            +
                return times.mean(), times.std(), times.max(), times.min()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def log_benchmark(arguments, mean_t, std_t):
         | 
| 47 | 
            +
                print('=' * 60)
         | 
| 48 | 
            +
                print('Benchmark Parameters:')
         | 
| 49 | 
            +
                for (key, value) in arguments.items():
         | 
| 50 | 
            +
                    print(f'{key} = {value}')
         | 
| 51 | 
            +
                print('Results:')
         | 
| 52 | 
            +
                print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
         | 
| 53 | 
            +
                print('=' * 60)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            class SortBenchmark(parameterized.TestCase):
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                @parameterized.parameters(*_SORT_TESTS)
         | 
| 59 | 
            +
                def testSort(self, n, dtype, max_val):
         | 
| 60 | 
            +
                    if max_val is None:
         | 
| 61 | 
            +
                        max_val = np.iinfo(numpy_dtype(dtype)).max
         | 
| 62 | 
            +
                    end_bit = int(np.ceil(np.log2(max_val)))
         | 
| 63 | 
            +
                    x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
         | 
| 66 | 
            +
                    arguments = {
         | 
| 67 | 
            +
                        'n': n,
         | 
| 68 | 
            +
                        'dtype': dtype,
         | 
| 69 | 
            +
                        'max_val': max_val,
         | 
| 70 | 
            +
                    }
         | 
| 71 | 
            +
                    log_benchmark(arguments, mean_t, std_t)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                @parameterized.parameters(*_BASELINE_SORT_TESTS)
         | 
| 74 | 
            +
                def testTorchSort(self, n):
         | 
| 75 | 
            +
                    x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
         | 
| 78 | 
            +
                    arguments = {
         | 
| 79 | 
            +
                        'n': n,
         | 
| 80 | 
            +
                    }
         | 
| 81 | 
            +
                    log_benchmark(arguments, mean_t, std_t)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            if __name__ == '__main__':
         | 
| 85 | 
            +
                unittest.main()
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/stk_autocast.py
    ADDED
    
    | @@ -0,0 +1,39 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # vendored from
         | 
| 2 | 
            +
            # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
         | 
| 3 | 
            +
            import functools
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def _is_eligible(x):
         | 
| 8 | 
            +
                return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def _cast(x, dtype):
         | 
| 12 | 
            +
                if isinstance(x, torch.Tensor) and _is_eligible(x):
         | 
| 13 | 
            +
                    return x.to(dtype)
         | 
| 14 | 
            +
                elif isinstance(x, map):
         | 
| 15 | 
            +
                    return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
         | 
| 16 | 
            +
                elif isinstance(x, list) or isinstance(x, tuple):
         | 
| 17 | 
            +
                    return type(x)(map(lambda y: _cast(y, dtype), x))
         | 
| 18 | 
            +
                return x
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def custom_fwd(fwd):
         | 
| 22 | 
            +
                """Wrap a custom autograd function that always uses autocast dtype."""
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                @functools.wraps(fwd)
         | 
| 25 | 
            +
                def decorate_fwd(*args, **kwargs):
         | 
| 26 | 
            +
                    if torch.is_autocast_enabled():
         | 
| 27 | 
            +
                        with torch.autocast(device_type="cuda", enabled=False):
         | 
| 28 | 
            +
                            dtype = torch.get_autocast_gpu_dtype()
         | 
| 29 | 
            +
                            return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
         | 
| 30 | 
            +
                    return fwd(*args, **kwargs)
         | 
| 31 | 
            +
                return decorate_fwd
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def custom_bwd(bwd):
         | 
| 35 | 
            +
                @functools.wraps(bwd)
         | 
| 36 | 
            +
                def decorate_bwd(*args, **kwargs):
         | 
| 37 | 
            +
                    with torch.autocast(device_type="cuda", enabled=False):
         | 
| 38 | 
            +
                        return bwd(*args, **kwargs)
         | 
| 39 | 
            +
                return decorate_bwd
         | 
    	
        build/torch27-cxx11-cu118-x86_64-linux/megablocks/ops/sum.py
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 Databricks
         | 
| 2 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def sum(x: torch.Tensor, dim: int = 0):
         | 
| 7 | 
            +
                if x.shape[dim] == 1:
         | 
| 8 | 
            +
                    return x.squeeze(dim=dim)
         | 
| 9 | 
            +
                return x.sum(dim=dim)
         | 
