drbh
		
	commited on
		
		
					Commit 
							
							·
						
						9c4ca75
	
1
								Parent(s):
							
							a4f6452
								
feat: validate build with original test suite
Browse files- README.md +36 -0
 - build.toml +24 -2
 - tests/conftest.py +110 -0
 - tests/fixtures/autouse.py +107 -0
 - tests/fixtures/fixtures.py +13 -0
 - tests/layers/architectures.py +53 -0
 - tests/layers/moe_test.py +199 -0
 - tests/ops/binned_gather_test.py +71 -0
 - tests/ops/binned_scatter_test.py +87 -0
 - tests/ops/cumsum_test.py +44 -0
 - tests/ops/histogram_test.py +82 -0
 - tests/ops/padded_gather_test.py +94 -0
 - tests/ops/padded_scatter_test.py +155 -0
 - tests/ops/replicate_test.py +108 -0
 - tests/ops/sort_test.py +65 -0
 - tests/ops/topology_test.py +81 -0
 - tests/test_mb_moe.py +42 -0
 - torch-ext/megablocks/__init__.py +6 -2
 - torch-ext/megablocks/ops/cumsum.py +1 -1
 - torch-ext/megablocks/ops/histogram.py +1 -1
 - torch-ext/megablocks/ops/replicate.py +1 -2
 - torch-ext/megablocks/ops/sort.py +1 -2
 - torch-ext/megablocks/ops/topology.py +1 -2
 - torch-ext/torch_binding.cpp +13 -13
 
    	
        README.md
    CHANGED
    
    | 
         @@ -4,3 +4,39 @@ tags: 
     | 
|
| 4 | 
         
             
              - kernel
         
     | 
| 5 | 
         
             
            ---
         
     | 
| 6 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 4 | 
         
             
              - kernel
         
     | 
| 5 | 
         
             
            ---
         
     | 
| 6 | 
         | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            ```bash
         
     | 
| 10 | 
         
            +
            nix develop --show-trace -i -L .#test --command python -m pytest -s tests
         
     | 
| 11 | 
         
            +
            ```
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            expected output:
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            ```
         
     | 
| 16 | 
         
            +
            ============== test session starts ===============
         
     | 
| 17 | 
         
            +
            platform linux -- Python 3.12.10, pytest-8.3.5, pluggy-1.5.0
         
     | 
| 18 | 
         
            +
            rootdir: /home/ubuntu/Projects/megablocks-moe
         
     | 
| 19 | 
         
            +
            plugins: hypothesis-6.130.12
         
     | 
| 20 | 
         
            +
            collecting 43 items                              world_size=1
         
     | 
| 21 | 
         
            +
            collected 387 items
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            tests/layers/moe_test.py ...........................................
         
     | 
| 24 | 
         
            +
            tests/ops/binned_gather_test.py .....................
         
     | 
| 25 | 
         
            +
            tests/ops/binned_scatter_test.py .....................
         
     | 
| 26 | 
         
            +
            tests/ops/cumsum_test.py ................................
         
     | 
| 27 | 
         
            +
            tests/ops/histogram_test.py ......................................................
         
     | 
| 28 | 
         
            +
            tests/ops/padded_gather_test.py ......................................
         
     | 
| 29 | 
         
            +
            tests/ops/padded_scatter_test.py ......................................................
         
     | 
| 30 | 
         
            +
            tests/ops/replicate_test.py ..................................................................................
         
     | 
| 31 | 
         
            +
            tests/ops/sort_test.py ..................
         
     | 
| 32 | 
         
            +
            tests/ops/topology_test.py ....................
         
     | 
| 33 | 
         
            +
            tests/test_mb_moe.py megablocks_moe module imported successfully.
         
     | 
| 34 | 
         
            +
            Available functions: ['Arguments', 'MLP', 'MoE', 'ParallelDroplessMLP', 'ParallelMLP', 'SparseGLU', 'SparseMLP', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_megablocks_a4f6452_dirty', '_ops', 'argsort', 'backend', 'cumsum', 'dMoE', 'exclusive_cumsum', 'get_load_balancing_loss', 'grouped_gemm_util', 'histogram', 'inclusive_cumsum', 'indices', 'layers', 'ops', 'replicate_backward', 'replicate_forward', 'sort', 'torch']
         
     | 
| 35 | 
         
            +
            .cumsum output: tensor([0, 1, 3, 6], device='cuda:0', dtype=torch.int16)
         
     | 
| 36 | 
         
            +
            ...
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            ================ warnings summary ================
         
     | 
| 39 | 
         
            +
            ...
         
     | 
| 40 | 
         
            +
            -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
         
     | 
| 41 | 
         
            +
            ======= 387 passed, 18 warnings in 54.63s ========
         
     | 
| 42 | 
         
            +
            ```
         
     | 
    	
        build.toml
    CHANGED
    
    | 
         @@ -10,6 +10,20 @@ src = [ 
     | 
|
| 10 | 
         | 
| 11 | 
         
             
            [kernel.megablocks]
         
     | 
| 12 | 
         
             
            backend = "cuda"
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         
             
            src = [
         
     | 
| 14 | 
         
             
                "csrc/new_cumsum.h",
         
     | 
| 15 | 
         
             
                "csrc/new_cumsum.cu",
         
     | 
| 
         @@ -22,9 +36,17 @@ src = [ 
     | 
|
| 22 | 
         
             
                "csrc/new_sort.h",
         
     | 
| 23 | 
         
             
                "csrc/new_sort.cu",
         
     | 
| 24 | 
         
             
            ]
         
     | 
| 25 | 
         
            -
            depends = [ "torch", "cutlass_3_8" ]
         
     | 
| 26 | 
         | 
| 27 | 
         
             
            [test]
         
     | 
| 28 | 
         
             
            python-git-packages = [
         
     | 
| 29 | 
         
            -
                { url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" }
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 30 | 
         
             
            ]
         
     | 
| 
         | 
|
| 10 | 
         | 
| 11 | 
         
             
            [kernel.megablocks]
         
     | 
| 12 | 
         
             
            backend = "cuda"
         
     | 
| 13 | 
         
            +
            cuda-capabilities = [
         
     | 
| 14 | 
         
            +
                "7.0",
         
     | 
| 15 | 
         
            +
                "7.2",
         
     | 
| 16 | 
         
            +
                "7.5",
         
     | 
| 17 | 
         
            +
                "8.0",
         
     | 
| 18 | 
         
            +
                "8.6",
         
     | 
| 19 | 
         
            +
                "8.7",
         
     | 
| 20 | 
         
            +
                "8.9",
         
     | 
| 21 | 
         
            +
                "9.0",
         
     | 
| 22 | 
         
            +
                "10.0",
         
     | 
| 23 | 
         
            +
                "10.1",
         
     | 
| 24 | 
         
            +
                "12.0",
         
     | 
| 25 | 
         
            +
            ]
         
     | 
| 26 | 
         
            +
            depends = ["torch", "cutlass_3_8"]
         
     | 
| 27 | 
         
             
            src = [
         
     | 
| 28 | 
         
             
                "csrc/new_cumsum.h",
         
     | 
| 29 | 
         
             
                "csrc/new_cumsum.cu",
         
     | 
| 
         | 
|
| 36 | 
         
             
                "csrc/new_sort.h",
         
     | 
| 37 | 
         
             
                "csrc/new_sort.cu",
         
     | 
| 38 | 
         
             
            ]
         
     | 
| 
         | 
|
| 39 | 
         | 
| 40 | 
         
             
            [test]
         
     | 
| 41 | 
         
             
            python-git-packages = [
         
     | 
| 42 | 
         
            +
                { url = "https://github.com/stanford-futuredata/stk.git", rev = "7363137", sha256 = "0m6g5l9nlwaiwybg5j8dhnz159wdpabdnkzapnn3dsifxrsb59vz" },
         
     | 
| 43 | 
         
            +
                { url = "https://github.com/mosaicml/composer.git", rev = "v0.9.0", sha256 = "ekJ5nE6JwYY6Ld9kIk72R/a3iI943Gd5yvAkBHQs5aI=" },
         
     | 
| 44 | 
         
            +
                # { url = "https://github.com/tgale96/grouped_gemm.git", rev = "v0.3.0", sha256 = "sha256-fS6MuDj6yQ00CSzFrmAmM20/ccvtLJ1MFjfeqdwuPl8=" }
         
     | 
| 45 | 
         
            +
            ]
         
     | 
| 46 | 
         
            +
            python-packages = [
         
     | 
| 47 | 
         
            +
                "tqdm",
         
     | 
| 48 | 
         
            +
                "py-cpuinfo",
         
     | 
| 49 | 
         
            +
                "importlib-metadata",
         
     | 
| 50 | 
         
            +
                "torchmetrics",
         
     | 
| 51 | 
         
            +
                # "yahp"
         
     | 
| 52 | 
         
             
            ]
         
     | 
    	
        tests/conftest.py
    ADDED
    
    | 
         @@ -0,0 +1,110 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import os
         
     | 
| 5 | 
         
            +
            from typing import List, Optional
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import pytest
         
     | 
| 8 | 
         
            +
            # from composer.utils import reproducibility
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Allowed options for pytest.mark.world_size()
         
     | 
| 11 | 
         
            +
            WORLD_SIZE_OPTIONS = (1, 2)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            # Enforce deterministic mode before any tests start.
         
     | 
| 14 | 
         
            +
            # reproducibility.configure_deterministic_mode()
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            # TODO: allow plugind when deps resolved
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            # Add the path of any pytest fixture files you want to make global
         
     | 
| 19 | 
         
            +
            pytest_plugins = [
         
     | 
| 20 | 
         
            +
                # 'tests.fixtures.autouse',
         
     | 
| 21 | 
         
            +
                'tests.fixtures.fixtures',
         
     | 
| 22 | 
         
            +
            ]
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def _get_world_size(item: pytest.Item):
         
     | 
| 26 | 
         
            +
                """Returns the world_size of a test, defaults to 1."""
         
     | 
| 27 | 
         
            +
                _default = pytest.mark.world_size(1).mark
         
     | 
| 28 | 
         
            +
                return item.get_closest_marker('world_size', default=_default).args[0]
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            def _get_option(
         
     | 
| 32 | 
         
            +
                config: pytest.Config,
         
     | 
| 33 | 
         
            +
                name: str,
         
     | 
| 34 | 
         
            +
                default: Optional[str] = None,
         
     | 
| 35 | 
         
            +
            ) -> str:  # type: ignore
         
     | 
| 36 | 
         
            +
                val = config.getoption(name)
         
     | 
| 37 | 
         
            +
                if val is not None:
         
     | 
| 38 | 
         
            +
                    assert isinstance(val, str)
         
     | 
| 39 | 
         
            +
                    return val
         
     | 
| 40 | 
         
            +
                val = config.getini(name)
         
     | 
| 41 | 
         
            +
                if val == []:
         
     | 
| 42 | 
         
            +
                    val = None
         
     | 
| 43 | 
         
            +
                if val is None:
         
     | 
| 44 | 
         
            +
                    if default is None:
         
     | 
| 45 | 
         
            +
                        pytest.fail(f'Config option {name} is not specified but is required',)
         
     | 
| 46 | 
         
            +
                    val = default
         
     | 
| 47 | 
         
            +
                assert isinstance(val, str)
         
     | 
| 48 | 
         
            +
                return val
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            def _add_option(
         
     | 
| 52 | 
         
            +
                parser: pytest.Parser,
         
     | 
| 53 | 
         
            +
                name: str,
         
     | 
| 54 | 
         
            +
                help: str,
         
     | 
| 55 | 
         
            +
                choices: Optional[list[str]] = None,
         
     | 
| 56 | 
         
            +
            ):
         
     | 
| 57 | 
         
            +
                parser.addoption(
         
     | 
| 58 | 
         
            +
                    f'--{name}',
         
     | 
| 59 | 
         
            +
                    default=None,
         
     | 
| 60 | 
         
            +
                    type=str,
         
     | 
| 61 | 
         
            +
                    choices=choices,
         
     | 
| 62 | 
         
            +
                    help=help,
         
     | 
| 63 | 
         
            +
                )
         
     | 
| 64 | 
         
            +
                parser.addini(
         
     | 
| 65 | 
         
            +
                    name=name,
         
     | 
| 66 | 
         
            +
                    help=help,
         
     | 
| 67 | 
         
            +
                    type='string',
         
     | 
| 68 | 
         
            +
                    default=None,
         
     | 
| 69 | 
         
            +
                )
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            def pytest_collection_modifyitems(
         
     | 
| 73 | 
         
            +
                config: pytest.Config,
         
     | 
| 74 | 
         
            +
                items: List[pytest.Item],
         
     | 
| 75 | 
         
            +
            ) -> None:
         
     | 
| 76 | 
         
            +
                """Filter tests by world_size (for multi-GPU tests)"""
         
     | 
| 77 | 
         
            +
                world_size = int(os.environ.get('WORLD_SIZE', '1'))
         
     | 
| 78 | 
         
            +
                print(f'world_size={world_size}')
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                conditions = [
         
     | 
| 81 | 
         
            +
                    lambda item: _get_world_size(item) == world_size,
         
     | 
| 82 | 
         
            +
                ]
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                # keep items that satisfy all conditions
         
     | 
| 85 | 
         
            +
                remaining = []
         
     | 
| 86 | 
         
            +
                deselected = []
         
     | 
| 87 | 
         
            +
                for item in items:
         
     | 
| 88 | 
         
            +
                    if all(condition(item) for condition in conditions):
         
     | 
| 89 | 
         
            +
                        remaining.append(item)
         
     | 
| 90 | 
         
            +
                    else:
         
     | 
| 91 | 
         
            +
                        deselected.append(item)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                if deselected:
         
     | 
| 94 | 
         
            +
                    config.hook.pytest_deselected(items=deselected)
         
     | 
| 95 | 
         
            +
                    items[:] = remaining
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            def pytest_addoption(parser: pytest.Parser) -> None:
         
     | 
| 99 | 
         
            +
                _add_option(
         
     | 
| 100 | 
         
            +
                    parser,
         
     | 
| 101 | 
         
            +
                    'seed',
         
     | 
| 102 | 
         
            +
                    help="""\
         
     | 
| 103 | 
         
            +
                    Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked
         
     | 
| 104 | 
         
            +
                    before each test.""",
         
     | 
| 105 | 
         
            +
                )
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
         
     | 
| 109 | 
         
            +
                if exitstatus == 5:
         
     | 
| 110 | 
         
            +
                    session.exitstatus = 0  # Ignore no-test-ran errors
         
     | 
    	
        tests/fixtures/autouse.py
    ADDED
    
    | 
         @@ -0,0 +1,107 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import gc
         
     | 
| 5 | 
         
            +
            import logging
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import composer
         
     | 
| 9 | 
         
            +
            import pytest
         
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            from composer.devices import DeviceCPU, DeviceGPU
         
     | 
| 12 | 
         
            +
            from composer.utils import dist, reproducibility
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            @pytest.fixture(autouse=True)
         
     | 
| 16 | 
         
            +
            def clear_cuda_cache(request: pytest.FixtureRequest):
         
     | 
| 17 | 
         
            +
                """Clear memory between GPU tests."""
         
     | 
| 18 | 
         
            +
                marker = request.node.get_closest_marker('gpu')
         
     | 
| 19 | 
         
            +
                if marker is not None and torch.cuda.is_available():
         
     | 
| 20 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 21 | 
         
            +
                    gc.collect()  # Only gc on GPU tests as it 2x slows down CPU tests
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            @pytest.fixture(autouse=True)
         
     | 
| 25 | 
         
            +
            def reset_mlflow_tracking_dir():
         
     | 
| 26 | 
         
            +
                """Reset MLFlow tracking dir so it doesn't persist across tests."""
         
     | 
| 27 | 
         
            +
                try:
         
     | 
| 28 | 
         
            +
                    import mlflow
         
     | 
| 29 | 
         
            +
                    mlflow.set_tracking_uri(None)  # type: ignore
         
     | 
| 30 | 
         
            +
                except ModuleNotFoundError:
         
     | 
| 31 | 
         
            +
                    # MLFlow not installed
         
     | 
| 32 | 
         
            +
                    pass
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @pytest.fixture(scope='session')
         
     | 
| 36 | 
         
            +
            def cleanup_dist():
         
     | 
| 37 | 
         
            +
                """Ensure all dist tests clean up resources properly."""
         
     | 
| 38 | 
         
            +
                yield
         
     | 
| 39 | 
         
            +
                # Avoid race condition where a test is still writing to a file on one rank
         
     | 
| 40 | 
         
            +
                # while the file system is being torn down on another rank.
         
     | 
| 41 | 
         
            +
                dist.barrier()
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            @pytest.fixture(autouse=True, scope='session')
         
     | 
| 45 | 
         
            +
            def configure_dist(request: pytest.FixtureRequest):
         
     | 
| 46 | 
         
            +
                # Configure dist globally when the world size is greater than 1,
         
     | 
| 47 | 
         
            +
                # so individual tests that do not use the trainer
         
     | 
| 48 | 
         
            +
                # do not need to worry about manually configuring dist.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                if dist.get_world_size() == 1:
         
     | 
| 51 | 
         
            +
                    return
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                device = None
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                for item in request.session.items:
         
     | 
| 56 | 
         
            +
                    device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU()
         
     | 
| 57 | 
         
            +
                    break
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                assert device is not None
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                if not dist.is_initialized():
         
     | 
| 62 | 
         
            +
                    dist.initialize_dist(device, timeout=300.0)
         
     | 
| 63 | 
         
            +
                # Hold PyTest until all ranks have reached this barrier. Ensure that no rank starts
         
     | 
| 64 | 
         
            +
                # any test before other ranks are ready to start it, which could be a cause of random timeouts
         
     | 
| 65 | 
         
            +
                # (e.g. rank 1 starts the next test while rank 0 is finishing up the previous test).
         
     | 
| 66 | 
         
            +
                dist.barrier()
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            @pytest.fixture(autouse=True)
         
     | 
| 70 | 
         
            +
            def set_log_levels():
         
     | 
| 71 | 
         
            +
                """Ensures all log levels are set to DEBUG."""
         
     | 
| 72 | 
         
            +
                logging.basicConfig()
         
     | 
| 73 | 
         
            +
                logging.getLogger(composer.__name__).setLevel(logging.DEBUG)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            @pytest.fixture(autouse=True)
         
     | 
| 77 | 
         
            +
            def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch):
         
     | 
| 78 | 
         
            +
                """Monkeypatch reproducibility.
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local
         
     | 
| 81 | 
         
            +
                seed.
         
     | 
| 82 | 
         
            +
                """
         
     | 
| 83 | 
         
            +
                monkeypatch.setattr(
         
     | 
| 84 | 
         
            +
                    reproducibility,
         
     | 
| 85 | 
         
            +
                    'get_random_seed',
         
     | 
| 86 | 
         
            +
                    lambda: rank_zero_seed,
         
     | 
| 87 | 
         
            +
                )
         
     | 
| 88 | 
         
            +
                reproducibility.seed_all(rank_zero_seed + dist.get_global_rank())
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            @pytest.fixture(autouse=True)
         
     | 
| 92 | 
         
            +
            def remove_run_name_env_var():
         
     | 
| 93 | 
         
            +
                # Remove environment variables for run names in unit tests
         
     | 
| 94 | 
         
            +
                composer_run_name = os.environ.get('COMPOSER_RUN_NAME')
         
     | 
| 95 | 
         
            +
                run_name = os.environ.get('RUN_NAME')
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                if 'COMPOSER_RUN_NAME' in os.environ:
         
     | 
| 98 | 
         
            +
                    del os.environ['COMPOSER_RUN_NAME']
         
     | 
| 99 | 
         
            +
                if 'RUN_NAME' in os.environ:
         
     | 
| 100 | 
         
            +
                    del os.environ['RUN_NAME']
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                yield
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                if composer_run_name is not None:
         
     | 
| 105 | 
         
            +
                    os.environ['COMPOSER_RUN_NAME'] = composer_run_name
         
     | 
| 106 | 
         
            +
                if run_name is not None:
         
     | 
| 107 | 
         
            +
                    os.environ['RUN_NAME'] = run_name
         
     | 
    	
        tests/fixtures/fixtures.py
    ADDED
    
    | 
         @@ -0,0 +1,13 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import pytest
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from tests.conftest import _get_option
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            @pytest.fixture
         
     | 
| 10 | 
         
            +
            def rank_zero_seed(pytestconfig: pytest.Config) -> int:
         
     | 
| 11 | 
         
            +
                """Read the rank_zero_seed from the CLI option."""
         
     | 
| 12 | 
         
            +
                seed = _get_option(pytestconfig, 'seed', default='0')
         
     | 
| 13 | 
         
            +
                return int(seed)
         
     | 
    	
        tests/layers/architectures.py
    ADDED
    
    | 
         @@ -0,0 +1,53 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from megablocks.layers.arguments import Arguments
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class FFN(torch.nn.Module):
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def __init__(self, args: Arguments):
         
     | 
| 13 | 
         
            +
                    super().__init__()
         
     | 
| 14 | 
         
            +
                    self.w1 = torch.nn.Parameter(
         
     | 
| 15 | 
         
            +
                        torch.empty(
         
     | 
| 16 | 
         
            +
                            args.hidden_size,
         
     | 
| 17 | 
         
            +
                            args.ffn_hidden_size,
         
     | 
| 18 | 
         
            +
                            device=args.device,
         
     | 
| 19 | 
         
            +
                            dtype=torch.float16 if args.fp16 else torch.float32,
         
     | 
| 20 | 
         
            +
                        ),
         
     | 
| 21 | 
         
            +
                    )
         
     | 
| 22 | 
         
            +
                    self.w2 = torch.nn.Parameter(
         
     | 
| 23 | 
         
            +
                        torch.empty(
         
     | 
| 24 | 
         
            +
                            args.ffn_hidden_size,
         
     | 
| 25 | 
         
            +
                            args.hidden_size,
         
     | 
| 26 | 
         
            +
                            device=args.device,
         
     | 
| 27 | 
         
            +
                            dtype=torch.float16 if args.fp16 else torch.float32,
         
     | 
| 28 | 
         
            +
                        ),
         
     | 
| 29 | 
         
            +
                    )
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def forward(self, x):
         
     | 
| 32 | 
         
            +
                    return torch.matmul(
         
     | 
| 33 | 
         
            +
                        F.gelu(torch.matmul(x, self.w1), approximate='tanh'),
         
     | 
| 34 | 
         
            +
                        self.w2,
         
     | 
| 35 | 
         
            +
                    )
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            class GLU(FFN):
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                def __init__(self, args: Arguments):
         
     | 
| 41 | 
         
            +
                    super().__init__(args)
         
     | 
| 42 | 
         
            +
                    self.v1 = torch.nn.Parameter(
         
     | 
| 43 | 
         
            +
                        torch.empty(
         
     | 
| 44 | 
         
            +
                            args.hidden_size,
         
     | 
| 45 | 
         
            +
                            args.ffn_hidden_size,
         
     | 
| 46 | 
         
            +
                            device=args.device,
         
     | 
| 47 | 
         
            +
                            dtype=torch.float16 if args.fp16 else torch.float32,
         
     | 
| 48 | 
         
            +
                        ),
         
     | 
| 49 | 
         
            +
                    )
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def forward(self, x):
         
     | 
| 52 | 
         
            +
                    x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1)
         
     | 
| 53 | 
         
            +
                    return torch.matmul(x1, self.w2)
         
     | 
    	
        tests/layers/moe_test.py
    ADDED
    
    | 
         @@ -0,0 +1,199 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from functools import partial
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import pytest
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from megablocks.layers.arguments import Arguments
         
     | 
| 10 | 
         
            +
            from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
         
     | 
| 11 | 
         
            +
            from megablocks.layers.router import batched_router_zloss, clear_router_zloss
         
     | 
| 12 | 
         
            +
            from tests.layers.architectures import FFN
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            _FORWARD_TESTS = (
         
     | 
| 15 | 
         
            +
                (16, 1024, 512, 1, 1),
         
     | 
| 16 | 
         
            +
                (16, 1024, 512, 2, 1),
         
     | 
| 17 | 
         
            +
                (16, 1024, 512, 4, 1),
         
     | 
| 18 | 
         
            +
                (16, 1024, 512, 8, 1),
         
     | 
| 19 | 
         
            +
                (8, 2048, 512, 1, 1),
         
     | 
| 20 | 
         
            +
                (8, 2048, 512, 2, 1),
         
     | 
| 21 | 
         
            +
                (8, 2048, 512, 4, 1),
         
     | 
| 22 | 
         
            +
                (16, 1024, 512, 2, 2),
         
     | 
| 23 | 
         
            +
                (16, 1024, 512, 4, 2),
         
     | 
| 24 | 
         
            +
                (16, 1024, 512, 4, 4),
         
     | 
| 25 | 
         
            +
                (16, 1024, 512, 8, 2),
         
     | 
| 26 | 
         
            +
                (16, 1024, 512, 8, 4),
         
     | 
| 27 | 
         
            +
                (16, 1024, 512, 8, 8),
         
     | 
| 28 | 
         
            +
            )
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            _DENSE_TESTS = (
         
     | 
| 31 | 
         
            +
                (16, 1024, 512),
         
     | 
| 32 | 
         
            +
                (8, 2048, 512),
         
     | 
| 33 | 
         
            +
            )
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def construct_moe(
         
     | 
| 37 | 
         
            +
                hidden_size: int,
         
     | 
| 38 | 
         
            +
                ffn_hidden_size: int,
         
     | 
| 39 | 
         
            +
                moe_num_experts: int = 1,
         
     | 
| 40 | 
         
            +
                moe_capacity_factor: int = 1,
         
     | 
| 41 | 
         
            +
                moe_top_k: int = 1,
         
     | 
| 42 | 
         
            +
                moe_zloss_weight: float = 0,
         
     | 
| 43 | 
         
            +
            ):
         
     | 
| 44 | 
         
            +
                # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
         
     | 
| 45 | 
         
            +
                # TODO: Remove this once sparse is supported with triton >=3.2.0
         
     | 
| 46 | 
         
            +
                try:
         
     | 
| 47 | 
         
            +
                    import triton
         
     | 
| 48 | 
         
            +
                    if triton.__version__ >= '3.2.0':
         
     | 
| 49 | 
         
            +
                        pytest.skip('Sparse MLP is not supported with triton >=3.2.0')
         
     | 
| 50 | 
         
            +
                except ImportError:
         
     | 
| 51 | 
         
            +
                    pass
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
         
     | 
| 54 | 
         
            +
                args = Arguments(
         
     | 
| 55 | 
         
            +
                    hidden_size=hidden_size,
         
     | 
| 56 | 
         
            +
                    ffn_hidden_size=ffn_hidden_size,
         
     | 
| 57 | 
         
            +
                    moe_num_experts=moe_num_experts,
         
     | 
| 58 | 
         
            +
                    moe_capacity_factor=moe_capacity_factor,
         
     | 
| 59 | 
         
            +
                    moe_top_k=moe_top_k,
         
     | 
| 60 | 
         
            +
                    init_method=init_method,
         
     | 
| 61 | 
         
            +
                    moe_zloss_weight=moe_zloss_weight,
         
     | 
| 62 | 
         
            +
                )
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                mlp = FFN(args)
         
     | 
| 65 | 
         
            +
                moe_mlp = MoE(args)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                mlp.cuda(torch.cuda.current_device()).half()
         
     | 
| 68 | 
         
            +
                moe_mlp.cuda(torch.cuda.current_device()).half()
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                # Set the baseline parameters to match exactly.
         
     | 
| 71 | 
         
            +
                if moe_num_experts == 1:
         
     | 
| 72 | 
         
            +
                    with torch.no_grad():
         
     | 
| 73 | 
         
            +
                        mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze())
         
     | 
| 74 | 
         
            +
                        mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze())
         
     | 
| 75 | 
         
            +
                return args, mlp, moe_mlp
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 79 | 
         
            +
            @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
         
     | 
| 80 | 
         
            +
            def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int):
         
     | 
| 81 | 
         
            +
                x = torch.randn(sl, bs, hs).half().cuda()
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                _, _, layer = construct_moe(
         
     | 
| 84 | 
         
            +
                    hidden_size=hs,
         
     | 
| 85 | 
         
            +
                    ffn_hidden_size=hs * 2,
         
     | 
| 86 | 
         
            +
                    moe_num_experts=num_experts,
         
     | 
| 87 | 
         
            +
                    moe_top_k=top_k,
         
     | 
| 88 | 
         
            +
                )
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                out, _ = layer(x)
         
     | 
| 91 | 
         
            +
                assert out.shape == x.shape
         
     | 
| 92 | 
         
            +
                clear_load_balancing_loss()
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 96 | 
         
            +
            @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
         
     | 
| 97 | 
         
            +
            def test_moe_forward_backward(
         
     | 
| 98 | 
         
            +
                bs: int,
         
     | 
| 99 | 
         
            +
                sl: int,
         
     | 
| 100 | 
         
            +
                hs: int,
         
     | 
| 101 | 
         
            +
                num_experts: int,
         
     | 
| 102 | 
         
            +
                top_k: int,
         
     | 
| 103 | 
         
            +
            ):
         
     | 
| 104 | 
         
            +
                x = torch.randn(sl, bs, hs).half().cuda()
         
     | 
| 105 | 
         
            +
                x.requires_grad_(True)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                args, _, layer = construct_moe(
         
     | 
| 108 | 
         
            +
                    hidden_size=hs,
         
     | 
| 109 | 
         
            +
                    ffn_hidden_size=hs * 2,
         
     | 
| 110 | 
         
            +
                    moe_num_experts=num_experts,
         
     | 
| 111 | 
         
            +
                    moe_top_k=top_k,
         
     | 
| 112 | 
         
            +
                )
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                out, _ = layer(x)
         
     | 
| 115 | 
         
            +
                assert out.shape == x.shape
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                loss = out.sum() + batched_load_balancing_loss(args)
         
     | 
| 118 | 
         
            +
                loss.backward()
         
     | 
| 119 | 
         
            +
                layer.zero_grad(set_to_none=True)
         
     | 
| 120 | 
         
            +
                x.grad = None
         
     | 
| 121 | 
         
            +
                clear_load_balancing_loss()
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 125 | 
         
            +
            @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
         
     | 
| 126 | 
         
            +
            def test_moe_forward_backward_with_zloss(
         
     | 
| 127 | 
         
            +
                bs: int,
         
     | 
| 128 | 
         
            +
                sl: int,
         
     | 
| 129 | 
         
            +
                hs: int,
         
     | 
| 130 | 
         
            +
                num_experts: int,
         
     | 
| 131 | 
         
            +
                top_k: int,
         
     | 
| 132 | 
         
            +
            ):
         
     | 
| 133 | 
         
            +
                x = torch.randn(sl, bs, hs).half().cuda()
         
     | 
| 134 | 
         
            +
                x.requires_grad_(True)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                args, _, layer = construct_moe(
         
     | 
| 137 | 
         
            +
                    hidden_size=hs,
         
     | 
| 138 | 
         
            +
                    ffn_hidden_size=hs * 2,
         
     | 
| 139 | 
         
            +
                    moe_num_experts=num_experts,
         
     | 
| 140 | 
         
            +
                    moe_top_k=top_k,
         
     | 
| 141 | 
         
            +
                    moe_zloss_weight=1e-3,
         
     | 
| 142 | 
         
            +
                )
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                out, _ = layer(x)
         
     | 
| 145 | 
         
            +
                assert out.shape == x.shape
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                loss = out.sum() + batched_load_balancing_loss(args)
         
     | 
| 148 | 
         
            +
                loss.backward()
         
     | 
| 149 | 
         
            +
                layer.zero_grad(set_to_none=True)
         
     | 
| 150 | 
         
            +
                x.grad = None
         
     | 
| 151 | 
         
            +
                clear_load_balancing_loss()
         
     | 
| 152 | 
         
            +
                clear_router_zloss()
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 156 | 
         
            +
            @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
         
     | 
| 157 | 
         
            +
            def test_moe_forward_vs_dense(bs: int, sl: int, hs: int):
         
     | 
| 158 | 
         
            +
                x = torch.randn(sl, bs, hs).half().cuda()
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                expected_out = mlp(x)
         
     | 
| 163 | 
         
            +
                out, _ = moe_mlp(x)
         
     | 
| 164 | 
         
            +
                assert out.shape == x.shape == expected_out.shape
         
     | 
| 165 | 
         
            +
                assert torch.allclose(out, expected_out)
         
     | 
| 166 | 
         
            +
                clear_load_balancing_loss()
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 170 | 
         
            +
            @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
         
     | 
| 171 | 
         
            +
            def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int):
         
     | 
| 172 | 
         
            +
                x = torch.randn(sl, bs, hs).half().cuda()
         
     | 
| 173 | 
         
            +
                x.requires_grad_(True)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                out, _ = moe_mlp(x)
         
     | 
| 178 | 
         
            +
                loss = out.sum()
         
     | 
| 179 | 
         
            +
                loss.backward()
         
     | 
| 180 | 
         
            +
                w1_grad = moe_mlp.experts.mlp.w1.grad.detach().squeeze()
         
     | 
| 181 | 
         
            +
                w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze()
         
     | 
| 182 | 
         
            +
                moe_mlp.zero_grad(set_to_none=True)
         
     | 
| 183 | 
         
            +
                x.grad = None
         
     | 
| 184 | 
         
            +
                clear_load_balancing_loss()
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                expected_out = mlp(x)
         
     | 
| 187 | 
         
            +
                expected_loss = expected_out.sum()
         
     | 
| 188 | 
         
            +
                expected_loss.backward()
         
     | 
| 189 | 
         
            +
                expected_w1_grad = mlp.w1.grad.detach()
         
     | 
| 190 | 
         
            +
                expected_w2_grad = mlp.w2.grad.detach()
         
     | 
| 191 | 
         
            +
                mlp.zero_grad(set_to_none=True)
         
     | 
| 192 | 
         
            +
                x.grad = None
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                # Verify the gradients match.
         
     | 
| 195 | 
         
            +
                assert w1_grad.shape == expected_w1_grad.shape
         
     | 
| 196 | 
         
            +
                assert w2_grad.shape == expected_w2_grad.shape
         
     | 
| 197 | 
         
            +
                assert torch.allclose(w1_grad, expected_w1_grad)
         
     | 
| 198 | 
         
            +
                assert torch.allclose(w2_grad, expected_w2_grad)
         
     | 
| 199 | 
         
            +
                clear_load_balancing_loss()
         
     | 
    	
        tests/ops/binned_gather_test.py
    ADDED
    
    | 
         @@ -0,0 +1,71 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import pytest
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from megablocks import ops
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            BINNED_GATHER_TESTS = (
         
     | 
| 11 | 
         
            +
                (4, 2, 2, 1),
         
     | 
| 12 | 
         
            +
                (4, 2, 2, 2),
         
     | 
| 13 | 
         
            +
                (4, 2, 2, 4),
         
     | 
| 14 | 
         
            +
                (1024, 1536, 4, 1),
         
     | 
| 15 | 
         
            +
                (1024, 1536, 4, 2),
         
     | 
| 16 | 
         
            +
                (1024, 1536, 4, 4),
         
     | 
| 17 | 
         
            +
                (1024, 1536, 64, 1),
         
     | 
| 18 | 
         
            +
                (1024, 1536, 64, 2),
         
     | 
| 19 | 
         
            +
                (1024, 1536, 64, 4),
         
     | 
| 20 | 
         
            +
                (1024, 1536, 128, 1),
         
     | 
| 21 | 
         
            +
                (1024, 1536, 128, 2),
         
     | 
| 22 | 
         
            +
                (1024, 1536, 128, 4),
         
     | 
| 23 | 
         
            +
                (16384, 768, 4, 1),
         
     | 
| 24 | 
         
            +
                (16384, 768, 4, 2),
         
     | 
| 25 | 
         
            +
                (16384, 768, 4, 4),
         
     | 
| 26 | 
         
            +
                (16384, 768, 64, 1),
         
     | 
| 27 | 
         
            +
                (16384, 768, 64, 2),
         
     | 
| 28 | 
         
            +
                (16384, 768, 64, 4),
         
     | 
| 29 | 
         
            +
                (16384, 768, 128, 1),
         
     | 
| 30 | 
         
            +
                (16384, 768, 128, 2),
         
     | 
| 31 | 
         
            +
                (16384, 768, 128, 4),
         
     | 
| 32 | 
         
            +
            )
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 36 | 
         
            +
            @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), BINNED_GATHER_TESTS)
         
     | 
| 37 | 
         
            +
            def test_binned_gather(sl: int, hs: int, ne: int, top_k: int):
         
     | 
| 38 | 
         
            +
                # NOTE: Capacity factor == 1.
         
     | 
| 39 | 
         
            +
                ec = (sl * top_k) // ne
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                # Create the data and indices.
         
     | 
| 42 | 
         
            +
                x = torch.randn((sl, hs)).cuda().half()
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                # Randomly assign tokens to experts.
         
     | 
| 45 | 
         
            +
                top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
         
     | 
| 46 | 
         
            +
                _, indices = ops.sort(top_expert)
         
     | 
| 47 | 
         
            +
                bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def binned_gather(
         
     | 
| 50 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 51 | 
         
            +
                    indices: torch.Tensor,
         
     | 
| 52 | 
         
            +
                    bins: torch.Tensor,
         
     | 
| 53 | 
         
            +
                    ec: int,
         
     | 
| 54 | 
         
            +
                    top_k: int,
         
     | 
| 55 | 
         
            +
                ):
         
     | 
| 56 | 
         
            +
                    x = x.cpu().numpy()
         
     | 
| 57 | 
         
            +
                    indices = indices.cpu().numpy()
         
     | 
| 58 | 
         
            +
                    bins = bins.cpu().numpy()
         
     | 
| 59 | 
         
            +
                    start = 0
         
     | 
| 60 | 
         
            +
                    out = np.zeros((ne, ec, hs))
         
     | 
| 61 | 
         
            +
                    for i in range(ne):
         
     | 
| 62 | 
         
            +
                        end = bins[i]
         
     | 
| 63 | 
         
            +
                        for j in range(min(ec, end - start)):
         
     | 
| 64 | 
         
            +
                            index = indices[start + j] // top_k
         
     | 
| 65 | 
         
            +
                            out[i, j, :] = x[index, :]
         
     | 
| 66 | 
         
            +
                        start = end
         
     | 
| 67 | 
         
            +
                    return torch.from_numpy(out).cuda().half()
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                out = ops.binned_gather(x, indices, bins, ec, top_k)
         
     | 
| 70 | 
         
            +
                expected_out = binned_gather(x, indices, bins, ec, top_k)
         
     | 
| 71 | 
         
            +
                assert torch.all(torch.eq(out, expected_out))
         
     | 
    	
        tests/ops/binned_scatter_test.py
    ADDED
    
    | 
         @@ -0,0 +1,87 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import pytest
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from megablocks import ops
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            _BINNED_SCATTER_TESTS = (
         
     | 
| 11 | 
         
            +
                (4, 2, 2, 1),
         
     | 
| 12 | 
         
            +
                (4, 2, 2, 2),
         
     | 
| 13 | 
         
            +
                (4, 2, 2, 4),
         
     | 
| 14 | 
         
            +
                (1024, 1536, 4, 1),
         
     | 
| 15 | 
         
            +
                (1024, 1536, 4, 2),
         
     | 
| 16 | 
         
            +
                (1024, 1536, 4, 4),
         
     | 
| 17 | 
         
            +
                (1024, 1536, 64, 1),
         
     | 
| 18 | 
         
            +
                (1024, 1536, 64, 2),
         
     | 
| 19 | 
         
            +
                (1024, 1536, 64, 4),
         
     | 
| 20 | 
         
            +
                (1024, 1536, 128, 1),
         
     | 
| 21 | 
         
            +
                (1024, 1536, 128, 2),
         
     | 
| 22 | 
         
            +
                (1024, 1536, 128, 4),
         
     | 
| 23 | 
         
            +
                (16384, 768, 4, 1),
         
     | 
| 24 | 
         
            +
                (16384, 768, 4, 2),
         
     | 
| 25 | 
         
            +
                (16384, 768, 4, 4),
         
     | 
| 26 | 
         
            +
                (16384, 768, 64, 1),
         
     | 
| 27 | 
         
            +
                (16384, 768, 64, 2),
         
     | 
| 28 | 
         
            +
                (16384, 768, 64, 4),
         
     | 
| 29 | 
         
            +
                (16384, 768, 128, 1),
         
     | 
| 30 | 
         
            +
                (16384, 768, 128, 2),
         
     | 
| 31 | 
         
            +
                (16384, 768, 128, 4),
         
     | 
| 32 | 
         
            +
            )
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 36 | 
         
            +
            @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), _BINNED_SCATTER_TESTS)
         
     | 
| 37 | 
         
            +
            def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int):
         
     | 
| 38 | 
         
            +
                # NOTE: Capacity factor == 1.
         
     | 
| 39 | 
         
            +
                ec = (sl * top_k) // ne
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                # Create the data and indices.
         
     | 
| 42 | 
         
            +
                x = torch.randn((sl, hs)).cuda().half()
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                # Randomly assign tokens to experts.
         
     | 
| 45 | 
         
            +
                top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
         
     | 
| 46 | 
         
            +
                _, indices = ops.sort(top_expert)
         
     | 
| 47 | 
         
            +
                bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                # Sample weights for the scatter reduce.
         
     | 
| 50 | 
         
            +
                weights = torch.rand((sl * top_k,)).cuda().half()
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                x = ops.binned_gather(x, indices, bins, ec, top_k)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def binned_scatter(
         
     | 
| 55 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 56 | 
         
            +
                    indices: torch.Tensor,
         
     | 
| 57 | 
         
            +
                    weights: torch.Tensor,
         
     | 
| 58 | 
         
            +
                    bins: torch.Tensor,
         
     | 
| 59 | 
         
            +
                    top_k: int,
         
     | 
| 60 | 
         
            +
                ):
         
     | 
| 61 | 
         
            +
                    x = x.cpu().numpy()
         
     | 
| 62 | 
         
            +
                    indices = indices.cpu().numpy()
         
     | 
| 63 | 
         
            +
                    weights = weights.cpu().numpy()
         
     | 
| 64 | 
         
            +
                    bins = bins.cpu().numpy()
         
     | 
| 65 | 
         
            +
                    start = 0
         
     | 
| 66 | 
         
            +
                    out = np.zeros((sl, hs))
         
     | 
| 67 | 
         
            +
                    for i in range(ne):
         
     | 
| 68 | 
         
            +
                        end = bins[i]
         
     | 
| 69 | 
         
            +
                        for j in range(min(ec, end - start)):
         
     | 
| 70 | 
         
            +
                            index = indices[start + j]
         
     | 
| 71 | 
         
            +
                            scale = weights[index]
         
     | 
| 72 | 
         
            +
                            index //= top_k
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                            out[index, :] += scale * x[i, j, :]
         
     | 
| 75 | 
         
            +
                        start = end
         
     | 
| 76 | 
         
            +
                    return torch.from_numpy(out).cuda().half()
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                out = ops.binned_scatter(x, indices, weights, bins, top_k)
         
     | 
| 79 | 
         
            +
                expected_out = binned_scatter(x, indices, weights, bins, top_k)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                # NOTE: We need to check approximate equality because the
         
     | 
| 82 | 
         
            +
                # scatter reduce uses atomics.
         
     | 
| 83 | 
         
            +
                assert np.testing.assert_allclose(
         
     | 
| 84 | 
         
            +
                    out.cpu(),
         
     | 
| 85 | 
         
            +
                    expected_out.cpu(),
         
     | 
| 86 | 
         
            +
                    rtol=5e-3,
         
     | 
| 87 | 
         
            +
                ) is None
         
     | 
    	
        tests/ops/cumsum_test.py
    ADDED
    
    | 
         @@ -0,0 +1,44 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import pytest
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from megablocks import ops
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            CUMSUM_TESTS = (
         
     | 
| 10 | 
         
            +
                (1, 32),
         
     | 
| 11 | 
         
            +
                (2, 32),
         
     | 
| 12 | 
         
            +
                (2, 1024),
         
     | 
| 13 | 
         
            +
                (4, 1024),
         
     | 
| 14 | 
         
            +
                (8, 1024),
         
     | 
| 15 | 
         
            +
                (16, 1024),
         
     | 
| 16 | 
         
            +
                (32, 1024),
         
     | 
| 17 | 
         
            +
                (64, 1024),
         
     | 
| 18 | 
         
            +
                (128, 1024),
         
     | 
| 19 | 
         
            +
                (2, 16384),
         
     | 
| 20 | 
         
            +
                (4, 16384),
         
     | 
| 21 | 
         
            +
                (8, 16384),
         
     | 
| 22 | 
         
            +
                (16, 16384),
         
     | 
| 23 | 
         
            +
                (32, 16384),
         
     | 
| 24 | 
         
            +
                (64, 16384),
         
     | 
| 25 | 
         
            +
                (128, 16384),
         
     | 
| 26 | 
         
            +
            )
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 30 | 
         
            +
            @pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
         
     | 
| 31 | 
         
            +
            def test_exclusive_cumsum(n: int, m: int):
         
     | 
| 32 | 
         
            +
                x = torch.randint(0, 2, (n, m)).long().cuda()
         
     | 
| 33 | 
         
            +
                out = ops.exclusive_cumsum(x, 1) * x
         
     | 
| 34 | 
         
            +
                expected_out = (torch.cumsum(x, dim=1) - 1) * x
         
     | 
| 35 | 
         
            +
                assert torch.all(torch.eq(out, expected_out))
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 39 | 
         
            +
            @pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
         
     | 
| 40 | 
         
            +
            def test_inclusive_cumsum(n: int, m: int):
         
     | 
| 41 | 
         
            +
                x = torch.randint(0, 2, (n, m)).long().cuda()
         
     | 
| 42 | 
         
            +
                out = ops.inclusive_cumsum(x, 1)
         
     | 
| 43 | 
         
            +
                expected_out = torch.cumsum(x, dim=1)
         
     | 
| 44 | 
         
            +
                assert torch.all(torch.eq(out, expected_out))
         
     | 
    	
        tests/ops/histogram_test.py
    ADDED
    
    | 
         @@ -0,0 +1,82 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import pytest
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from megablocks import ops
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            _HISTOGRAM_TESTS = (
         
     | 
| 10 | 
         
            +
                (1, 32, torch.int16, 128),
         
     | 
| 11 | 
         
            +
                (1, 1024, torch.int16, 128),
         
     | 
| 12 | 
         
            +
                (1, 16384, torch.int16, 128),
         
     | 
| 13 | 
         
            +
                (1, 32, torch.int32, 128),
         
     | 
| 14 | 
         
            +
                (1, 1024, torch.int32, 128),
         
     | 
| 15 | 
         
            +
                (1, 16384, torch.int32, 128),
         
     | 
| 16 | 
         
            +
                (1, 32, torch.int64, 128),
         
     | 
| 17 | 
         
            +
                (1, 1024, torch.int64, 128),
         
     | 
| 18 | 
         
            +
                (1, 16384, torch.int64, 128),
         
     | 
| 19 | 
         
            +
                (1, 32, torch.int16, 1024),
         
     | 
| 20 | 
         
            +
                (1, 1024, torch.int16, 1024),
         
     | 
| 21 | 
         
            +
                (1, 16384, torch.int16, 1024),
         
     | 
| 22 | 
         
            +
                (1, 32, torch.int32, 1024),
         
     | 
| 23 | 
         
            +
                (1, 1024, torch.int32, 1024),
         
     | 
| 24 | 
         
            +
                (1, 16384, torch.int32, 1024),
         
     | 
| 25 | 
         
            +
                (1, 32, torch.int64, 1024),
         
     | 
| 26 | 
         
            +
                (1, 1024, torch.int64, 1024),
         
     | 
| 27 | 
         
            +
                (1, 16384, torch.int64, 1024),
         
     | 
| 28 | 
         
            +
                (2, 32, torch.int16, 128),
         
     | 
| 29 | 
         
            +
                (2, 1024, torch.int16, 128),
         
     | 
| 30 | 
         
            +
                (2, 16384, torch.int16, 128),
         
     | 
| 31 | 
         
            +
                (2, 32, torch.int32, 128),
         
     | 
| 32 | 
         
            +
                (2, 1024, torch.int32, 128),
         
     | 
| 33 | 
         
            +
                (2, 16384, torch.int32, 128),
         
     | 
| 34 | 
         
            +
                (2, 32, torch.int64, 128),
         
     | 
| 35 | 
         
            +
                (2, 1024, torch.int64, 128),
         
     | 
| 36 | 
         
            +
                (2, 16384, torch.int64, 128),
         
     | 
| 37 | 
         
            +
                (2, 32, torch.int16, 1024),
         
     | 
| 38 | 
         
            +
                (2, 1024, torch.int16, 1024),
         
     | 
| 39 | 
         
            +
                (2, 16384, torch.int16, 1024),
         
     | 
| 40 | 
         
            +
                (2, 32, torch.int32, 1024),
         
     | 
| 41 | 
         
            +
                (2, 1024, torch.int32, 1024),
         
     | 
| 42 | 
         
            +
                (2, 16384, torch.int32, 1024),
         
     | 
| 43 | 
         
            +
                (2, 32, torch.int64, 1024),
         
     | 
| 44 | 
         
            +
                (2, 1024, torch.int64, 1024),
         
     | 
| 45 | 
         
            +
                (2, 16384, torch.int64, 1024),
         
     | 
| 46 | 
         
            +
                (8, 32, torch.int16, 128),
         
     | 
| 47 | 
         
            +
                (8, 1024, torch.int16, 128),
         
     | 
| 48 | 
         
            +
                (8, 16384, torch.int16, 128),
         
     | 
| 49 | 
         
            +
                (8, 32, torch.int32, 128),
         
     | 
| 50 | 
         
            +
                (8, 1024, torch.int32, 128),
         
     | 
| 51 | 
         
            +
                (8, 16384, torch.int32, 128),
         
     | 
| 52 | 
         
            +
                (8, 32, torch.int64, 128),
         
     | 
| 53 | 
         
            +
                (8, 1024, torch.int64, 128),
         
     | 
| 54 | 
         
            +
                (8, 16384, torch.int64, 128),
         
     | 
| 55 | 
         
            +
                (8, 32, torch.int16, 1024),
         
     | 
| 56 | 
         
            +
                (8, 1024, torch.int16, 1024),
         
     | 
| 57 | 
         
            +
                (8, 16384, torch.int16, 1024),
         
     | 
| 58 | 
         
            +
                (8, 32, torch.int32, 1024),
         
     | 
| 59 | 
         
            +
                (8, 1024, torch.int32, 1024),
         
     | 
| 60 | 
         
            +
                (8, 16384, torch.int32, 1024),
         
     | 
| 61 | 
         
            +
                (8, 32, torch.int64, 1024),
         
     | 
| 62 | 
         
            +
                (8, 1024, torch.int64, 1024),
         
     | 
| 63 | 
         
            +
                (8, 16384, torch.int64, 1024),
         
     | 
| 64 | 
         
            +
            )
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            # Override the seed_all fixture in autouse.py because
         
     | 
| 68 | 
         
            +
            # _histc_cuda does not have a deterministic implementation
         
     | 
| 69 | 
         
            +
            @pytest.fixture()
         
     | 
| 70 | 
         
            +
            def seed_all():
         
     | 
| 71 | 
         
            +
                torch.use_deterministic_algorithms(False)
         
     | 
| 72 | 
         
            +
                return
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 76 | 
         
            +
            @pytest.mark.parametrize(('m', 'n', 'dtype', 'max_val'), _HISTOGRAM_TESTS)
         
     | 
| 77 | 
         
            +
            def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int):
         
     | 
| 78 | 
         
            +
                x = torch.randint(0, max_val, (m, n)).cuda().to(dtype)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                out = ops.histogram(x, max_val)
         
     | 
| 81 | 
         
            +
                expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)])
         
     | 
| 82 | 
         
            +
                assert torch.all(torch.eq(out, expected_out))
         
     | 
    	
        tests/ops/padded_gather_test.py
    ADDED
    
    | 
         @@ -0,0 +1,94 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import pytest
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from megablocks import ops
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            PADDED_GATHER_TESTS = (
         
     | 
| 11 | 
         
            +
                (4, 2, 2, 1),
         
     | 
| 12 | 
         
            +
                (4, 2, 2, 2),
         
     | 
| 13 | 
         
            +
                (1024, 1, 4, 1),
         
     | 
| 14 | 
         
            +
                (1024, 1, 4, 2),
         
     | 
| 15 | 
         
            +
                (1024, 1, 4, 4),
         
     | 
| 16 | 
         
            +
                (1024, 1, 64, 1),
         
     | 
| 17 | 
         
            +
                (1024, 1, 64, 2),
         
     | 
| 18 | 
         
            +
                (1024, 1, 64, 4),
         
     | 
| 19 | 
         
            +
                (1024, 1, 128, 1),
         
     | 
| 20 | 
         
            +
                (1024, 1, 128, 2),
         
     | 
| 21 | 
         
            +
                (1024, 1, 128, 4),
         
     | 
| 22 | 
         
            +
                (1024, 1536, 4, 1),
         
     | 
| 23 | 
         
            +
                (1024, 1536, 4, 2),
         
     | 
| 24 | 
         
            +
                (1024, 1536, 4, 4),
         
     | 
| 25 | 
         
            +
                (1024, 1536, 64, 1),
         
     | 
| 26 | 
         
            +
                (1024, 1536, 64, 2),
         
     | 
| 27 | 
         
            +
                (1024, 1536, 64, 4),
         
     | 
| 28 | 
         
            +
                (1024, 1536, 128, 1),
         
     | 
| 29 | 
         
            +
                (1024, 1536, 128, 2),
         
     | 
| 30 | 
         
            +
                (1024, 1536, 128, 4),
         
     | 
| 31 | 
         
            +
                (16384, 768, 4, 1),
         
     | 
| 32 | 
         
            +
                (16384, 768, 4, 2),
         
     | 
| 33 | 
         
            +
                (16384, 768, 4, 4),
         
     | 
| 34 | 
         
            +
                (16384, 768, 64, 1),
         
     | 
| 35 | 
         
            +
                (16384, 768, 64, 2),
         
     | 
| 36 | 
         
            +
                (16384, 768, 64, 4),
         
     | 
| 37 | 
         
            +
                (16384, 768, 128, 1),
         
     | 
| 38 | 
         
            +
                (16384, 768, 128, 2),
         
     | 
| 39 | 
         
            +
                (16384, 768, 128, 4),
         
     | 
| 40 | 
         
            +
                (16384, 1, 4, 1),
         
     | 
| 41 | 
         
            +
                (16384, 1, 4, 2),
         
     | 
| 42 | 
         
            +
                (16384, 1, 4, 4),
         
     | 
| 43 | 
         
            +
                (16384, 1, 64, 1),
         
     | 
| 44 | 
         
            +
                (16384, 1, 64, 2),
         
     | 
| 45 | 
         
            +
                (16384, 1, 64, 4),
         
     | 
| 46 | 
         
            +
                (16384, 1, 128, 1),
         
     | 
| 47 | 
         
            +
                (16384, 1, 128, 2),
         
     | 
| 48 | 
         
            +
                (16384, 1, 128, 4),
         
     | 
| 49 | 
         
            +
            )
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 53 | 
         
            +
            @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), PADDED_GATHER_TESTS)
         
     | 
| 54 | 
         
            +
            def testPaddedGather(sl: int, hs: int, ne: int, top_k: int):
         
     | 
| 55 | 
         
            +
                # Create the data and indices.
         
     | 
| 56 | 
         
            +
                x = torch.randn((sl, hs)).cuda().half()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                # Randomly assign tokens to experts.
         
     | 
| 59 | 
         
            +
                top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
         
     | 
| 60 | 
         
            +
                bin_ids, indices = ops.sort(top_expert)
         
     | 
| 61 | 
         
            +
                tokens_per_expert = ops.histogram(top_expert, ne)
         
     | 
| 62 | 
         
            +
                padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
         
     | 
| 63 | 
         
            +
                padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         
     | 
| 64 | 
         
            +
                bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                def padded_gather(
         
     | 
| 67 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 68 | 
         
            +
                    indices: torch.Tensor,
         
     | 
| 69 | 
         
            +
                    bin_ids: torch.Tensor,
         
     | 
| 70 | 
         
            +
                    bins: torch.Tensor,
         
     | 
| 71 | 
         
            +
                    padded_bins: torch.Tensor,
         
     | 
| 72 | 
         
            +
                    top_k: int,
         
     | 
| 73 | 
         
            +
                ):
         
     | 
| 74 | 
         
            +
                    x = x.cpu().numpy()
         
     | 
| 75 | 
         
            +
                    indices = indices.cpu().numpy()
         
     | 
| 76 | 
         
            +
                    bin_ids = bin_ids.cpu().numpy()
         
     | 
| 77 | 
         
            +
                    bins = bins.cpu().numpy()
         
     | 
| 78 | 
         
            +
                    padded_bins = padded_bins.cpu().numpy()
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    out = np.zeros((padded_bins[-1], hs))
         
     | 
| 81 | 
         
            +
                    in_idx = 0
         
     | 
| 82 | 
         
            +
                    for i, end in enumerate(bins):
         
     | 
| 83 | 
         
            +
                        out_idx = 0 if i == 0 else padded_bins[i - 1]
         
     | 
| 84 | 
         
            +
                        end = bins[i]
         
     | 
| 85 | 
         
            +
                        while in_idx < end:
         
     | 
| 86 | 
         
            +
                            load_idx = indices[in_idx] // top_k
         
     | 
| 87 | 
         
            +
                            out[out_idx, :] = x[load_idx, :]
         
     | 
| 88 | 
         
            +
                            in_idx += 1
         
     | 
| 89 | 
         
            +
                            out_idx += 1
         
     | 
| 90 | 
         
            +
                    return torch.from_numpy(out).cuda().half()
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
         
     | 
| 93 | 
         
            +
                expected_out = padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
         
     | 
| 94 | 
         
            +
                assert torch.all(torch.eq(out, expected_out))
         
     | 
    	
        tests/ops/padded_scatter_test.py
    ADDED
    
    | 
         @@ -0,0 +1,155 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import pytest
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from megablocks import ops
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            PADDED_SCATTER_TESTS = [
         
     | 
| 11 | 
         
            +
                (4, 2, 2, 2),
         
     | 
| 12 | 
         
            +
                (4, 2, 2, 1),
         
     | 
| 13 | 
         
            +
                (4, 2, 2, 1),
         
     | 
| 14 | 
         
            +
                (4, 2, 2, 1),
         
     | 
| 15 | 
         
            +
                (4, 2, 2, 2),
         
     | 
| 16 | 
         
            +
                (4, 2, 2, 2),
         
     | 
| 17 | 
         
            +
                (1024, 1, 4, 1),
         
     | 
| 18 | 
         
            +
                (1024, 1, 4, 2),
         
     | 
| 19 | 
         
            +
                (1024, 1, 4, 4),
         
     | 
| 20 | 
         
            +
                (1024, 1, 4, 1),
         
     | 
| 21 | 
         
            +
                (1024, 1, 4, 2),
         
     | 
| 22 | 
         
            +
                (1024, 1, 4, 4),
         
     | 
| 23 | 
         
            +
                (1024, 1, 4, 1),
         
     | 
| 24 | 
         
            +
                (1024, 1, 4, 2),
         
     | 
| 25 | 
         
            +
                (1024, 1, 4, 4),
         
     | 
| 26 | 
         
            +
                (1024, 1, 64, 1),
         
     | 
| 27 | 
         
            +
                (1024, 1, 64, 2),
         
     | 
| 28 | 
         
            +
                (1024, 1, 64, 4),
         
     | 
| 29 | 
         
            +
                (1024, 1, 128, 1),
         
     | 
| 30 | 
         
            +
                (1024, 1, 128, 2),
         
     | 
| 31 | 
         
            +
                (1024, 1, 128, 4),
         
     | 
| 32 | 
         
            +
                (1024, 1536, 4, 1),
         
     | 
| 33 | 
         
            +
                (1024, 1536, 4, 2),
         
     | 
| 34 | 
         
            +
                (1024, 1536, 4, 4),
         
     | 
| 35 | 
         
            +
                (1024, 1536, 4, 4),
         
     | 
| 36 | 
         
            +
                (1024, 1536, 4, 4),
         
     | 
| 37 | 
         
            +
                (1024, 1536, 64, 1),
         
     | 
| 38 | 
         
            +
                (1024, 1536, 64, 2),
         
     | 
| 39 | 
         
            +
                (1024, 1536, 64, 4),
         
     | 
| 40 | 
         
            +
                (1024, 1536, 128, 1),
         
     | 
| 41 | 
         
            +
                (1024, 1536, 128, 2),
         
     | 
| 42 | 
         
            +
                (1024, 1536, 128, 4),
         
     | 
| 43 | 
         
            +
                (1024, 1536, 128, 1),
         
     | 
| 44 | 
         
            +
                (1024, 1536, 128, 1),
         
     | 
| 45 | 
         
            +
                (16384, 768, 4, 1),
         
     | 
| 46 | 
         
            +
                (16384, 768, 4, 2),
         
     | 
| 47 | 
         
            +
                (16384, 768, 4, 4),
         
     | 
| 48 | 
         
            +
                (16384, 768, 64, 1),
         
     | 
| 49 | 
         
            +
                (16384, 768, 64, 2),
         
     | 
| 50 | 
         
            +
                (16384, 768, 64, 4),
         
     | 
| 51 | 
         
            +
                (16384, 768, 128, 1),
         
     | 
| 52 | 
         
            +
                (16384, 768, 128, 2),
         
     | 
| 53 | 
         
            +
                (16384, 768, 128, 4),
         
     | 
| 54 | 
         
            +
                (16384, 1, 4, 1),
         
     | 
| 55 | 
         
            +
                (16384, 1, 4, 2),
         
     | 
| 56 | 
         
            +
                (16384, 1, 4, 4),
         
     | 
| 57 | 
         
            +
                (16384, 1, 64, 1),
         
     | 
| 58 | 
         
            +
                (16384, 1, 64, 2),
         
     | 
| 59 | 
         
            +
                (16384, 1, 64, 4),
         
     | 
| 60 | 
         
            +
                (16384, 1, 128, 1),
         
     | 
| 61 | 
         
            +
                (16384, 1, 128, 2),
         
     | 
| 62 | 
         
            +
                (16384, 1, 128, 4),
         
     | 
| 63 | 
         
            +
                (16384, 1, 128, 2),
         
     | 
| 64 | 
         
            +
                (16384, 1, 128, 2),
         
     | 
| 65 | 
         
            +
            ]
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def _to_numpy(x: torch.Tensor) -> np.ndarray:
         
     | 
| 69 | 
         
            +
                return x.detach().cpu().numpy()
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 73 | 
         
            +
            @pytest.mark.parametrize((
         
     | 
| 74 | 
         
            +
                'sl',
         
     | 
| 75 | 
         
            +
                'hs',
         
     | 
| 76 | 
         
            +
                'ne',
         
     | 
| 77 | 
         
            +
                'top_k',
         
     | 
| 78 | 
         
            +
            ), PADDED_SCATTER_TESTS)
         
     | 
| 79 | 
         
            +
            def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int):
         
     | 
| 80 | 
         
            +
                # Create the data and indices.
         
     | 
| 81 | 
         
            +
                x = torch.randn((sl, hs), requires_grad=True).cuda().half()
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                # Randomly assign tokens to experts.
         
     | 
| 84 | 
         
            +
                top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
         
     | 
| 85 | 
         
            +
                bin_ids, indices = ops.sort(top_expert)
         
     | 
| 86 | 
         
            +
                tokens_per_expert = ops.histogram(top_expert, ne)
         
     | 
| 87 | 
         
            +
                padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
         
     | 
| 88 | 
         
            +
                padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         
     | 
| 89 | 
         
            +
                bins = ops.inclusive_cumsum(tokens_per_expert, 0)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                # Sample weights for the scatter reduce.
         
     | 
| 92 | 
         
            +
                weights = torch.rand((sl * top_k,), requires_grad=True).cuda().half()
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                # Gather the data to prepare for backwards.
         
     | 
| 95 | 
         
            +
                x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def padded_scatter(
         
     | 
| 98 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 99 | 
         
            +
                    indices: torch.Tensor,
         
     | 
| 100 | 
         
            +
                    bin_ids: torch.Tensor,
         
     | 
| 101 | 
         
            +
                    weights: torch.Tensor,
         
     | 
| 102 | 
         
            +
                    bins: torch.Tensor,
         
     | 
| 103 | 
         
            +
                    padded_bins: torch.Tensor,
         
     | 
| 104 | 
         
            +
                    top_k: int,
         
     | 
| 105 | 
         
            +
                ):
         
     | 
| 106 | 
         
            +
                    x = x.detach().cpu().numpy()
         
     | 
| 107 | 
         
            +
                    indices: np.ndarray = _to_numpy(indices)
         
     | 
| 108 | 
         
            +
                    bin_ids: np.ndarray = _to_numpy(bin_ids)
         
     | 
| 109 | 
         
            +
                    weights: np.ndarray = _to_numpy(weights)
         
     | 
| 110 | 
         
            +
                    bins: np.ndarray = _to_numpy(bins)
         
     | 
| 111 | 
         
            +
                    padded_bins: np.ndarray = _to_numpy(padded_bins)
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    out = np.zeros((indices.shape[0] // top_k, hs))
         
     | 
| 114 | 
         
            +
                    out_idx = 0
         
     | 
| 115 | 
         
            +
                    for i in range(len(bins)):
         
     | 
| 116 | 
         
            +
                        in_idx = 0 if i == 0 else padded_bins[i - 1]
         
     | 
| 117 | 
         
            +
                        end = bins[i]
         
     | 
| 118 | 
         
            +
                        while out_idx < end:
         
     | 
| 119 | 
         
            +
                            store_idx = indices[out_idx]
         
     | 
| 120 | 
         
            +
                            scale = weights[store_idx]
         
     | 
| 121 | 
         
            +
                            store_idx //= top_k
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                            out[store_idx, :] += scale * x[in_idx, :]
         
     | 
| 124 | 
         
            +
                            out_idx += 1
         
     | 
| 125 | 
         
            +
                            in_idx += 1
         
     | 
| 126 | 
         
            +
                    return torch.from_numpy(out).cuda().half()
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                out = ops.padded_scatter(
         
     | 
| 129 | 
         
            +
                    x,
         
     | 
| 130 | 
         
            +
                    indices,
         
     | 
| 131 | 
         
            +
                    bin_ids,
         
     | 
| 132 | 
         
            +
                    weights,
         
     | 
| 133 | 
         
            +
                    bins,
         
     | 
| 134 | 
         
            +
                    padded_bins,
         
     | 
| 135 | 
         
            +
                    top_k,
         
     | 
| 136 | 
         
            +
                )
         
     | 
| 137 | 
         
            +
                expected_out = padded_scatter(
         
     | 
| 138 | 
         
            +
                    x,
         
     | 
| 139 | 
         
            +
                    indices,
         
     | 
| 140 | 
         
            +
                    bin_ids,
         
     | 
| 141 | 
         
            +
                    weights,
         
     | 
| 142 | 
         
            +
                    bins,
         
     | 
| 143 | 
         
            +
                    padded_bins,
         
     | 
| 144 | 
         
            +
                    top_k,
         
     | 
| 145 | 
         
            +
                )
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                out.backward(torch.randn_like(out))  # sanity check backward pass
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                # NOTE: We need to check approximate equality because the scatter reduce uses atomics.
         
     | 
| 150 | 
         
            +
                # np.testing.assert_allclose returns `None` if no error and raises an AssertionError if an error exists
         
     | 
| 151 | 
         
            +
                assert np.testing.assert_allclose(
         
     | 
| 152 | 
         
            +
                    _to_numpy(out),
         
     | 
| 153 | 
         
            +
                    _to_numpy(expected_out),
         
     | 
| 154 | 
         
            +
                    rtol=5e-3,
         
     | 
| 155 | 
         
            +
                ) is None
         
     | 
    	
        tests/ops/replicate_test.py
    ADDED
    
    | 
         @@ -0,0 +1,108 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import pytest
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            try:
         
     | 
| 9 | 
         
            +
                from megablocks._ops import ops as backend  # type: ignore
         
     | 
| 10 | 
         
            +
            except ModuleNotFoundError as e:
         
     | 
| 11 | 
         
            +
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from megablocks import ops
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def promote_scalar(x: torch.Tensor) -> torch.Tensor:
         
     | 
| 17 | 
         
            +
                return x.view(1) if not len(x.size()) else x
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            REPLICATE_TESTS = [
         
     | 
| 21 | 
         
            +
                (8, 1, 1),
         
     | 
| 22 | 
         
            +
                (8, 2, 1),
         
     | 
| 23 | 
         
            +
                (8, 4, 1),
         
     | 
| 24 | 
         
            +
                (8, 8, 1),
         
     | 
| 25 | 
         
            +
                (8, 2, 2),
         
     | 
| 26 | 
         
            +
                (8, 4, 2),
         
     | 
| 27 | 
         
            +
                (8, 8, 2),
         
     | 
| 28 | 
         
            +
                (8, 2, 4),
         
     | 
| 29 | 
         
            +
                (8, 4, 4),
         
     | 
| 30 | 
         
            +
                (8, 8, 4),
         
     | 
| 31 | 
         
            +
                (8, 2, 8),
         
     | 
| 32 | 
         
            +
                (8, 4, 8),
         
     | 
| 33 | 
         
            +
                (8, 8, 8),
         
     | 
| 34 | 
         
            +
                (16384, 2, 1),
         
     | 
| 35 | 
         
            +
                (16384, 4, 1),
         
     | 
| 36 | 
         
            +
                (16384, 8, 1),
         
     | 
| 37 | 
         
            +
                (16384, 16, 1),
         
     | 
| 38 | 
         
            +
                (16384, 32, 1),
         
     | 
| 39 | 
         
            +
                (16384, 64, 1),
         
     | 
| 40 | 
         
            +
                (16384, 128, 1),
         
     | 
| 41 | 
         
            +
                (16384, 2, 2),
         
     | 
| 42 | 
         
            +
                (16384, 4, 2),
         
     | 
| 43 | 
         
            +
                (16384, 8, 2),
         
     | 
| 44 | 
         
            +
                (16384, 16, 2),
         
     | 
| 45 | 
         
            +
                (16384, 32, 2),
         
     | 
| 46 | 
         
            +
                (16384, 64, 2),
         
     | 
| 47 | 
         
            +
                (16384, 128, 2),
         
     | 
| 48 | 
         
            +
                (16384, 2, 4),
         
     | 
| 49 | 
         
            +
                (16384, 4, 4),
         
     | 
| 50 | 
         
            +
                (16384, 8, 4),
         
     | 
| 51 | 
         
            +
                (16384, 16, 4),
         
     | 
| 52 | 
         
            +
                (16384, 32, 4),
         
     | 
| 53 | 
         
            +
                (16384, 64, 4),
         
     | 
| 54 | 
         
            +
                (16384, 128, 4),
         
     | 
| 55 | 
         
            +
                (16384, 2, 8),
         
     | 
| 56 | 
         
            +
                (16384, 4, 8),
         
     | 
| 57 | 
         
            +
                (16384, 8, 8),
         
     | 
| 58 | 
         
            +
                (16384, 16, 8),
         
     | 
| 59 | 
         
            +
                (16384, 32, 8),
         
     | 
| 60 | 
         
            +
                (16384, 64, 8),
         
     | 
| 61 | 
         
            +
                (16384, 128, 8),
         
     | 
| 62 | 
         
            +
            ]
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 66 | 
         
            +
            @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
         
     | 
| 67 | 
         
            +
            def test_replicate(tokens: int, num_centers: int, top_k: int):
         
     | 
| 68 | 
         
            +
                tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
         
     | 
| 69 | 
         
            +
                tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
         
     | 
| 70 | 
         
            +
                bins = ops.inclusive_cumsum(tokens_per_center, 0)
         
     | 
| 71 | 
         
            +
                bins = promote_scalar(bins)
         
     | 
| 72 | 
         
            +
                center_weights = torch.randn(top_k, num_centers).cuda().half()
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
         
     | 
| 75 | 
         
            +
                    x = x.cpu().numpy()
         
     | 
| 76 | 
         
            +
                    bins = bins.cpu().numpy()
         
     | 
| 77 | 
         
            +
                    out = np.zeros((x.shape[0], num_outputs))
         
     | 
| 78 | 
         
            +
                    for batch_idx in range(x.shape[0]):
         
     | 
| 79 | 
         
            +
                        start = 0
         
     | 
| 80 | 
         
            +
                        for i, end in enumerate(bins):
         
     | 
| 81 | 
         
            +
                            value = x[batch_idx, i]
         
     | 
| 82 | 
         
            +
                            while start < end:
         
     | 
| 83 | 
         
            +
                                out[batch_idx, start] = value
         
     | 
| 84 | 
         
            +
                                start += 1
         
     | 
| 85 | 
         
            +
                    return torch.from_numpy(out).cuda().half()
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                out = ops.replicate(center_weights, bins, tokens)
         
     | 
| 88 | 
         
            +
                expected_out = replicate(center_weights, bins, tokens)
         
     | 
| 89 | 
         
            +
                assert torch.all(torch.eq(out, expected_out))
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 93 | 
         
            +
            @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
         
     | 
| 94 | 
         
            +
            def test_replicate_backward(tokens: int, num_centers: int, top_k: int):
         
     | 
| 95 | 
         
            +
                tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
         
     | 
| 96 | 
         
            +
                tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
         
     | 
| 97 | 
         
            +
                bins = ops.inclusive_cumsum(tokens_per_center, 0)
         
     | 
| 98 | 
         
            +
                bins = promote_scalar(bins)
         
     | 
| 99 | 
         
            +
                center_weights = torch.randn(top_k, num_centers).cuda().half()
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                grad = ops.replicate(center_weights, bins, tokens)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                out = torch.empty_like(center_weights)
         
     | 
| 104 | 
         
            +
                backend.replicate_backward(grad, bins, out)
         
     | 
| 105 | 
         
            +
                expected_out = center_weights * tokens_per_center.view([1, num_centers])
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                # NOTE: This floating-point reduction could be a problem for training stability and accuracy.
         
     | 
| 108 | 
         
            +
                assert torch.allclose(out, expected_out, rtol=1e-2)
         
     | 
    	
        tests/ops/sort_test.py
    ADDED
    
    | 
         @@ -0,0 +1,65 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from typing import Dict, Optional, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import pytest
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from megablocks import ops
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            SORT_TESTS = [
         
     | 
| 13 | 
         
            +
                (32, torch.int16, None),
         
     | 
| 14 | 
         
            +
                (1024, torch.int16, None),
         
     | 
| 15 | 
         
            +
                (16384, torch.int16, None),
         
     | 
| 16 | 
         
            +
                (32, torch.int32, None),
         
     | 
| 17 | 
         
            +
                (1024, torch.int32, None),
         
     | 
| 18 | 
         
            +
                (16384, torch.int32, None),
         
     | 
| 19 | 
         
            +
                (32, torch.int64, None),
         
     | 
| 20 | 
         
            +
                (1024, torch.int64, None),
         
     | 
| 21 | 
         
            +
                (16384, torch.int64, None),
         
     | 
| 22 | 
         
            +
                (32, torch.int16, 128),
         
     | 
| 23 | 
         
            +
                (1024, torch.int16, 128),
         
     | 
| 24 | 
         
            +
                (16384, torch.int16, 128),
         
     | 
| 25 | 
         
            +
                (32, torch.int32, 128),
         
     | 
| 26 | 
         
            +
                (1024, torch.int32, 128),
         
     | 
| 27 | 
         
            +
                (16384, torch.int32, 128),
         
     | 
| 28 | 
         
            +
                (32, torch.int64, 128),
         
     | 
| 29 | 
         
            +
                (1024, torch.int64, 128),
         
     | 
| 30 | 
         
            +
                (16384, torch.int64, 128),
         
     | 
| 31 | 
         
            +
            ]
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]:
         
     | 
| 35 | 
         
            +
                types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = {
         
     | 
| 36 | 
         
            +
                    torch.int16: np.int16,
         
     | 
| 37 | 
         
            +
                    torch.int32: np.int32,
         
     | 
| 38 | 
         
            +
                    torch.int64: np.int64,
         
     | 
| 39 | 
         
            +
                }
         
     | 
| 40 | 
         
            +
                return types[dtype]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 44 | 
         
            +
            @pytest.mark.parametrize(
         
     | 
| 45 | 
         
            +
                ('n', 'dtype', 'max_val'),
         
     | 
| 46 | 
         
            +
                SORT_TESTS,
         
     | 
| 47 | 
         
            +
            )
         
     | 
| 48 | 
         
            +
            def test_sort(n: int, dtype: torch.dtype, max_val: Optional[int]):
         
     | 
| 49 | 
         
            +
                if max_val is None:
         
     | 
| 50 | 
         
            +
                    max_val = np.iinfo(torch_to_numpy_dtype(dtype)).max
         
     | 
| 51 | 
         
            +
                end_bit = int(np.ceil(np.log2(max_val)))
         
     | 
| 52 | 
         
            +
                x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                out, indices = ops.sort(x, end_bit)
         
     | 
| 55 | 
         
            +
                expected_out, expected_indices = torch.sort(x)
         
     | 
| 56 | 
         
            +
                assert torch.all(torch.eq(out, expected_out))
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                # NOTE: The indices can be in different order depending
         
     | 
| 59 | 
         
            +
                # on sort stability if multiple values in the array are
         
     | 
| 60 | 
         
            +
                # equal.
         
     | 
| 61 | 
         
            +
                data = torch.empty_like(x)
         
     | 
| 62 | 
         
            +
                data.scatter_(0, indices.long(), out)
         
     | 
| 63 | 
         
            +
                expected_data = torch.empty_like(x)
         
     | 
| 64 | 
         
            +
                expected_data.scatter_(0, expected_indices, expected_out)
         
     | 
| 65 | 
         
            +
                assert torch.all(torch.eq(data, expected_data))
         
     | 
    	
        tests/ops/topology_test.py
    ADDED
    
    | 
         @@ -0,0 +1,81 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright 2024 Databricks
         
     | 
| 2 | 
         
            +
            # SPDX-License-Identifier: Apache-2.0
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import pytest
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from megablocks import ops
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            TOPOLOGY_TESTS = (
         
     | 
| 11 | 
         
            +
                (1024, 1536, 2),
         
     | 
| 12 | 
         
            +
                (1024, 1536, 4),
         
     | 
| 13 | 
         
            +
                (1024, 1536, 8),
         
     | 
| 14 | 
         
            +
                (1024, 1536, 16),
         
     | 
| 15 | 
         
            +
                (1024, 1536, 32),
         
     | 
| 16 | 
         
            +
                (1024, 1536, 64),
         
     | 
| 17 | 
         
            +
                (1024, 1536, 128),
         
     | 
| 18 | 
         
            +
                (1024, 1536, 256),
         
     | 
| 19 | 
         
            +
                (1024, 1536, 512),
         
     | 
| 20 | 
         
            +
                (16384, 768, 2),
         
     | 
| 21 | 
         
            +
                (16384, 768, 4),
         
     | 
| 22 | 
         
            +
                (16384, 768, 8),
         
     | 
| 23 | 
         
            +
                (16384, 768, 16),
         
     | 
| 24 | 
         
            +
                (16384, 768, 32),
         
     | 
| 25 | 
         
            +
                (16384, 768, 64),
         
     | 
| 26 | 
         
            +
                (16384, 768, 128),
         
     | 
| 27 | 
         
            +
                (16384, 768, 256),
         
     | 
| 28 | 
         
            +
                (16384, 768, 512),
         
     | 
| 29 | 
         
            +
                (16384, 768, 1024),
         
     | 
| 30 | 
         
            +
                (8, 14336, 8),
         
     | 
| 31 | 
         
            +
            )
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            @pytest.mark.gpu
         
     | 
| 35 | 
         
            +
            @pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS)
         
     | 
| 36 | 
         
            +
            def test_topology(sl: int, hs: int, ne: int):
         
     | 
| 37 | 
         
            +
                # Create the data and indices.
         
     | 
| 38 | 
         
            +
                blocking = 128
         
     | 
| 39 | 
         
            +
                assert hs % blocking == 0
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                # Randomly assign tokens to experts.
         
     | 
| 42 | 
         
            +
                top_expert = torch.randint(0, ne, (sl,)).cuda().int()
         
     | 
| 43 | 
         
            +
                tokens_per_expert = ops.histogram(top_expert, ne)
         
     | 
| 44 | 
         
            +
                padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking)
         
     | 
| 45 | 
         
            +
                padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                # Dimensions for the output indices.
         
     | 
| 48 | 
         
            +
                output_block_rows = int(padded_bins[-1]) // blocking
         
     | 
| 49 | 
         
            +
                output_block_columns = hs // blocking
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def topology(
         
     | 
| 52 | 
         
            +
                    padded_bins: torch.Tensor,
         
     | 
| 53 | 
         
            +
                    blocking: torch.Tensor,
         
     | 
| 54 | 
         
            +
                    rows: int,
         
     | 
| 55 | 
         
            +
                    columns: int,
         
     | 
| 56 | 
         
            +
                ):
         
     | 
| 57 | 
         
            +
                    padded_bins = padded_bins.cpu().numpy()
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    out = np.zeros([rows * columns])
         
     | 
| 60 | 
         
            +
                    start = 0
         
     | 
| 61 | 
         
            +
                    for i in range(padded_bins.shape[0]):
         
     | 
| 62 | 
         
            +
                        end = padded_bins[i] // blocking
         
     | 
| 63 | 
         
            +
                        while start < end:
         
     | 
| 64 | 
         
            +
                            for j in range(columns):
         
     | 
| 65 | 
         
            +
                                out[start * columns + j] = j + i * columns
         
     | 
| 66 | 
         
            +
                            start += 1
         
     | 
| 67 | 
         
            +
                    return torch.from_numpy(out).cuda().short()
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                out = ops.topology(
         
     | 
| 70 | 
         
            +
                    padded_bins,
         
     | 
| 71 | 
         
            +
                    blocking,
         
     | 
| 72 | 
         
            +
                    output_block_rows,
         
     | 
| 73 | 
         
            +
                    output_block_columns,
         
     | 
| 74 | 
         
            +
                )
         
     | 
| 75 | 
         
            +
                expected_out = topology(
         
     | 
| 76 | 
         
            +
                    padded_bins,
         
     | 
| 77 | 
         
            +
                    blocking,
         
     | 
| 78 | 
         
            +
                    output_block_rows,
         
     | 
| 79 | 
         
            +
                    output_block_columns,
         
     | 
| 80 | 
         
            +
                )
         
     | 
| 81 | 
         
            +
                assert torch.all(torch.eq(out, expected_out))
         
     | 
    	
        tests/test_mb_moe.py
    CHANGED
    
    | 
         @@ -1,6 +1,48 @@ 
     | 
|
| 
         | 
|
| 1 | 
         
             
            import megablocks
         
     | 
| 2 | 
         | 
| 3 | 
         
             
            def test_import():
         
     | 
| 4 | 
         
             
                """Simple test to check if the module can be imported."""
         
     | 
| 5 | 
         
             
                print("megablocks_moe module imported successfully.")
         
     | 
| 6 | 
         
             
                print("Available functions:", dir(megablocks))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
             
            import megablocks
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            def test_import():
         
     | 
| 5 | 
         
             
                """Simple test to check if the module can be imported."""
         
     | 
| 6 | 
         
             
                print("megablocks_moe module imported successfully.")
         
     | 
| 7 | 
         
             
                print("Available functions:", dir(megablocks))
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                expected_functions = [
         
     | 
| 10 | 
         
            +
                    "Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP",
         
     | 
| 11 | 
         
            +
                    "SparseGLU", "SparseMLP", "argsort",
         
     | 
| 12 | 
         
            +
                    "backend", "cumsum", "dMoE", "exclusive_cumsum",
         
     | 
| 13 | 
         
            +
                    "get_load_balancing_loss", "grouped_gemm_util", "histogram",
         
     | 
| 14 | 
         
            +
                    "inclusive_cumsum", "indices", "layers", "ops", "replicate_backward",
         
     | 
| 15 | 
         
            +
                    "replicate_forward", "sort", "torch"
         
     | 
| 16 | 
         
            +
                ]
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                # Check if all expected functions are available
         
     | 
| 19 | 
         
            +
                for func in expected_functions:
         
     | 
| 20 | 
         
            +
                    assert func in dir(megablocks), f"Missing function: {func}" 
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            # exclusive_cumsum
         
     | 
| 23 | 
         
            +
            def test_exclusive_cumsum():
         
     | 
| 24 | 
         
            +
                """Test exclusive cumulative sum."""
         
     | 
| 25 | 
         
            +
                x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
         
     | 
| 26 | 
         
            +
                out = torch.empty_like(x)
         
     | 
| 27 | 
         
            +
                megablocks.exclusive_cumsum(x, 0, out)
         
     | 
| 28 | 
         
            +
                expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda()
         
     | 
| 29 | 
         
            +
                assert torch.equal(out, expected), f"Expected {expected}, got {out}"
         
     | 
| 30 | 
         
            +
                print("cumsum output:", out)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            # inclusive_cumsum
         
     | 
| 33 | 
         
            +
            def test_inclusive_cumsum():
         
     | 
| 34 | 
         
            +
                """Test inclusive cumulative sum."""
         
     | 
| 35 | 
         
            +
                x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
         
     | 
| 36 | 
         
            +
                out = torch.empty_like(x)
         
     | 
| 37 | 
         
            +
                megablocks.inclusive_cumsum(x, dim=0, out=out)
         
     | 
| 38 | 
         
            +
                expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda()
         
     | 
| 39 | 
         
            +
                assert torch.equal(out, expected), f"Expected {expected}, got {out}"
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            # histogram
         
     | 
| 42 | 
         
            +
            def test_histogram():
         
     | 
| 43 | 
         
            +
                """Test histogram operation."""
         
     | 
| 44 | 
         
            +
                x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda()
         
     | 
| 45 | 
         
            +
                num_bins = 3
         
     | 
| 46 | 
         
            +
                hist = megablocks.histogram(x, num_bins)
         
     | 
| 47 | 
         
            +
                expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda()
         
     | 
| 48 | 
         
            +
                assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}"
         
     | 
    	
        torch-ext/megablocks/__init__.py
    CHANGED
    
    | 
         @@ -24,7 +24,9 @@ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tens 
     | 
|
| 24 | 
         
             
                Returns:
         
     | 
| 25 | 
         
             
                    The output tensor
         
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 28 | 
         | 
| 29 | 
         | 
| 30 | 
         
             
            def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
         
     | 
| 
         @@ -39,7 +41,9 @@ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tens 
     | 
|
| 39 | 
         
             
                Returns:
         
     | 
| 40 | 
         
             
                    The output tensor
         
     | 
| 41 | 
         
             
                """
         
     | 
| 42 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 43 | 
         | 
| 44 | 
         | 
| 45 | 
         
             
            def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
         
     | 
| 
         | 
|
| 24 | 
         
             
                Returns:
         
     | 
| 25 | 
         
             
                    The output tensor
         
     | 
| 26 | 
         
             
                """
         
     | 
| 27 | 
         
            +
                result = ops.exclusive_cumsum(x, dim)
         
     | 
| 28 | 
         
            +
                out.copy_(result)
         
     | 
| 29 | 
         
            +
                return out
         
     | 
| 30 | 
         | 
| 31 | 
         | 
| 32 | 
         
             
            def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
         
     | 
| 
         | 
|
| 41 | 
         
             
                Returns:
         
     | 
| 42 | 
         
             
                    The output tensor
         
     | 
| 43 | 
         
             
                """
         
     | 
| 44 | 
         
            +
                result = ops.inclusive_cumsum(x, dim)
         
     | 
| 45 | 
         
            +
                out.copy_(result)
         
     | 
| 46 | 
         
            +
                return out
         
     | 
| 47 | 
         | 
| 48 | 
         | 
| 49 | 
         
             
            def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
         
     | 
    	
        torch-ext/megablocks/ops/cumsum.py
    CHANGED
    
    | 
         @@ -11,7 +11,7 @@ import torch 
     | 
|
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
             
                # import megablocks_ops as ops  # type: ignore
         
     | 
| 14 | 
         
            -
                 
     | 
| 15 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 16 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 17 | 
         | 
| 
         | 
|
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
             
                # import megablocks_ops as ops  # type: ignore
         
     | 
| 14 | 
         
            +
                from megablocks._ops import ops  # type: ignore
         
     | 
| 15 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 16 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 17 | 
         | 
    	
        torch-ext/megablocks/ops/histogram.py
    CHANGED
    
    | 
         @@ -10,7 +10,7 @@ import torch 
     | 
|
| 10 | 
         
             
            # Wrap this in a try-block with better error message and
         
     | 
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
            -
                 
     | 
| 14 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 15 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 16 | 
         | 
| 
         | 
|
| 10 | 
         
             
            # Wrap this in a try-block with better error message and
         
     | 
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
            +
                from megablocks._ops import ops  # type: ignore
         
     | 
| 14 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 15 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 16 | 
         | 
    	
        torch-ext/megablocks/ops/replicate.py
    CHANGED
    
    | 
         @@ -10,8 +10,7 @@ import torch 
     | 
|
| 10 | 
         
             
            # Wrap this in a try-block with better error message and
         
     | 
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
            -
                 
     | 
| 14 | 
         
            -
                import megablocks._ops as ops  # type: ignore
         
     | 
| 15 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 16 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 17 | 
         | 
| 
         | 
|
| 10 | 
         
             
            # Wrap this in a try-block with better error message and
         
     | 
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
            +
                from megablocks._ops import ops  # type: ignore
         
     | 
| 
         | 
|
| 14 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 15 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 16 | 
         | 
    	
        torch-ext/megablocks/ops/sort.py
    CHANGED
    
    | 
         @@ -10,8 +10,7 @@ import torch 
     | 
|
| 10 | 
         
             
            # Wrap this in a try-block with better error message and
         
     | 
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
            -
                 
     | 
| 14 | 
         
            -
                import megablocks._ops as ops  # type: ignore
         
     | 
| 15 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 16 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 17 | 
         | 
| 
         | 
|
| 10 | 
         
             
            # Wrap this in a try-block with better error message and
         
     | 
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
            +
                from megablocks._ops import ops  # type: ignore
         
     | 
| 
         | 
|
| 14 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 15 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 16 | 
         | 
    	
        torch-ext/megablocks/ops/topology.py
    CHANGED
    
    | 
         @@ -10,8 +10,7 @@ import torch 
     | 
|
| 10 | 
         
             
            # Wrap this in a try-block with better error message and
         
     | 
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
            -
                 
     | 
| 14 | 
         
            -
                import megablocks._ops as ops  # type: ignore
         
     | 
| 15 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 16 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 17 | 
         | 
| 
         | 
|
| 10 | 
         
             
            # Wrap this in a try-block with better error message and
         
     | 
| 11 | 
         
             
            # instructions for building the c++ operations.
         
     | 
| 12 | 
         
             
            try:
         
     | 
| 13 | 
         
            +
                from megablocks._ops import ops  # type: ignore
         
     | 
| 
         | 
|
| 14 | 
         
             
            except ModuleNotFoundError as e:
         
     | 
| 15 | 
         
             
                raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
         
     | 
| 16 | 
         | 
    	
        torch-ext/torch_binding.cpp
    CHANGED
    
    | 
         @@ -34,22 +34,22 @@ torch::Tensor histogram_wrapper(torch::Tensor x, int64_t num_bins) { 
     | 
|
| 34 | 
         
             
            torch::Tensor indices_wrapper(torch::Tensor padded_bins,
         
     | 
| 35 | 
         
             
                                           int64_t block_size,
         
     | 
| 36 | 
         
             
                                           int64_t output_block_rows,
         
     | 
| 37 | 
         
            -
                                           int64_t output_block_columns 
     | 
| 38 | 
         
            -
             
     | 
| 39 | 
         
             
              megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out);
         
     | 
| 40 | 
         
             
              return out;
         
     | 
| 41 | 
         
             
            }
         
     | 
| 42 | 
         | 
| 43 | 
         | 
| 44 | 
         | 
| 45 | 
         
            -
            //  
     | 
| 46 | 
         
            -
            //  
     | 
| 47 | 
         
            -
            // 
     | 
| 48 | 
         
            -
            // 
     | 
| 49 | 
         
            -
             
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         | 
| 54 | 
         
             
            // // Backward pass: reduce gradients back to bins using segmented reduction
         
     | 
| 55 | 
         
             
            // void replicate_backward(torch::Tensor grad,
         
     | 
| 
         @@ -90,11 +90,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { 
     | 
|
| 90 | 
         
             
              ops.def("histogram(Tensor x, int num_bins) -> Tensor");
         
     | 
| 91 | 
         
             
              ops.impl("histogram", torch::kCUDA, &histogram_wrapper);
         
     | 
| 92 | 
         | 
| 93 | 
         
            -
              ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns) -> Tensor");
         
     | 
| 94 | 
         
             
              ops.impl("indices", torch::kCUDA, &indices_wrapper);
         
     | 
| 95 | 
         | 
| 96 | 
         
            -
               
     | 
| 97 | 
         
            -
               
     | 
| 98 | 
         | 
| 99 | 
         
             
              ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
         
     | 
| 100 | 
         
             
              ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper);
         
     | 
| 
         | 
|
| 34 | 
         
             
            torch::Tensor indices_wrapper(torch::Tensor padded_bins,
         
     | 
| 35 | 
         
             
                                           int64_t block_size,
         
     | 
| 36 | 
         
             
                                           int64_t output_block_rows,
         
     | 
| 37 | 
         
            +
                                           int64_t output_block_columns,
         
     | 
| 38 | 
         
            +
                                           torch::Tensor out) {
         
     | 
| 39 | 
         
             
              megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out);
         
     | 
| 40 | 
         
             
              return out;
         
     | 
| 41 | 
         
             
            }
         
     | 
| 42 | 
         | 
| 43 | 
         | 
| 44 | 
         | 
| 45 | 
         
            +
            // Forward pass: replicate values from x according to bin sizes
         
     | 
| 46 | 
         
            +
            // void replicate_forward(torch::Tensor x,
         
     | 
| 47 | 
         
            +
            //   torch::Tensor bins,
         
     | 
| 48 | 
         
            +
            //   torch::Tensor out);
         
     | 
| 49 | 
         
            +
            torch::Tensor replicate_forward_wrapper(torch::Tensor x, torch::Tensor bins, torch::Tensor out) {
         
     | 
| 50 | 
         
            +
              megablocks::replicate_forward(x, bins, out);
         
     | 
| 51 | 
         
            +
              return out;
         
     | 
| 52 | 
         
            +
            }
         
     | 
| 53 | 
         | 
| 54 | 
         
             
            // // Backward pass: reduce gradients back to bins using segmented reduction
         
     | 
| 55 | 
         
             
            // void replicate_backward(torch::Tensor grad,
         
     | 
| 
         | 
|
| 90 | 
         
             
              ops.def("histogram(Tensor x, int num_bins) -> Tensor");
         
     | 
| 91 | 
         
             
              ops.impl("histogram", torch::kCUDA, &histogram_wrapper);
         
     | 
| 92 | 
         | 
| 93 | 
         
            +
              ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns, Tensor(a!) out) -> Tensor(a!)");
         
     | 
| 94 | 
         
             
              ops.impl("indices", torch::kCUDA, &indices_wrapper);
         
     | 
| 95 | 
         | 
| 96 | 
         
            +
              ops.def("replicate_forward(Tensor x, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
         
     | 
| 97 | 
         
            +
              ops.impl("replicate_forward", torch::kCUDA, &replicate_forward_wrapper);
         
     | 
| 98 | 
         | 
| 99 | 
         
             
              ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
         
     | 
| 100 | 
         
             
              ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper);
         
     |