drbh
		
	commited on
		
		
					Commit 
							
							·
						
						63599de
	
1
								Parent(s):
							
							ef966b6
								
fix: fully vendor stk and fix imports
Browse files- flake.lock +7 -41
- flake.nix +15 -43
- torch-ext/megablocks/_layers/activation_fn.py +1 -1
- torch-ext/megablocks/_layers/dmoe.py +9 -2
- torch-ext/megablocks/_layers/gelu.py +10 -1
- torch-ext/megablocks/_layers/glu.py +11 -1
- torch-ext/megablocks/_layers/mlp.py +12 -3
- torch-ext/megablocks/ops/matmul_benchmark.py +13 -1
- torch-ext/megablocks/stk/__init__.py +7 -0
- torch-ext/megablocks/stk/backend/__init__.py +0 -0
- torch-ext/megablocks/stk/backend/autocast.py +37 -0
- torch-ext/megablocks/stk/backend/sputnik.py +316 -0
- torch-ext/megablocks/stk/backend/triton_kernels.py +393 -0
- torch-ext/megablocks/stk/matrix.py +329 -0
- torch-ext/megablocks/stk/ops/__init__.py +3 -0
- torch-ext/megablocks/stk/ops/eltwise_ops.py +28 -0
- torch-ext/megablocks/stk/ops/eltwise_ops_test.py +86 -0
- torch-ext/megablocks/stk/ops/linear_ops.py +59 -0
- torch-ext/megablocks/stk/ops/linear_ops_test.py +216 -0
- torch-ext/megablocks/stk/ops/matrix_ops.py +98 -0
- torch-ext/megablocks/stk/ops/matrix_ops_test.py +62 -0
- torch-ext/megablocks/stk/random/__init__.py +2 -0
- torch-ext/megablocks/stk/random/random_ops.py +36 -0
- torch-ext/megablocks/stk/random/random_ops_test.py +73 -0
    	
        flake.lock
    CHANGED
    
    | @@ -1,21 +1,5 @@ | |
| 1 | 
             
            {
         | 
| 2 | 
             
              "nodes": {
         | 
| 3 | 
            -
                "composer": {
         | 
| 4 | 
            -
                  "flake": false,
         | 
| 5 | 
            -
                  "locked": {
         | 
| 6 | 
            -
                    "lastModified": 1749592532,
         | 
| 7 | 
            -
                    "narHash": "sha256-VKfSWtf+Z20nP1cHiBNwFzYCuMGL0xelvd6HMyDnIhc=",
         | 
| 8 | 
            -
                    "owner": "mosaicml",
         | 
| 9 | 
            -
                    "repo": "composer",
         | 
| 10 | 
            -
                    "rev": "0eec49da42e7f617329f035853800211f0a54ca3",
         | 
| 11 | 
            -
                    "type": "github"
         | 
| 12 | 
            -
                  },
         | 
| 13 | 
            -
                  "original": {
         | 
| 14 | 
            -
                    "owner": "mosaicml",
         | 
| 15 | 
            -
                    "repo": "composer",
         | 
| 16 | 
            -
                    "type": "github"
         | 
| 17 | 
            -
                  }
         | 
| 18 | 
            -
                },
         | 
| 19 | 
             
                "flake-compat": {
         | 
| 20 | 
             
                  "locked": {
         | 
| 21 | 
             
                    "lastModified": 1747046372,
         | 
| @@ -89,10 +73,11 @@ | |
| 89 | 
             
                    "nixpkgs": "nixpkgs"
         | 
| 90 | 
             
                  },
         | 
| 91 | 
             
                  "locked": {
         | 
| 92 | 
            -
                    "lastModified":  | 
|  | |
| 93 | 
             
                    "owner": "huggingface",
         | 
| 94 | 
             
                    "repo": "hf-nix",
         | 
| 95 | 
            -
                    "rev": " | 
| 96 | 
             
                    "type": "github"
         | 
| 97 | 
             
                  },
         | 
| 98 | 
             
                  "original": {
         | 
| @@ -113,16 +98,15 @@ | |
| 113 | 
             
                    ]
         | 
| 114 | 
             
                  },
         | 
| 115 | 
             
                  "locked": {
         | 
| 116 | 
            -
                    "lastModified":  | 
| 117 | 
            -
                    "narHash": "sha256 | 
| 118 | 
             
                    "owner": "huggingface",
         | 
| 119 | 
             
                    "repo": "kernel-builder",
         | 
| 120 | 
            -
                    "rev": " | 
| 121 | 
             
                    "type": "github"
         | 
| 122 | 
             
                  },
         | 
| 123 | 
             
                  "original": {
         | 
| 124 | 
             
                    "owner": "huggingface",
         | 
| 125 | 
            -
                    "ref": "support-custom-python-libraries-in-dev-shell-nixland",
         | 
| 126 | 
             
                    "repo": "kernel-builder",
         | 
| 127 | 
             
                    "type": "github"
         | 
| 128 | 
             
                  }
         | 
| @@ -145,25 +129,7 @@ | |
| 145 | 
             
                },
         | 
| 146 | 
             
                "root": {
         | 
| 147 | 
             
                  "inputs": {
         | 
| 148 | 
            -
                    " | 
| 149 | 
            -
                    "kernel-builder": "kernel-builder",
         | 
| 150 | 
            -
                    "stk": "stk"
         | 
| 151 | 
            -
                  }
         | 
| 152 | 
            -
                },
         | 
| 153 | 
            -
                "stk": {
         | 
| 154 | 
            -
                  "flake": false,
         | 
| 155 | 
            -
                  "locked": {
         | 
| 156 | 
            -
                    "lastModified": 1724272107,
         | 
| 157 | 
            -
                    "narHash": "sha256-f6eydO4u6jasvepP25a6jacSvoUNyfKW51FxahMtz1Q=",
         | 
| 158 | 
            -
                    "owner": "stanford-futuredata",
         | 
| 159 | 
            -
                    "repo": "stk",
         | 
| 160 | 
            -
                    "rev": "736313768ef697ce13a0594a41b2512a0fbc9884",
         | 
| 161 | 
            -
                    "type": "github"
         | 
| 162 | 
            -
                  },
         | 
| 163 | 
            -
                  "original": {
         | 
| 164 | 
            -
                    "owner": "stanford-futuredata",
         | 
| 165 | 
            -
                    "repo": "stk",
         | 
| 166 | 
            -
                    "type": "github"
         | 
| 167 | 
             
                  }
         | 
| 168 | 
             
                },
         | 
| 169 | 
             
                "systems": {
         | 
|  | |
| 1 | 
             
            {
         | 
| 2 | 
             
              "nodes": {
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 3 | 
             
                "flake-compat": {
         | 
| 4 | 
             
                  "locked": {
         | 
| 5 | 
             
                    "lastModified": 1747046372,
         | 
|  | |
| 73 | 
             
                    "nixpkgs": "nixpkgs"
         | 
| 74 | 
             
                  },
         | 
| 75 | 
             
                  "locked": {
         | 
| 76 | 
            +
                    "lastModified": 1750234878,
         | 
| 77 | 
            +
                    "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
         | 
| 78 | 
             
                    "owner": "huggingface",
         | 
| 79 | 
             
                    "repo": "hf-nix",
         | 
| 80 | 
            +
                    "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
         | 
| 81 | 
             
                    "type": "github"
         | 
| 82 | 
             
                  },
         | 
| 83 | 
             
                  "original": {
         | 
|  | |
| 98 | 
             
                    ]
         | 
| 99 | 
             
                  },
         | 
| 100 | 
             
                  "locked": {
         | 
| 101 | 
            +
                    "lastModified": 1751014803,
         | 
| 102 | 
            +
                    "narHash": "sha256-9Xfq2k3uPfB602NwQF+zAY2GQZiKUN1G7Q6XiDCUR8Y=",
         | 
| 103 | 
             
                    "owner": "huggingface",
         | 
| 104 | 
             
                    "repo": "kernel-builder",
         | 
| 105 | 
            +
                    "rev": "bbc4e712ff2046e217818e97de2201e2b996756e",
         | 
| 106 | 
             
                    "type": "github"
         | 
| 107 | 
             
                  },
         | 
| 108 | 
             
                  "original": {
         | 
| 109 | 
             
                    "owner": "huggingface",
         | 
|  | |
| 110 | 
             
                    "repo": "kernel-builder",
         | 
| 111 | 
             
                    "type": "github"
         | 
| 112 | 
             
                  }
         | 
|  | |
| 129 | 
             
                },
         | 
| 130 | 
             
                "root": {
         | 
| 131 | 
             
                  "inputs": {
         | 
| 132 | 
            +
                    "kernel-builder": "kernel-builder"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 133 | 
             
                  }
         | 
| 134 | 
             
                },
         | 
| 135 | 
             
                "systems": {
         | 
    	
        flake.nix
    CHANGED
    
    | @@ -1,52 +1,24 @@ | |
| 1 | 
             
            {
         | 
| 2 | 
             
              description = "Flake for megablocks_moe kernel";
         | 
| 3 | 
            -
              
         | 
| 4 | 
            -
              inputs = {
         | 
| 5 | 
            -
                kernel-builder.url = "github:huggingface/kernel-builder/support-custom-python-libraries-in-dev-shell-nixland";
         | 
| 6 | 
            -
                # Add libraries as inputs
         | 
| 7 | 
            -
                composer = {
         | 
| 8 | 
            -
                  url = "github:mosaicml/composer";
         | 
| 9 | 
            -
                  flake = false;
         | 
| 10 | 
            -
                };
         | 
| 11 | 
            -
                stk = {
         | 
| 12 | 
            -
                  url = "github:stanford-futuredata/stk";
         | 
| 13 | 
            -
                  flake = false;
         | 
| 14 | 
            -
                };
         | 
| 15 |  | 
| 16 | 
            -
             | 
| 17 | 
            -
                 | 
| 18 | 
            -
                #   url = "github:tgale96/grouped_gemm";
         | 
| 19 | 
            -
                #   flake = false;
         | 
| 20 | 
            -
                # };
         | 
| 21 | 
             
              };
         | 
| 22 | 
            -
             | 
| 23 | 
            -
              outputs = | 
| 24 | 
            -
                 | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 27 | 
            -
                 | 
| 28 | 
            -
                # grouped_gemm,
         | 
| 29 | 
            -
              }:
         | 
| 30 | 
             
                kernel-builder.lib.genFlakeOutputs {
         | 
| 31 | 
             
                  path = ./.;
         | 
| 32 | 
             
                  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
         | 
| 33 | 
            -
             | 
| 34 | 
            -
                   | 
| 35 | 
            -
             | 
| 36 | 
            -
                     | 
| 37 | 
            -
                     | 
| 38 | 
            -
                     | 
| 39 | 
            -
                  };
         | 
| 40 | 
            -
                  
         | 
| 41 | 
            -
                  pythonTestDeps = [
         | 
| 42 | 
            -
                    "tqdm"
         | 
| 43 | 
            -
                    "py-cpuinfo"
         | 
| 44 | 
            -
                    "importlib-metadata"
         | 
| 45 | 
            -
                    "torchmetrics"
         | 
| 46 | 
            -
                    "composer"
         | 
| 47 | 
            -
                    "stk" 
         | 
| 48 | 
            -
                    # "grouped_gemm"
         | 
| 49 | 
            -
                    # "yahp" # may be needed for some testing plugin
         | 
| 50 | 
             
                  ];
         | 
| 51 | 
             
                };
         | 
| 52 | 
            -
            }
         | 
|  | |
| 1 | 
             
            {
         | 
| 2 | 
             
              description = "Flake for megablocks_moe kernel";
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 3 |  | 
| 4 | 
            +
              inputs = {
         | 
| 5 | 
            +
                kernel-builder.url = "github:huggingface/kernel-builder";
         | 
|  | |
|  | |
|  | |
| 6 | 
             
              };
         | 
| 7 | 
            +
             | 
| 8 | 
            +
              outputs =
         | 
| 9 | 
            +
                {
         | 
| 10 | 
            +
                  self,
         | 
| 11 | 
            +
                  kernel-builder,
         | 
| 12 | 
            +
                }:
         | 
|  | |
|  | |
| 13 | 
             
                kernel-builder.lib.genFlakeOutputs {
         | 
| 14 | 
             
                  path = ./.;
         | 
| 15 | 
             
                  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                  pythonCheckInputs = pkgs: with pkgs; [ 
         | 
| 18 | 
            +
                    tqdm
         | 
| 19 | 
            +
                    py-cpuinfo
         | 
| 20 | 
            +
                    importlib-metadata
         | 
| 21 | 
            +
                    torchmetrics
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 22 | 
             
                  ];
         | 
| 23 | 
             
                };
         | 
| 24 | 
            +
            }
         | 
    	
        torch-ext/megablocks/_layers/activation_fn.py
    CHANGED
    
    | @@ -4,7 +4,7 @@ | |
| 4 | 
             
            from typing import Any, Callable, Union
         | 
| 5 |  | 
| 6 | 
             
            import torch
         | 
| 7 | 
            -
            from stk import Matrix
         | 
| 8 |  | 
| 9 |  | 
| 10 | 
             
            def act_fn(
         | 
|  | |
| 4 | 
             
            from typing import Any, Callable, Union
         | 
| 5 |  | 
| 6 | 
             
            import torch
         | 
| 7 | 
            +
            from ..stk import Matrix
         | 
| 8 |  | 
| 9 |  | 
| 10 | 
             
            def act_fn(
         | 
    	
        torch-ext/megablocks/_layers/dmoe.py
    CHANGED
    
    | @@ -2,15 +2,22 @@ | |
| 2 | 
             
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 |  | 
| 4 | 
             
            import numpy as np
         | 
| 5 | 
            -
            import stk.ops
         | 
| 6 | 
             
            import torch
         | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 |  | 
| 9 | 
             
            # import megablocks.ops as ops
         | 
| 10 | 
             
            # # from megablocks.ops import ops
         | 
| 11 | 
             
            # from megablocks.layers import common, dmlp_registry, moe, mpu
         | 
| 12 | 
             
            # from megablocks.layers.arguments import Arguments
         | 
| 13 |  | 
|  | |
| 14 | 
             
            from .. import ops
         | 
| 15 | 
             
            from . import common, dmlp_registry, moe, mpu
         | 
| 16 | 
             
            from .arguments import Arguments
         | 
|  | |
| 2 | 
             
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 |  | 
| 4 | 
             
            import numpy as np
         | 
|  | |
| 5 | 
             
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # try:
         | 
| 8 | 
            +
            #     import stk.ops
         | 
| 9 | 
            +
            # except ImportError:
         | 
| 10 | 
            +
            #     import warnings
         | 
| 11 | 
            +
            #     warnings.warn(
         | 
| 12 | 
            +
            #         'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
         | 
| 13 | 
            +
            #     )
         | 
| 14 |  | 
| 15 | 
             
            # import megablocks.ops as ops
         | 
| 16 | 
             
            # # from megablocks.ops import ops
         | 
| 17 | 
             
            # from megablocks.layers import common, dmlp_registry, moe, mpu
         | 
| 18 | 
             
            # from megablocks.layers.arguments import Arguments
         | 
| 19 |  | 
| 20 | 
            +
            from .. import stk
         | 
| 21 | 
             
            from .. import ops
         | 
| 22 | 
             
            from . import common, dmlp_registry, moe, mpu
         | 
| 23 | 
             
            from .arguments import Arguments
         | 
    	
        torch-ext/megablocks/_layers/gelu.py
    CHANGED
    
    | @@ -1,7 +1,16 @@ | |
| 1 | 
             
            # Copyright 2024 Databricks
         | 
| 2 | 
             
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 |  | 
| 4 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 | 
             
            import torch
         | 
| 6 | 
             
            import torch.nn.functional as F
         | 
| 7 |  | 
|  | |
| 1 | 
             
            # Copyright 2024 Databricks
         | 
| 2 | 
             
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 |  | 
| 4 | 
            +
            # try:
         | 
| 5 | 
            +
            #     import stk
         | 
| 6 | 
            +
            # except ImportError:
         | 
| 7 | 
            +
            #     import warnings
         | 
| 8 | 
            +
            #     warnings.warn(
         | 
| 9 | 
            +
            #         'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
         | 
| 10 | 
            +
            #     )
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .. import stk
         | 
| 13 | 
            +
             | 
| 14 | 
             
            import torch
         | 
| 15 | 
             
            import torch.nn.functional as F
         | 
| 16 |  | 
    	
        torch-ext/megablocks/_layers/glu.py
    CHANGED
    
    | @@ -1,7 +1,17 @@ | |
| 1 | 
             
            # Copyright 2024 Databricks
         | 
| 2 | 
             
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 |  | 
| 4 | 
            -
            import stk.ops
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 5 | 
             
            import torch
         | 
| 6 |  | 
| 7 | 
             
            # from megablocks import grouped_gemm_util as gg
         | 
|  | |
| 1 | 
             
            # Copyright 2024 Databricks
         | 
| 2 | 
             
            # SPDX-License-Identifier: Apache-2.0
         | 
| 3 |  | 
| 4 | 
            +
            # import stk.ops
         | 
| 5 | 
            +
            # try:
         | 
| 6 | 
            +
            #     import stk.ops
         | 
| 7 | 
            +
            # except ImportError:
         | 
| 8 | 
            +
            #     import warnings
         | 
| 9 | 
            +
            #     warnings.warn(
         | 
| 10 | 
            +
            #         'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
         | 
| 11 | 
            +
            #     )
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .. import stk
         | 
| 14 | 
            +
             | 
| 15 | 
             
            import torch
         | 
| 16 |  | 
| 17 | 
             
            # from megablocks import grouped_gemm_util as gg
         | 
    	
        torch-ext/megablocks/_layers/mlp.py
    CHANGED
    
    | @@ -3,9 +3,18 @@ | |
| 3 |  | 
| 4 | 
             
            from typing import Any
         | 
| 5 |  | 
| 6 | 
            -
             | 
| 7 | 
            -
            import stk | 
| 8 | 
            -
            import stk. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 9 | 
             
            import torch
         | 
| 10 | 
             
            from packaging import version
         | 
| 11 |  | 
|  | |
| 3 |  | 
| 4 | 
             
            from typing import Any
         | 
| 5 |  | 
| 6 | 
            +
            # try:
         | 
| 7 | 
            +
            #     import stk
         | 
| 8 | 
            +
            #     import stk.backend.triton_kernels
         | 
| 9 | 
            +
            #     import stk.ops
         | 
| 10 | 
            +
            # except ImportError:
         | 
| 11 | 
            +
            #     import warnings
         | 
| 12 | 
            +
            #     warnings.warn(
         | 
| 13 | 
            +
            #         'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
         | 
| 14 | 
            +
            #     )
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from .. import stk
         | 
| 17 | 
            +
             | 
| 18 | 
             
            import torch
         | 
| 19 | 
             
            from packaging import version
         | 
| 20 |  | 
    	
        torch-ext/megablocks/ops/matmul_benchmark.py
    CHANGED
    
    | @@ -3,7 +3,19 @@ | |
| 3 |  | 
| 4 | 
             
            import unittest
         | 
| 5 |  | 
| 6 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 7 | 
             
            import torch
         | 
| 8 | 
             
            from absl.testing import parameterized
         | 
| 9 |  | 
|  | |
| 3 |  | 
| 4 | 
             
            import unittest
         | 
| 5 |  | 
| 6 | 
            +
             | 
| 7 | 
            +
            # import stk
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # try:
         | 
| 10 | 
            +
            #     import stk
         | 
| 11 | 
            +
            # except ImportError:
         | 
| 12 | 
            +
            #     import warnings
         | 
| 13 | 
            +
            #     warnings.warn(
         | 
| 14 | 
            +
            #         'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
         | 
| 15 | 
            +
            #     )
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from .. import stk
         | 
| 18 | 
            +
             | 
| 19 | 
             
            import torch
         | 
| 20 | 
             
            from absl.testing import parameterized
         | 
| 21 |  | 
    	
        torch-ext/megablocks/stk/__init__.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # import stk.random
         | 
| 2 | 
            +
            # import stk.ops
         | 
| 3 | 
            +
            # from stk.matrix import Matrix
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from . import random
         | 
| 6 | 
            +
            from . import ops
         | 
| 7 | 
            +
            from .matrix import Matrix
         | 
    	
        torch-ext/megablocks/stk/backend/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        torch-ext/megablocks/stk/backend/autocast.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import functools
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def _is_eligible(x):
         | 
| 6 | 
            +
                return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def _cast(x, dtype):
         | 
| 10 | 
            +
                if isinstance(x, torch.Tensor) and _is_eligible(x):
         | 
| 11 | 
            +
                    return x.to(dtype)
         | 
| 12 | 
            +
                elif isinstance(x, map):
         | 
| 13 | 
            +
                    return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
         | 
| 14 | 
            +
                elif isinstance(x, list) or isinstance(x, tuple):
         | 
| 15 | 
            +
                    return type(x)(map(lambda y: _cast(y, dtype), x))
         | 
| 16 | 
            +
                return x
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def custom_fwd(fwd):
         | 
| 20 | 
            +
                """Wrap a custom autograd function that always uses autocast dtype."""
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @functools.wraps(fwd)
         | 
| 23 | 
            +
                def decorate_fwd(*args, **kwargs):
         | 
| 24 | 
            +
                    if torch.is_autocast_enabled():
         | 
| 25 | 
            +
                        with torch.autocast(device_type="cuda", enabled=False):
         | 
| 26 | 
            +
                            dtype = torch.get_autocast_gpu_dtype()
         | 
| 27 | 
            +
                            return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
         | 
| 28 | 
            +
                    return fwd(*args, **kwargs)
         | 
| 29 | 
            +
                return decorate_fwd
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def custom_bwd(bwd):
         | 
| 33 | 
            +
                @functools.wraps(bwd)
         | 
| 34 | 
            +
                def decorate_bwd(*args, **kwargs):
         | 
| 35 | 
            +
                    with torch.autocast(device_type="cuda", enabled=False):
         | 
| 36 | 
            +
                        return bwd(*args, **kwargs)
         | 
| 37 | 
            +
                return decorate_bwd
         | 
    	
        torch-ext/megablocks/stk/backend/sputnik.py
    ADDED
    
    | @@ -0,0 +1,316 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from ..backend import triton_kernels as backend
         | 
| 4 | 
            +
            from ..backend.autocast import custom_bwd, custom_fwd
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def _standardize_shape(x, transpose):
         | 
| 8 | 
            +
                if transpose:
         | 
| 9 | 
            +
                    return torch.Size((x[1], x[0]))
         | 
| 10 | 
            +
                return x
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def _sparse_transpose(x):
         | 
| 14 | 
            +
                return (torch.Size((x[0][1], x[0][0])), ) + x[1:]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def _transpose_helper(x, transpose):
         | 
| 18 | 
            +
                if isinstance(x, torch.Tensor):
         | 
| 19 | 
            +
                    return x.t() if transpose else x
         | 
| 20 | 
            +
                if transpose:
         | 
| 21 | 
            +
                    x = _sparse_transpose(x)
         | 
| 22 | 
            +
                return x + (transpose,)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def _wrap(x):
         | 
| 26 | 
            +
                if isinstance(x, torch.Tensor):
         | 
| 27 | 
            +
                    return (x,)
         | 
| 28 | 
            +
                return x
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def _is_transposed(x):
         | 
| 32 | 
            +
                return (not x.is_contiguous() and
         | 
| 33 | 
            +
                        x.stride()[0] == 1 and
         | 
| 34 | 
            +
                        x.stride()[1] == x.size()[0])
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def _call_helper(op, out, a, b, trans_a, trans_b):
         | 
| 38 | 
            +
                args = (_wrap(_transpose_helper(a, trans_a)) +
         | 
| 39 | 
            +
                        _wrap(_transpose_helper(b, trans_b)))
         | 
| 40 | 
            +
                if isinstance(out, tuple):
         | 
| 41 | 
            +
                    args = args + out
         | 
| 42 | 
            +
                return op(*args)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def _preprocess_inputs(lhs, rhs, dy):
         | 
| 46 | 
            +
                if isinstance(lhs, torch.Tensor) and _is_transposed(lhs):
         | 
| 47 | 
            +
                    lhs = lhs.t()
         | 
| 48 | 
            +
                if isinstance(rhs, torch.Tensor) and _is_transposed(rhs):
         | 
| 49 | 
            +
                    rhs = rhs.t()
         | 
| 50 | 
            +
                if (isinstance(dy, torch.Tensor) and
         | 
| 51 | 
            +
                    not dy.is_contiguous() and
         | 
| 52 | 
            +
                    not _is_transposed(dy)):
         | 
| 53 | 
            +
                    dy = dy.contiguous()
         | 
| 54 | 
            +
                if isinstance(dy, tuple) and not dy[1].is_contiguous():
         | 
| 55 | 
            +
                    dy = (dy[0], dy[1].contiguous()) + dy[2:]
         | 
| 56 | 
            +
                return lhs, rhs, dy
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            def _postprocess_outputs(x, transpose, grad):
         | 
| 60 | 
            +
                if isinstance(x, torch.Tensor) and transpose:
         | 
| 61 | 
            +
                    return grad.t()
         | 
| 62 | 
            +
                return grad
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
         | 
| 66 | 
            +
                lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                a, b = (rhs, dy) if trans_lhs else (dy, rhs)
         | 
| 69 | 
            +
                trans_a = trans_lhs and trans_rhs
         | 
| 70 | 
            +
                trans_b = trans_lhs or not trans_rhs
         | 
| 71 | 
            +
                out = _call_helper(op, lhs, a, b, trans_a, trans_b)
         | 
| 72 | 
            +
                return _postprocess_outputs(lhs, trans_lhs, out)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs):
         | 
| 76 | 
            +
                lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                a, b = (dy, lhs) if trans_rhs else (lhs, dy)
         | 
| 79 | 
            +
                trans_a = not trans_lhs or trans_rhs
         | 
| 80 | 
            +
                trans_b = trans_lhs and trans_rhs
         | 
| 81 | 
            +
                out = _call_helper(op, rhs, a, b, trans_a, trans_b)
         | 
| 82 | 
            +
                return _postprocess_outputs(rhs, trans_rhs, out)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            class DSD(torch.autograd.Function):
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                @staticmethod
         | 
| 88 | 
            +
                @custom_fwd
         | 
| 89 | 
            +
                def forward(ctx,
         | 
| 90 | 
            +
                            shape,
         | 
| 91 | 
            +
                            data,
         | 
| 92 | 
            +
                            offsets,
         | 
| 93 | 
            +
                            row_indices,
         | 
| 94 | 
            +
                            column_indices,
         | 
| 95 | 
            +
                            offsets_t,
         | 
| 96 | 
            +
                            column_indices_t,
         | 
| 97 | 
            +
                            block_offsets_t,
         | 
| 98 | 
            +
                            transpose_a,
         | 
| 99 | 
            +
                            rhs):
         | 
| 100 | 
            +
                    ctx.save_for_backward(data,
         | 
| 101 | 
            +
                                          offsets,
         | 
| 102 | 
            +
                                          row_indices,
         | 
| 103 | 
            +
                                          column_indices,
         | 
| 104 | 
            +
                                          offsets_t,
         | 
| 105 | 
            +
                                          column_indices_t,
         | 
| 106 | 
            +
                                          block_offsets_t,
         | 
| 107 | 
            +
                                          rhs)
         | 
| 108 | 
            +
                    ctx.shape = _standardize_shape(shape, transpose_a)
         | 
| 109 | 
            +
                    ctx.transpose_a = transpose_a
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    out = torch.empty(
         | 
| 112 | 
            +
                        (shape[0], rhs.size()[1]),
         | 
| 113 | 
            +
                        dtype=rhs.dtype,
         | 
| 114 | 
            +
                        device=rhs.device)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    backend.dsd(shape,
         | 
| 117 | 
            +
                                data,
         | 
| 118 | 
            +
                                offsets,
         | 
| 119 | 
            +
                                row_indices,
         | 
| 120 | 
            +
                                column_indices,
         | 
| 121 | 
            +
                                offsets_t,
         | 
| 122 | 
            +
                                column_indices_t,
         | 
| 123 | 
            +
                                block_offsets_t,
         | 
| 124 | 
            +
                                transpose_a,
         | 
| 125 | 
            +
                                rhs,
         | 
| 126 | 
            +
                                out)
         | 
| 127 | 
            +
                    return out
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                @staticmethod
         | 
| 130 | 
            +
                @custom_bwd
         | 
| 131 | 
            +
                def backward(ctx, dy):
         | 
| 132 | 
            +
                    saved_tensors = ctx.saved_tensors
         | 
| 133 | 
            +
                    lhs = (ctx.shape,) + saved_tensors[:-1]
         | 
| 134 | 
            +
                    rhs = saved_tensors[-1]
         | 
| 135 | 
            +
                    trans_a = ctx.transpose_a
         | 
| 136 | 
            +
                    trans_b = _is_transposed(rhs)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    ddata = None
         | 
| 139 | 
            +
                    if ctx.needs_input_grad[1]:
         | 
| 140 | 
            +
                        ddata = _lhs_gradient(sdd,
         | 
| 141 | 
            +
                                              lhs,
         | 
| 142 | 
            +
                                              rhs,
         | 
| 143 | 
            +
                                              dy,
         | 
| 144 | 
            +
                                              trans_a,
         | 
| 145 | 
            +
                                              trans_b)
         | 
| 146 | 
            +
                    drhs = None
         | 
| 147 | 
            +
                    if ctx.needs_input_grad[-1]:
         | 
| 148 | 
            +
                        op = dds if trans_b else dsd
         | 
| 149 | 
            +
                        drhs = _rhs_gradient(op,
         | 
| 150 | 
            +
                                             lhs,
         | 
| 151 | 
            +
                                             rhs,
         | 
| 152 | 
            +
                                             dy,
         | 
| 153 | 
            +
                                             trans_a,
         | 
| 154 | 
            +
                                             trans_b)
         | 
| 155 | 
            +
                    return None, ddata, None, None, None, None, None, None, None, drhs
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            dsd = DSD.apply
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            class DDS(torch.autograd.Function):
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                @staticmethod
         | 
| 164 | 
            +
                @custom_fwd
         | 
| 165 | 
            +
                def forward(ctx,
         | 
| 166 | 
            +
                            lhs,
         | 
| 167 | 
            +
                            shape,
         | 
| 168 | 
            +
                            data,
         | 
| 169 | 
            +
                            offsets,
         | 
| 170 | 
            +
                            row_indices,
         | 
| 171 | 
            +
                            column_indices,
         | 
| 172 | 
            +
                            offsets_t,
         | 
| 173 | 
            +
                            column_indices_t,
         | 
| 174 | 
            +
                            block_offsets_t,
         | 
| 175 | 
            +
                            transpose_b):
         | 
| 176 | 
            +
                    ctx.save_for_backward(lhs,
         | 
| 177 | 
            +
                                          data,
         | 
| 178 | 
            +
                                          offsets,
         | 
| 179 | 
            +
                                          row_indices,
         | 
| 180 | 
            +
                                          column_indices,
         | 
| 181 | 
            +
                                          offsets_t,
         | 
| 182 | 
            +
                                          column_indices_t,
         | 
| 183 | 
            +
                                          block_offsets_t)
         | 
| 184 | 
            +
                    ctx.shape = _standardize_shape(shape, transpose_b)
         | 
| 185 | 
            +
                    ctx.transpose_b = transpose_b
         | 
| 186 | 
            +
                    out = torch.empty((lhs.size()[0], shape[1]),
         | 
| 187 | 
            +
                                      dtype=lhs.dtype,
         | 
| 188 | 
            +
                                      device=lhs.device)
         | 
| 189 | 
            +
                    backend.dds(lhs,
         | 
| 190 | 
            +
                                shape,
         | 
| 191 | 
            +
                                data,
         | 
| 192 | 
            +
                                offsets,
         | 
| 193 | 
            +
                                row_indices,
         | 
| 194 | 
            +
                                column_indices,
         | 
| 195 | 
            +
                                offsets_t,
         | 
| 196 | 
            +
                                column_indices_t,
         | 
| 197 | 
            +
                                block_offsets_t,
         | 
| 198 | 
            +
                                transpose_b,
         | 
| 199 | 
            +
                                out)
         | 
| 200 | 
            +
                    return out
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                @staticmethod
         | 
| 203 | 
            +
                @custom_bwd
         | 
| 204 | 
            +
                def backward(ctx, dy):
         | 
| 205 | 
            +
                    saved_tensors = ctx.saved_tensors
         | 
| 206 | 
            +
                    lhs = saved_tensors[0]
         | 
| 207 | 
            +
                    rhs = (ctx.shape,) + saved_tensors[1:]
         | 
| 208 | 
            +
                    trans_a = _is_transposed(lhs)
         | 
| 209 | 
            +
                    trans_b = ctx.transpose_b
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    dlhs = None
         | 
| 212 | 
            +
                    if ctx.needs_input_grad[0]:
         | 
| 213 | 
            +
                        op = dsd if trans_a else dds
         | 
| 214 | 
            +
                        dlhs = _lhs_gradient(op,
         | 
| 215 | 
            +
                                             lhs,
         | 
| 216 | 
            +
                                             rhs,
         | 
| 217 | 
            +
                                             dy,
         | 
| 218 | 
            +
                                             trans_a,
         | 
| 219 | 
            +
                                             trans_b)
         | 
| 220 | 
            +
                    ddata = None
         | 
| 221 | 
            +
                    if ctx.needs_input_grad[2]:
         | 
| 222 | 
            +
                        ddata = _rhs_gradient(sdd,
         | 
| 223 | 
            +
                                              lhs,
         | 
| 224 | 
            +
                                              rhs,
         | 
| 225 | 
            +
                                              dy,
         | 
| 226 | 
            +
                                              trans_a,
         | 
| 227 | 
            +
                                              trans_b)
         | 
| 228 | 
            +
                    return dlhs, None, ddata, None, None, None, None, None, None, None
         | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            dds = DDS.apply
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            class SDD(torch.autograd.Function):
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                @staticmethod
         | 
| 237 | 
            +
                @custom_fwd
         | 
| 238 | 
            +
                def forward(ctx,
         | 
| 239 | 
            +
                            lhs,
         | 
| 240 | 
            +
                            rhs,
         | 
| 241 | 
            +
                            shape,
         | 
| 242 | 
            +
                            data,
         | 
| 243 | 
            +
                            offsets,
         | 
| 244 | 
            +
                            row_indices,
         | 
| 245 | 
            +
                            column_indices,
         | 
| 246 | 
            +
                            offsets_t,
         | 
| 247 | 
            +
                            column_indices_t,
         | 
| 248 | 
            +
                            block_offsets_t):
         | 
| 249 | 
            +
                    ctx.save_for_backward(
         | 
| 250 | 
            +
                        lhs,
         | 
| 251 | 
            +
                        rhs,
         | 
| 252 | 
            +
                        offsets,
         | 
| 253 | 
            +
                        row_indices,
         | 
| 254 | 
            +
                        column_indices,
         | 
| 255 | 
            +
                        offsets_t,
         | 
| 256 | 
            +
                        column_indices_t,
         | 
| 257 | 
            +
                        block_offsets_t)
         | 
| 258 | 
            +
                    ctx.shape = shape
         | 
| 259 | 
            +
                    out = torch.empty(
         | 
| 260 | 
            +
                        data.shape,
         | 
| 261 | 
            +
                        dtype=lhs.dtype,
         | 
| 262 | 
            +
                        device=lhs.device)
         | 
| 263 | 
            +
                    backend.sdd(lhs,
         | 
| 264 | 
            +
                                rhs,
         | 
| 265 | 
            +
                                shape,
         | 
| 266 | 
            +
                                out,
         | 
| 267 | 
            +
                                offsets,
         | 
| 268 | 
            +
                                row_indices,
         | 
| 269 | 
            +
                                column_indices)
         | 
| 270 | 
            +
                    return out
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                @staticmethod
         | 
| 273 | 
            +
                @custom_bwd
         | 
| 274 | 
            +
                def backward(ctx, dy):
         | 
| 275 | 
            +
                    saved_tensors = ctx.saved_tensors
         | 
| 276 | 
            +
                    lhs, rhs = saved_tensors[:2]
         | 
| 277 | 
            +
                    dy = (ctx.shape, dy) + saved_tensors[2:]
         | 
| 278 | 
            +
                    trans_a = _is_transposed(lhs)
         | 
| 279 | 
            +
                    trans_b = _is_transposed(rhs)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    dlhs = None
         | 
| 282 | 
            +
                    if ctx.needs_input_grad[0]:
         | 
| 283 | 
            +
                        op = dds if trans_a else dsd
         | 
| 284 | 
            +
                        dlhs = _lhs_gradient(op,
         | 
| 285 | 
            +
                                             lhs,
         | 
| 286 | 
            +
                                             rhs,
         | 
| 287 | 
            +
                                             dy,
         | 
| 288 | 
            +
                                             trans_a,
         | 
| 289 | 
            +
                                             trans_b)
         | 
| 290 | 
            +
                    drhs = None
         | 
| 291 | 
            +
                    if ctx.needs_input_grad[1]:
         | 
| 292 | 
            +
                        op = dsd if trans_b else dds
         | 
| 293 | 
            +
                        drhs = _rhs_gradient(op,
         | 
| 294 | 
            +
                                             lhs,
         | 
| 295 | 
            +
                                             rhs,
         | 
| 296 | 
            +
                                             dy,
         | 
| 297 | 
            +
                                             trans_a,
         | 
| 298 | 
            +
                                             trans_b)
         | 
| 299 | 
            +
                    return dlhs, drhs, None, None, None, None, None, None, None, None
         | 
| 300 | 
            +
             | 
| 301 | 
            +
             | 
| 302 | 
            +
            sdd = SDD.apply
         | 
| 303 | 
            +
             | 
| 304 | 
            +
            class RowIndices(torch.autograd.Function):
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                @staticmethod
         | 
| 307 | 
            +
                def forward(ctx, shape, data, offsets, column_indices):
         | 
| 308 | 
            +
                    out = torch.empty(
         | 
| 309 | 
            +
                        column_indices.shape,
         | 
| 310 | 
            +
                        dtype=column_indices.dtype,
         | 
| 311 | 
            +
                        device=column_indices.device)
         | 
| 312 | 
            +
                    backend.row_indices(shape, data, offsets, column_indices, out)
         | 
| 313 | 
            +
                    return out
         | 
| 314 | 
            +
             | 
| 315 | 
            +
             | 
| 316 | 
            +
            row_indices = RowIndices.apply
         | 
    	
        torch-ext/megablocks/stk/backend/triton_kernels.py
    ADDED
    
    | @@ -0,0 +1,393 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import triton
         | 
| 3 | 
            +
            import triton.language as tl
         | 
| 4 | 
            +
            from dataclasses import dataclass
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            @dataclass
         | 
| 7 | 
            +
            class TritonConfig:
         | 
| 8 | 
            +
                BLOCK_M: int = 128
         | 
| 9 | 
            +
                BLOCK_N: int = 128
         | 
| 10 | 
            +
                BLOCK_K: int = 32
         | 
| 11 | 
            +
                BLOCK_SIZE: int = 128
         | 
| 12 | 
            +
                NUM_STAGES: int = 4
         | 
| 13 | 
            +
                NUM_WARPS: int = 4
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            def _validate_matmul_dims(M: int, K: int, N: int):
         | 
| 16 | 
            +
                error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}"
         | 
| 17 | 
            +
                assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M)
         | 
| 18 | 
            +
                assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K)
         | 
| 19 | 
            +
                assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            @triton.autotune(
         | 
| 22 | 
            +
                configs=[
         | 
| 23 | 
            +
                    # basic configs for compute-bound matmuls
         | 
| 24 | 
            +
                    triton.Config({
         | 
| 25 | 
            +
                        'BLOCK_M': TritonConfig.BLOCK_M,
         | 
| 26 | 
            +
                        'BLOCK_N': TritonConfig.BLOCK_N,
         | 
| 27 | 
            +
                        'BLOCK_K': TritonConfig.BLOCK_K,
         | 
| 28 | 
            +
                        'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
         | 
| 29 | 
            +
                    }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
         | 
| 30 | 
            +
                ],
         | 
| 31 | 
            +
                key=['M', 'N', 'K'],
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
            @triton.jit
         | 
| 34 | 
            +
            def _sdd_kernel(A, B, C, M, N, K,
         | 
| 35 | 
            +
                        stride_am, stride_ak,
         | 
| 36 | 
            +
                        stride_bk, stride_bn,
         | 
| 37 | 
            +
                        stride_cm, stride_cn,
         | 
| 38 | 
            +
                        row_indices, column_indices,
         | 
| 39 | 
            +
                        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
         | 
| 40 | 
            +
                        BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
         | 
| 41 | 
            +
                        ):
         | 
| 42 | 
            +
                # matrix multiplication
         | 
| 43 | 
            +
                pid = tl.program_id(0)
         | 
| 44 | 
            +
                pid_m = tl.load(row_indices + pid)
         | 
| 45 | 
            +
                pid_n = tl.load(column_indices + pid)
         | 
| 46 | 
            +
                rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
         | 
| 47 | 
            +
                rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
         | 
| 48 | 
            +
                ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
         | 
| 49 | 
            +
                rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
         | 
| 50 | 
            +
                rk = tl.arange(0, BLOCK_K)
         | 
| 51 | 
            +
                # pointers
         | 
| 52 | 
            +
                A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
         | 
| 53 | 
            +
                B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
         | 
| 54 | 
            +
                # do matrix multiplication
         | 
| 55 | 
            +
                acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
         | 
| 56 | 
            +
                for k in range(0, tl.cdiv(K, BLOCK_K)):
         | 
| 57 | 
            +
                    a = tl.load(A)
         | 
| 58 | 
            +
                    b = tl.load(B)
         | 
| 59 | 
            +
                    acc += tl.dot(a, b)
         | 
| 60 | 
            +
                    A += BLOCK_K * stride_ak
         | 
| 61 | 
            +
                    B += BLOCK_K * stride_bk
         | 
| 62 | 
            +
                #Store to sparse matrix
         | 
| 63 | 
            +
                acc = acc.to(C.dtype.element_ty)
         | 
| 64 | 
            +
                BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
         | 
| 65 | 
            +
                cm = tl.arange(0, BLOCK_M)
         | 
| 66 | 
            +
                cn = tl.arange(0, BLOCK_N)
         | 
| 67 | 
            +
                C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
         | 
| 68 | 
            +
                tl.store(C, acc, mask=True)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            @triton.autotune(
         | 
| 71 | 
            +
                configs=[
         | 
| 72 | 
            +
                    # basic configs for compute-bound matmuls
         | 
| 73 | 
            +
                    triton.Config({
         | 
| 74 | 
            +
                        'BLOCK_M': TritonConfig.BLOCK_M,
         | 
| 75 | 
            +
                        'BLOCK_N': TritonConfig.BLOCK_N,
         | 
| 76 | 
            +
                        'BLOCK_K': TritonConfig.BLOCK_K,
         | 
| 77 | 
            +
                        'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
         | 
| 78 | 
            +
                    }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
         | 
| 79 | 
            +
                ],
         | 
| 80 | 
            +
                key=['M', 'N', 'K'],
         | 
| 81 | 
            +
            )
         | 
| 82 | 
            +
            @triton.jit
         | 
| 83 | 
            +
            def _dsd_kernel(A, B, C, M, N, K,
         | 
| 84 | 
            +
                        stride_am, stride_ak,
         | 
| 85 | 
            +
                        stride_bk, stride_bn,
         | 
| 86 | 
            +
                        stride_cm, stride_cn,
         | 
| 87 | 
            +
                        row_indices, column_indices, offsets,
         | 
| 88 | 
            +
                        block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
         | 
| 89 | 
            +
                        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
         | 
| 90 | 
            +
                        BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
         | 
| 91 | 
            +
                        ):
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                # matrix multiplication
         | 
| 94 | 
            +
                pid_m = tl.program_id(0)
         | 
| 95 | 
            +
                pid_n = tl.program_id(1)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                num_pid_m = tl.num_programs(0)
         | 
| 98 | 
            +
                num_pid_n = tl.num_programs(1)
         | 
| 99 | 
            +
                pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                start_inx = tl.load(offsets + pid_m)
         | 
| 102 | 
            +
                end_inx = tl.load(offsets + pid_m + 1)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                # pointers to sparse matrix
         | 
| 105 | 
            +
                rm =  tl.arange(0, BLOCK_M)
         | 
| 106 | 
            +
                rak = tl.arange(0, BLOCK_K)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # pointers to dense matrix
         | 
| 111 | 
            +
                rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
         | 
| 112 | 
            +
                rbk = tl.arange(0, BLOCK_K)
         | 
| 113 | 
            +
                B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                # do matrix multiplication
         | 
| 116 | 
            +
                acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
         | 
| 117 | 
            +
                nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
         | 
| 120 | 
            +
                ak_sub_incr = BLOCK_K * stride_ak
         | 
| 121 | 
            +
                bk_sub_incr = BLOCK_K * stride_bk
         | 
| 122 | 
            +
                bk_block_incr = BLOCK_SIZE * stride_bk
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                for k in range(nsub_blocks * (end_inx - start_inx)):
         | 
| 125 | 
            +
                    sub_block_inx = k % nsub_blocks
         | 
| 126 | 
            +
                    block_inx = k // nsub_blocks
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    if trans_A:
         | 
| 129 | 
            +
                        ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
         | 
| 130 | 
            +
                    else:
         | 
| 131 | 
            +
                        ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    a = tl.load(ptr_A)
         | 
| 136 | 
            +
                    b = tl.load(ptr_B)
         | 
| 137 | 
            +
                    acc += tl.dot(a, b)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                acc = acc.to(C.dtype.element_ty)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
         | 
| 142 | 
            +
                cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
         | 
| 145 | 
            +
                tl.store(C, acc, mask=True)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            @triton.autotune(
         | 
| 148 | 
            +
                configs=[
         | 
| 149 | 
            +
                    # basic configs for compute-bound matmuls
         | 
| 150 | 
            +
                    triton.Config({
         | 
| 151 | 
            +
                        'BLOCK_M': TritonConfig.BLOCK_M,
         | 
| 152 | 
            +
                        'BLOCK_N': TritonConfig.BLOCK_N,
         | 
| 153 | 
            +
                        'BLOCK_K': TritonConfig.BLOCK_K,
         | 
| 154 | 
            +
                        'BLOCK_SIZE': TritonConfig.BLOCK_SIZE
         | 
| 155 | 
            +
                    }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS),
         | 
| 156 | 
            +
                ],
         | 
| 157 | 
            +
                key=['M', 'N', 'K'],
         | 
| 158 | 
            +
            )
         | 
| 159 | 
            +
            @triton.jit
         | 
| 160 | 
            +
            def _dds_kernel(A, B, C, M, N, K,
         | 
| 161 | 
            +
                        stride_am, stride_ak,
         | 
| 162 | 
            +
                        stride_bk, stride_bn,
         | 
| 163 | 
            +
                        stride_cm, stride_cn,
         | 
| 164 | 
            +
                        row_indices, column_indices, offsets,
         | 
| 165 | 
            +
                        block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr,
         | 
| 166 | 
            +
                        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
         | 
| 167 | 
            +
                        BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr,
         | 
| 168 | 
            +
                        ):
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                # matrix multiplication
         | 
| 171 | 
            +
                pid_m = tl.program_id(0)
         | 
| 172 | 
            +
                pid_n = tl.program_id(1)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                num_pid_m = tl.num_programs(0)
         | 
| 175 | 
            +
                num_pid_n = tl.num_programs(1)
         | 
| 176 | 
            +
                pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                start_inx = tl.load(offsets + pid_n)
         | 
| 179 | 
            +
                end_inx = tl.load(offsets + pid_n + 1)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                # pointers to dense matrix
         | 
| 182 | 
            +
                rm =  pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
         | 
| 183 | 
            +
                rak = tl.arange(0, BLOCK_K)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                A += (rm[:, None] * stride_am + rak[None, :] * stride_ak)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                # pointers to sparse matrix
         | 
| 188 | 
            +
                rn = tl.arange(0, BLOCK_N)
         | 
| 189 | 
            +
                rbk = tl.arange(0, BLOCK_K)
         | 
| 190 | 
            +
                B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                # do matrix multiplication
         | 
| 193 | 
            +
                acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
         | 
| 194 | 
            +
                nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                ak_sub_incr = BLOCK_K * stride_ak
         | 
| 199 | 
            +
                ak_block_incr = BLOCK_SIZE * stride_ak
         | 
| 200 | 
            +
                bk_sub_incr = BLOCK_K * stride_bk
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                for k in range(nsub_blocks * (end_inx - start_inx)):
         | 
| 203 | 
            +
                    sub_block_inx = k % nsub_blocks
         | 
| 204 | 
            +
                    block_inx = k // nsub_blocks
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    if trans_B:
         | 
| 207 | 
            +
                        ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
         | 
| 208 | 
            +
                    else:
         | 
| 209 | 
            +
                        ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr
         | 
| 212 | 
            +
                    a = tl.load(ptr_A)
         | 
| 213 | 
            +
                    b = tl.load(ptr_B)
         | 
| 214 | 
            +
                    acc += tl.dot(a, b)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                acc = acc.to(C.dtype.element_ty)
         | 
| 217 | 
            +
                cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
         | 
| 218 | 
            +
                cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
         | 
| 219 | 
            +
                C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn)
         | 
| 220 | 
            +
                tl.store(C, acc, mask=True)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
            def dsd(shape,
         | 
| 223 | 
            +
                    data,
         | 
| 224 | 
            +
                    offsets,
         | 
| 225 | 
            +
                    row_indices,
         | 
| 226 | 
            +
                    column_indices,
         | 
| 227 | 
            +
                    offsets_t,
         | 
| 228 | 
            +
                    column_indices_t,
         | 
| 229 | 
            +
                    block_offsets_t,
         | 
| 230 | 
            +
                    transpose_a,
         | 
| 231 | 
            +
                    rhs,
         | 
| 232 | 
            +
                    out
         | 
| 233 | 
            +
                ):
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                device = rhs.device
         | 
| 236 | 
            +
                trans_A = transpose_a
         | 
| 237 | 
            +
                trans_B = False
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                if rhs.stride(0) > 1 and rhs.stride(1) > 1:
         | 
| 240 | 
            +
                    trans_B = True
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                # checks constraints
         | 
| 243 | 
            +
                assert shape[1] == rhs.shape[0], "incompatible dimensions"
         | 
| 244 | 
            +
                M, K = shape
         | 
| 245 | 
            +
                _, N = rhs.shape
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                _validate_matmul_dims(M, K, N)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                # accumulator types
         | 
| 250 | 
            +
                ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                stride_am, stride_ak = data.stride(1), data.stride(2)
         | 
| 253 | 
            +
                stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
         | 
| 254 | 
            +
                a_column_indices  = column_indices
         | 
| 255 | 
            +
                a_offsets = offsets
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                # launch kernel
         | 
| 258 | 
            +
                grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                if trans_A:
         | 
| 261 | 
            +
                    stride_am, stride_ak = data.stride(2), data.stride(1)
         | 
| 262 | 
            +
                    a_column_indices, a_offsets = column_indices_t, offsets_t
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                if trans_B:
         | 
| 265 | 
            +
                    stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                _dsd_kernel[grid](
         | 
| 268 | 
            +
                    data.data, rhs, out, M, N, K,
         | 
| 269 | 
            +
                    stride_am, stride_ak,
         | 
| 270 | 
            +
                    stride_bk, stride_bn,
         | 
| 271 | 
            +
                    out.stride(0), out.stride(1),
         | 
| 272 | 
            +
                    row_indices, a_column_indices, a_offsets,
         | 
| 273 | 
            +
                    block_offsets_t, trans_A, trans_B,
         | 
| 274 | 
            +
                    GROUP_M=128, ACC_TYPE=ACC_TYPE
         | 
| 275 | 
            +
                )
         | 
| 276 | 
            +
                # return out
         | 
| 277 | 
            +
             | 
| 278 | 
            +
            def dds(lhs,
         | 
| 279 | 
            +
                    shape,
         | 
| 280 | 
            +
                    data,
         | 
| 281 | 
            +
                    offsets,
         | 
| 282 | 
            +
                    row_indices,
         | 
| 283 | 
            +
                    column_indices,
         | 
| 284 | 
            +
                    offsets_t,
         | 
| 285 | 
            +
                    column_indices_t,
         | 
| 286 | 
            +
                    block_offsets_t,
         | 
| 287 | 
            +
                    transpose_b,
         | 
| 288 | 
            +
                    out
         | 
| 289 | 
            +
                ):
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                device = lhs.device
         | 
| 292 | 
            +
                trans_B = transpose_b
         | 
| 293 | 
            +
                trans_A = False
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                if lhs.stride(0) > 1 and lhs.stride(1) > 1:
         | 
| 296 | 
            +
                    trans_A = True
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                # checks constraints
         | 
| 299 | 
            +
                assert lhs.shape[1] == shape[0], "incompatible dimensions"
         | 
| 300 | 
            +
                M, K = lhs.shape
         | 
| 301 | 
            +
                _, N = shape
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                _validate_matmul_dims(M, K, N)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                # accumulator types
         | 
| 306 | 
            +
                ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
         | 
| 309 | 
            +
                stride_bk, stride_bn = data.stride(1), data.stride(2)
         | 
| 310 | 
            +
                b_column_indices  = column_indices_t
         | 
| 311 | 
            +
                b_offsets = offsets_t
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                # launch kernel
         | 
| 314 | 
            +
                grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']))
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                if trans_A:
         | 
| 317 | 
            +
                    stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
         | 
| 318 | 
            +
                if trans_B:
         | 
| 319 | 
            +
                    stride_bk, stride_bn = data.stride(2), data.stride(1)
         | 
| 320 | 
            +
                    b_column_indices, b_offsets = column_indices, offsets
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                _dds_kernel[grid](
         | 
| 323 | 
            +
                    lhs, data, out, M, N, K,
         | 
| 324 | 
            +
                    stride_am, stride_ak,
         | 
| 325 | 
            +
                    stride_bk, stride_bn,
         | 
| 326 | 
            +
                    out.stride(0), out.stride(1),
         | 
| 327 | 
            +
                    row_indices, b_column_indices, b_offsets,
         | 
| 328 | 
            +
                    block_offsets_t, trans_A, trans_B,
         | 
| 329 | 
            +
                    GROUP_M=128, ACC_TYPE=ACC_TYPE
         | 
| 330 | 
            +
                )
         | 
| 331 | 
            +
             | 
| 332 | 
            +
            def sdd(lhs,
         | 
| 333 | 
            +
                    rhs,
         | 
| 334 | 
            +
                    shape,
         | 
| 335 | 
            +
                    out,
         | 
| 336 | 
            +
                    offsets,
         | 
| 337 | 
            +
                    row_indices,
         | 
| 338 | 
            +
                    column_indices
         | 
| 339 | 
            +
                ):
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                device = out.device
         | 
| 342 | 
            +
                trans_A = False
         | 
| 343 | 
            +
                trans_B = False
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                if lhs.stride(0) > 1 and lhs.stride(1) > 1:
         | 
| 346 | 
            +
                    trans_A = True
         | 
| 347 | 
            +
                if rhs.stride(0) > 1 and rhs.stride(1) > 1:
         | 
| 348 | 
            +
                    trans_B = True
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                # checks constraints
         | 
| 351 | 
            +
                assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions"
         | 
| 352 | 
            +
                M, K = lhs.shape
         | 
| 353 | 
            +
                _, N = rhs.shape
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                _validate_matmul_dims(M, K, N)
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                # accumulator types
         | 
| 358 | 
            +
                ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                # launch kernel
         | 
| 361 | 
            +
                nnz_blocks = len(row_indices)
         | 
| 362 | 
            +
                grid = lambda META: (nnz_blocks,)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                stride_am, stride_ak = lhs.stride(0), lhs.stride(1)
         | 
| 365 | 
            +
                stride_bk, stride_bn = rhs.stride(0), rhs.stride(1)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                if trans_A:
         | 
| 368 | 
            +
                    stride_am, stride_ak = lhs.stride(1), lhs.stride(0)
         | 
| 369 | 
            +
                if trans_B:
         | 
| 370 | 
            +
                    stride_bk, stride_bn = rhs.stride(1), rhs.stride(0)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                _sdd_kernel[grid](
         | 
| 373 | 
            +
                    lhs, rhs, out, M, N, K,
         | 
| 374 | 
            +
                    stride_am, stride_ak,
         | 
| 375 | 
            +
                    stride_bk, stride_bn,
         | 
| 376 | 
            +
                    out.stride(1), out.stride(2),
         | 
| 377 | 
            +
                    row_indices, column_indices,
         | 
| 378 | 
            +
                    GROUP_M=128, ACC_TYPE=ACC_TYPE
         | 
| 379 | 
            +
                    )
         | 
| 380 | 
            +
             | 
| 381 | 
            +
            @triton.jit
         | 
| 382 | 
            +
            def _row_indices_kernel(offsets, out):
         | 
| 383 | 
            +
                pid = tl.program_id(0)
         | 
| 384 | 
            +
                row_offset = tl.load(offsets + pid)
         | 
| 385 | 
            +
                nnz_blocks = tl.load(offsets + pid + 1) - row_offset
         | 
| 386 | 
            +
                for nnz_block in range(nnz_blocks):
         | 
| 387 | 
            +
                    tl.store(out + row_offset + nnz_block, pid)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
            def row_indices(
         | 
| 390 | 
            +
                shape, data, offsets, column_indices, out
         | 
| 391 | 
            +
            ):
         | 
| 392 | 
            +
                block_rows = len(offsets) - 1
         | 
| 393 | 
            +
                _row_indices_kernel[(block_rows, )](offsets, out)
         | 
    	
        torch-ext/megablocks/stk/matrix.py
    ADDED
    
    | @@ -0,0 +1,329 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # 1. Add heavyweight (data) validation helper.
         | 
| 5 | 
            +
            # 2. Add construction helpers
         | 
| 6 | 
            +
            # 3. Make indentation consistent
         | 
| 7 | 
            +
            # 4. Replace asserts with descriptive errors.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            ##
         | 
| 10 | 
            +
            ### Validation helpers.
         | 
| 11 | 
            +
            ##
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def _validate_matrix(shape, data, row_indices, column_indices, offsets):
         | 
| 15 | 
            +
                # Data should be [nnz, block_size, block_size]
         | 
| 16 | 
            +
                if data.dim() == 1:
         | 
| 17 | 
            +
                    data = torch.reshape(data, [data.numel(), 1, 1])
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                # Blocks should be square.
         | 
| 20 | 
            +
                if data.shape[-2] != data.shape[-1]:
         | 
| 21 | 
            +
                    raise ValueError(
         | 
| 22 | 
            +
                        "Expected square blocking in data. "
         | 
| 23 | 
            +
                        f"Got block shape {[data.shape[-2], data.shape[-1]]}")
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                # Flatten batch dimensions on data - original shape preserved
         | 
| 26 | 
            +
                # in shape argument.
         | 
| 27 | 
            +
                block_size = data.shape[-1]
         | 
| 28 | 
            +
                data = data.view([-1, block_size, block_size])
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                if data.dim() != 3:
         | 
| 31 | 
            +
                    raise ValueError(
         | 
| 32 | 
            +
                        "Expected 3D shape for data (nnz, block, block). "
         | 
| 33 | 
            +
                        f"Got shape {data.dim()}D shape.")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                block_size = data.shape[1]
         | 
| 36 | 
            +
                if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
         | 
| 37 | 
            +
                    raise ValueError(
         | 
| 38 | 
            +
                        "Matrix shape must be dividible by blocking. "
         | 
| 39 | 
            +
                        f"Got shape {shape} with "
         | 
| 40 | 
            +
                        f"{[block_size, block_size]} blocking.")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                if np.prod(shape) < data.numel():
         | 
| 43 | 
            +
                    raise ValueError(
         | 
| 44 | 
            +
                        "Invalid matrix. Number of nonzeros exceeds matrix capacity "
         | 
| 45 | 
            +
                        f"({data.numel()} v. {np.prod(shape)})")
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                if row_indices.dim() != 1:
         | 
| 48 | 
            +
                    raise ValueError(
         | 
| 49 | 
            +
                        f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                if column_indices.dim() != 1:
         | 
| 52 | 
            +
                    raise ValueError(
         | 
| 53 | 
            +
                        f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                if offsets.dim() != 1:
         | 
| 56 | 
            +
                    raise ValueError(
         | 
| 57 | 
            +
                        f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                if row_indices.numel() != data.shape[0]:
         | 
| 60 | 
            +
                    raise ValueError(
         | 
| 61 | 
            +
                        "Expected 1 index per nonzero block. "
         | 
| 62 | 
            +
                        f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                if column_indices.numel() != data.shape[0]:
         | 
| 65 | 
            +
                    raise ValueError(
         | 
| 66 | 
            +
                        "Expected 1 index per nonzero block. "
         | 
| 67 | 
            +
                        f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                block_rows = np.prod(shape[:-1]) / block_size
         | 
| 70 | 
            +
                if offsets.numel() != block_rows + 1:
         | 
| 71 | 
            +
                    raise ValueError(
         | 
| 72 | 
            +
                        "Expected one offset per block row plus one. "
         | 
| 73 | 
            +
                        f"Got {offsets.numel()} offsets with {block_rows} block rows.")
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                is_cuda = (data.is_cuda and
         | 
| 76 | 
            +
                           row_indices.is_cuda and
         | 
| 77 | 
            +
                           column_indices.is_cuda and
         | 
| 78 | 
            +
                           offsets.is_cuda)
         | 
| 79 | 
            +
                is_cpu = (not data.is_cuda and
         | 
| 80 | 
            +
                          not row_indices.is_cuda and
         | 
| 81 | 
            +
                          not column_indices.is_cuda and
         | 
| 82 | 
            +
                          not offsets.is_cuda)
         | 
| 83 | 
            +
                if not (is_cuda or is_cpu):
         | 
| 84 | 
            +
                    raise ValueError(
         | 
| 85 | 
            +
                        "Expected data & meta-data on common device. "
         | 
| 86 | 
            +
                        f"Got data on {data.device}, row_indices on {row_indices.device} "
         | 
| 87 | 
            +
                        f"column_indices on {column_indices.device} and "
         | 
| 88 | 
            +
                        f"offsets on {offsets.device}.")
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                if data.dtype != torch.float16:
         | 
| 91 | 
            +
                    raise ValueError(
         | 
| 92 | 
            +
                        f"Expected float16 data. Got {data.dtype} data.")
         | 
| 93 | 
            +
                if row_indices.dtype != torch.int16:
         | 
| 94 | 
            +
                    raise ValueError(
         | 
| 95 | 
            +
                        f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
         | 
| 96 | 
            +
                if column_indices.dtype != torch.int16:
         | 
| 97 | 
            +
                    raise ValueError(
         | 
| 98 | 
            +
                        f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
         | 
| 99 | 
            +
                if offsets.dtype != torch.int32:
         | 
| 100 | 
            +
                    raise ValueError(
         | 
| 101 | 
            +
                        f"Expected int32 offsets. Got {offsets.dtype} offsets.")
         | 
| 102 | 
            +
                return data
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            def _transpose(size, data, row_indices, column_indices, offsets):
         | 
| 106 | 
            +
                block_columns = size[1] // data.shape[1]
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                # Sort row indices by column indices to get the transposed matrix's
         | 
| 109 | 
            +
                # column indices.
         | 
| 110 | 
            +
                gather_indices = column_indices.argsort()
         | 
| 111 | 
            +
                column_indices_t = row_indices.gather(0, gather_indices)
         | 
| 112 | 
            +
                block_offsets_t = gather_indices.int()
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                # NOTE: Histogram is not implemented for any integer type on CPU. Do
         | 
| 115 | 
            +
                # the histogram in 32-bit float, which can exactly represent 16-bit
         | 
| 116 | 
            +
                # integers.
         | 
| 117 | 
            +
                column_indices_float = column_indices.float()
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
         | 
| 120 | 
            +
                nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
         | 
| 121 | 
            +
                nnz_per_column = nnz_per_column.int()
         | 
| 122 | 
            +
                offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
         | 
| 123 | 
            +
                return column_indices_t, offsets_t, block_offsets_t
         | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
            class Matrix(torch.nn.Module):
         | 
| 127 | 
            +
                """A matrix stored in sparse format.
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                Underlying format is block compressed sparse row (BCSR).
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                TODO(tgale): Make this mirror torch.Tensor API as much as possible.
         | 
| 132 | 
            +
                """
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def __init__(self,
         | 
| 135 | 
            +
                             size,
         | 
| 136 | 
            +
                             data,
         | 
| 137 | 
            +
                             row_indices,
         | 
| 138 | 
            +
                             column_indices,
         | 
| 139 | 
            +
                             offsets,
         | 
| 140 | 
            +
                             column_indices_t=None,
         | 
| 141 | 
            +
                             offsets_t=None,
         | 
| 142 | 
            +
                             block_offsets_t=None):
         | 
| 143 | 
            +
                    super().__init__()
         | 
| 144 | 
            +
                    self._size = size
         | 
| 145 | 
            +
                    self._data = data
         | 
| 146 | 
            +
                    self._row_indices = row_indices
         | 
| 147 | 
            +
                    self._column_indices = column_indices
         | 
| 148 | 
            +
                    self._offsets = offsets
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # Produce the transpose meta-data if it is not passed in.
         | 
| 151 | 
            +
                    if ((column_indices_t is None) or (offsets_t is None) or
         | 
| 152 | 
            +
                        (block_offsets_t is None)):
         | 
| 153 | 
            +
                        column_indices_t, offsets_t, block_offsets_t = _transpose(
         | 
| 154 | 
            +
                            size, data, row_indices, column_indices, offsets)
         | 
| 155 | 
            +
                    self._column_indices_t = column_indices_t
         | 
| 156 | 
            +
                    self._offsets_t = offsets_t
         | 
| 157 | 
            +
                    self._block_offsets_t = block_offsets_t
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    self._transposed = False
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    # Validate that our metadata will not overflow.
         | 
| 162 | 
            +
                    max_dim = np.iinfo(np.int16).max * self.blocking
         | 
| 163 | 
            +
                    if column_indices.dtype == torch.int16:
         | 
| 164 | 
            +
                        if size[0] > max_dim or size[1] > max_dim:
         | 
| 165 | 
            +
                            raise ValueError(
         | 
| 166 | 
            +
                                "Sparse matrix with shape {size} exceeds representable "
         | 
| 167 | 
            +
                                "size with 16-bit indices.")
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def validate(self):
         | 
| 170 | 
            +
                    _validate_matrix(self._size,
         | 
| 171 | 
            +
                                     self._data,
         | 
| 172 | 
            +
                                     self._row_indices,
         | 
| 173 | 
            +
                                     self._column_indices,
         | 
| 174 | 
            +
                                     self._offsets)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    # TODO(tgale): Add heavyweight data validation.
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def to(self, device):
         | 
| 179 | 
            +
                    # TODO(tgale): Handle type conversions here. We
         | 
| 180 | 
            +
                    # need to set the appropriate meta-data type for
         | 
| 181 | 
            +
                    # the given floating-point type.
         | 
| 182 | 
            +
                    self._data = self._data.to(device)
         | 
| 183 | 
            +
                    self._row_indices = self._row_indices.to(device)
         | 
| 184 | 
            +
                    self._column_indices = self._column_indices.to(device)
         | 
| 185 | 
            +
                    self._offsets = self._offsets.to(device)
         | 
| 186 | 
            +
                    self._column_indices_t = self._column_indices_t.to(device)
         | 
| 187 | 
            +
                    self._offsets_t = self._offsets_t.to(device)
         | 
| 188 | 
            +
                    self._block_offsets_t = self._block_offsets_t.to(device)
         | 
| 189 | 
            +
                    return self
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def cuda(self):
         | 
| 192 | 
            +
                    return self.to(torch.cuda.current_device())
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def clone(self):
         | 
| 195 | 
            +
                    return Matrix(
         | 
| 196 | 
            +
                        self.size(),
         | 
| 197 | 
            +
                        self.data.clone(),
         | 
| 198 | 
            +
                        self.row_indices.clone(),
         | 
| 199 | 
            +
                        self.column_indices.clone(),
         | 
| 200 | 
            +
                        self.offsets.clone(),
         | 
| 201 | 
            +
                        self.column_indices_t.clone(),
         | 
| 202 | 
            +
                        self.offsets_t.clone(),
         | 
| 203 | 
            +
                        self.block_offsets_t.clone())
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def t(self):
         | 
| 206 | 
            +
                    if self.dim() != 2:
         | 
| 207 | 
            +
                        raise ValueError(
         | 
| 208 | 
            +
                            "t() expects a tensor with <= 2 dimensions, "
         | 
| 209 | 
            +
                            f"but self is {self.dim()}D.")
         | 
| 210 | 
            +
                    out = Matrix(self.size(),
         | 
| 211 | 
            +
                                 self.data,
         | 
| 212 | 
            +
                                 self.row_indices,
         | 
| 213 | 
            +
                                 self.column_indices,
         | 
| 214 | 
            +
                                 self.offsets,
         | 
| 215 | 
            +
                                 self.column_indices_t,
         | 
| 216 | 
            +
                                 self.offsets_t,
         | 
| 217 | 
            +
                                 self.block_offsets_t)
         | 
| 218 | 
            +
                    out._transposed = not self._transposed
         | 
| 219 | 
            +
                    out._size = torch.Size((self._size[1], self._size[0]))
         | 
| 220 | 
            +
                    return out
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def contiguous(self):
         | 
| 223 | 
            +
                    raise ValueError("Not yet implemented.")
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def is_contiguous(self):
         | 
| 226 | 
            +
                    return not self._transposed
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                @property
         | 
| 229 | 
            +
                def is_cuda(self):
         | 
| 230 | 
            +
                    return self._data.is_cuda
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                @property
         | 
| 233 | 
            +
                def device(self):
         | 
| 234 | 
            +
                    return self._data.device
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def size(self):
         | 
| 237 | 
            +
                    return self._size
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                @property
         | 
| 240 | 
            +
                def shape(self):
         | 
| 241 | 
            +
                    return self.size()
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def dim(self):
         | 
| 244 | 
            +
                    return len(self._size)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                @property
         | 
| 247 | 
            +
                def data(self):
         | 
| 248 | 
            +
                    return self._data
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                @property
         | 
| 251 | 
            +
                def row_indices(self):
         | 
| 252 | 
            +
                    return self._row_indices
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                @property
         | 
| 255 | 
            +
                def column_indices(self):
         | 
| 256 | 
            +
                    return self._column_indices
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                @property
         | 
| 259 | 
            +
                def offsets(self):
         | 
| 260 | 
            +
                    return self._offsets
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                @property
         | 
| 263 | 
            +
                def offsets_t(self):
         | 
| 264 | 
            +
                    return self._offsets_t
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                @property
         | 
| 267 | 
            +
                def column_indices_t(self):
         | 
| 268 | 
            +
                    return self._column_indices_t
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                @property
         | 
| 271 | 
            +
                def block_offsets_t(self):
         | 
| 272 | 
            +
                    return self._block_offsets_t
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                @property
         | 
| 275 | 
            +
                def dtype(self):
         | 
| 276 | 
            +
                    return self.data.dtype
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                @property
         | 
| 279 | 
            +
                def nnz(self):
         | 
| 280 | 
            +
                    return self.data.numel()
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                @property
         | 
| 283 | 
            +
                def blocking(self):
         | 
| 284 | 
            +
                    return self.data.shape[1]
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                @property
         | 
| 287 | 
            +
                def requires_grad(self):
         | 
| 288 | 
            +
                    return self.data.requires_grad
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                def requires_grad_(self, x):
         | 
| 291 | 
            +
                    self.data.requires_grad_(x)
         | 
| 292 | 
            +
                    return self
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def view(self, *shape):
         | 
| 295 | 
            +
                    assert self.is_contiguous()
         | 
| 296 | 
            +
                    if shape[-1] != self.size()[-1]:
         | 
| 297 | 
            +
                        raise ValueError(
         | 
| 298 | 
            +
                            "Can't change view on compressed dimension. "
         | 
| 299 | 
            +
                            f"{self.size()[-1]} v. {shape[-1]}.")
         | 
| 300 | 
            +
                    if np.prod(shape) != np.prod(self.size()):
         | 
| 301 | 
            +
                        raise ValueError(
         | 
| 302 | 
            +
                            "Mismatch in numel of Matrix and new shape. "
         | 
| 303 | 
            +
                            f"{np.prod(self.size())} v. {np.prod(shape)}")
         | 
| 304 | 
            +
                    return Matrix(shape,
         | 
| 305 | 
            +
                                  self.data,
         | 
| 306 | 
            +
                                  self.row_indices,
         | 
| 307 | 
            +
                                  self.column_indices,
         | 
| 308 | 
            +
                                  self.offsets,
         | 
| 309 | 
            +
                                  self.column_indices_t,
         | 
| 310 | 
            +
                                  self.offsets_t,
         | 
| 311 | 
            +
                                  self.block_offsets_t)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                @property
         | 
| 314 | 
            +
                def grad(self):
         | 
| 315 | 
            +
                    # TODO(tgale): Make sure this mirrors torch.Tensor
         | 
| 316 | 
            +
                    # behavior in the case where we ask for the gradient
         | 
| 317 | 
            +
                    # of a non-contiguous tensor.
         | 
| 318 | 
            +
                    size = self.size()
         | 
| 319 | 
            +
                    if not self.is_contiguous():
         | 
| 320 | 
            +
                        size = torch.Size((size[1], size[0]))
         | 
| 321 | 
            +
                    out = Matrix(size,
         | 
| 322 | 
            +
                                 self.data.grad,
         | 
| 323 | 
            +
                                 self.row_indices,
         | 
| 324 | 
            +
                                 self.column_indices,
         | 
| 325 | 
            +
                                 self.offsets,
         | 
| 326 | 
            +
                                 self.column_indices_t,
         | 
| 327 | 
            +
                                 self.offsets_t,
         | 
| 328 | 
            +
                                 self.block_offsets_t)
         | 
| 329 | 
            +
                    return out if self.is_contiguous() else out.t()
         | 
    	
        torch-ext/megablocks/stk/ops/__init__.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .linear_ops import dds, dsd, sdd
         | 
| 2 | 
            +
            from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse
         | 
| 3 | 
            +
            from .eltwise_ops import mul
         | 
    	
        torch-ext/megablocks/stk/ops/eltwise_ops.py
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from ..matrix import Matrix
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            def mul(a, b):
         | 
| 4 | 
            +
                """Performs element-wise multiplication of matrices a and b.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
                It is the user's responsibility to make sure that a and b
         | 
| 7 | 
            +
                follow the same matrix topology. This function assumes it is safe
         | 
| 8 | 
            +
                to use the topoplogy of a.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                Args:
         | 
| 11 | 
            +
                    a: stk.Matrix.
         | 
| 12 | 
            +
                    b: stk.Matrix with a's matrix topology.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                Returns:
         | 
| 15 | 
            +
                    stk.Matrix where the entries correspond to torch.mul(a, b).
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                assert isinstance(a, Matrix)
         | 
| 18 | 
            +
                assert isinstance(b, Matrix)
         | 
| 19 | 
            +
                assert a.size() == b.size()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                return Matrix(a.size(),
         | 
| 22 | 
            +
                              a.data * b.data,
         | 
| 23 | 
            +
                              a.row_indices,
         | 
| 24 | 
            +
                              a.column_indices,
         | 
| 25 | 
            +
                              a.offsets,
         | 
| 26 | 
            +
                              a.column_indices_t,
         | 
| 27 | 
            +
                              a.offsets_t,
         | 
| 28 | 
            +
                              a.block_offsets_t)
         | 
    	
        torch-ext/megablocks/stk/ops/eltwise_ops_test.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import unittest
         | 
| 2 | 
            +
            import itertools
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from absl.testing import parameterized
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import stk
         | 
| 7 | 
            +
            from stk.ops.linear_ops_test import allclose, _dense_and_sparse
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            _MATRIX_SIZES = (
         | 
| 10 | 
            +
                (128, 128, 0.0),
         | 
| 11 | 
            +
                (256, 256, 0.5),
         | 
| 12 | 
            +
                (2048, 1024, 0.8),
         | 
| 13 | 
            +
                (512, 128, 0.0),
         | 
| 14 | 
            +
                (128, 512, 0.0),
         | 
| 15 | 
            +
                (1024, 512, 0.0),
         | 
| 16 | 
            +
                (1024, 512, 0.5),
         | 
| 17 | 
            +
                (1024, 512, 0.75),
         | 
| 18 | 
            +
                (512, 1024, 0.0),
         | 
| 19 | 
            +
                (512, 1024, 0.5),
         | 
| 20 | 
            +
                (512, 1024, 0.75),
         | 
| 21 | 
            +
                (1024, 1024, 0.0),
         | 
| 22 | 
            +
                (1024, 1024, 0.5),
         | 
| 23 | 
            +
                (1024, 1024, 0.75),
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            _DTYPE = (
         | 
| 27 | 
            +
                torch.float16, torch.bfloat16
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            def _generate_testcases():
         | 
| 31 | 
            +
                testcases = itertools.product(_MATRIX_SIZES, _DTYPE)
         | 
| 32 | 
            +
                testcases = [(*size, 128, dtype) for 
         | 
| 33 | 
            +
                    (size, dtype) in testcases]
         | 
| 34 | 
            +
                return testcases
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            _ELTWISE_OP_TESTS = _generate_testcases()
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            def _dense_and_sparse_like(x, std=0.1):
         | 
| 39 | 
            +
                dense_data = torch.randn_like(x.data, device=x.device) * std
         | 
| 40 | 
            +
                sparse = stk.Matrix(x.size(),
         | 
| 41 | 
            +
                                    dense_data,
         | 
| 42 | 
            +
                                    x.row_indices,
         | 
| 43 | 
            +
                                    x.column_indices,
         | 
| 44 | 
            +
                                    x.offsets)
         | 
| 45 | 
            +
                dense = stk.ops.to_dense(sparse)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                return (dense.requires_grad_(True),
         | 
| 48 | 
            +
                        sparse.requires_grad_(True))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            @parameterized.parameters(_ELTWISE_OP_TESTS)
         | 
| 51 | 
            +
            class EltwiseOpsTest(parameterized.TestCase):
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def testEltwiseMul(self, m, n, sparsity, blocking, dtype):
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype)
         | 
| 56 | 
            +
                    b_dense, b = _dense_and_sparse_like(a)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    out = stk.ops.mul(a, b)
         | 
| 59 | 
            +
                    expected_out = torch.mul(a_dense, b_dense)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # Compute the gradients w.r.t. the inputs.
         | 
| 62 | 
            +
                    expected_out.sum().backward()
         | 
| 63 | 
            +
                    stk.ops.sum(out).backward()
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # Validate the results.
         | 
| 66 | 
            +
                    out = stk.ops.to_dense(out)
         | 
| 67 | 
            +
                    self.assertEqual(out.dim(), 2)
         | 
| 68 | 
            +
                    self.assertEqual(expected_out.size(), out.size())
         | 
| 69 | 
            +
                    self.assertTrue(allclose(out, expected_out)) 
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # LHS gradient.
         | 
| 72 | 
            +
                    grad = stk.ops.to_dense(a.grad)
         | 
| 73 | 
            +
                    expected_grad = a_dense.grad
         | 
| 74 | 
            +
                    self.assertEqual(grad.dim(), 2)
         | 
| 75 | 
            +
                    self.assertEqual(expected_grad.size(), grad.size())
         | 
| 76 | 
            +
                    self.assertTrue(allclose(grad, expected_grad))
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    # RHS gradient.
         | 
| 79 | 
            +
                    grad =  stk.ops.to_dense(b.grad)
         | 
| 80 | 
            +
                    expected_grad = b_dense.grad
         | 
| 81 | 
            +
                    self.assertEqual(grad.dim(), 2)
         | 
| 82 | 
            +
                    self.assertEqual(expected_grad.size(), grad.size())
         | 
| 83 | 
            +
                    self.assertTrue(allclose(grad, expected_grad))
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            if __name__ == '__main__':
         | 
| 86 | 
            +
                unittest.main()
         | 
    	
        torch-ext/megablocks/stk/ops/linear_ops.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from ..backend import sputnik
         | 
| 4 | 
            +
            from ..matrix import Matrix
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def dsd(a, b):
         | 
| 8 | 
            +
                assert isinstance(a, Matrix)
         | 
| 9 | 
            +
                assert isinstance(b, torch.Tensor)
         | 
| 10 | 
            +
                return sputnik.dsd(
         | 
| 11 | 
            +
                    a.size(),
         | 
| 12 | 
            +
                    a.data, a.offsets,
         | 
| 13 | 
            +
                    a.row_indices,
         | 
| 14 | 
            +
                    a.column_indices,
         | 
| 15 | 
            +
                    a.offsets_t,
         | 
| 16 | 
            +
                    a.column_indices_t,
         | 
| 17 | 
            +
                    a.block_offsets_t,
         | 
| 18 | 
            +
                    not a.is_contiguous(),
         | 
| 19 | 
            +
                    b)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def dds(a, b):
         | 
| 23 | 
            +
                assert isinstance(a, torch.Tensor)
         | 
| 24 | 
            +
                assert isinstance(b, Matrix)
         | 
| 25 | 
            +
                return sputnik.dds(
         | 
| 26 | 
            +
                    a,
         | 
| 27 | 
            +
                    b.size(),
         | 
| 28 | 
            +
                    b.data, b.offsets,
         | 
| 29 | 
            +
                    b.row_indices,
         | 
| 30 | 
            +
                    b.column_indices,
         | 
| 31 | 
            +
                    b.offsets_t,
         | 
| 32 | 
            +
                    b.column_indices_t,
         | 
| 33 | 
            +
                    b.block_offsets_t,
         | 
| 34 | 
            +
                    not b.is_contiguous())
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def sdd(a, b, topo):
         | 
| 38 | 
            +
                assert isinstance(a, torch.Tensor)
         | 
| 39 | 
            +
                assert isinstance(b, torch.Tensor)
         | 
| 40 | 
            +
                assert isinstance(topo, Matrix)
         | 
| 41 | 
            +
                assert topo.is_contiguous()
         | 
| 42 | 
            +
                out = sputnik.sdd(
         | 
| 43 | 
            +
                    a, b,
         | 
| 44 | 
            +
                    topo.size(),
         | 
| 45 | 
            +
                    topo.data,
         | 
| 46 | 
            +
                    topo.offsets,
         | 
| 47 | 
            +
                    topo.row_indices,
         | 
| 48 | 
            +
                    topo.column_indices,
         | 
| 49 | 
            +
                    topo.offsets_t,
         | 
| 50 | 
            +
                    topo.column_indices_t,
         | 
| 51 | 
            +
                    topo.block_offsets_t)
         | 
| 52 | 
            +
                return Matrix(topo.size(),
         | 
| 53 | 
            +
                              out,
         | 
| 54 | 
            +
                              topo.row_indices,
         | 
| 55 | 
            +
                              topo.column_indices,
         | 
| 56 | 
            +
                              topo.offsets,
         | 
| 57 | 
            +
                              topo.column_indices_t,
         | 
| 58 | 
            +
                              topo.offsets_t,
         | 
| 59 | 
            +
                              topo.block_offsets_t)
         | 
    	
        torch-ext/megablocks/stk/ops/linear_ops_test.py
    ADDED
    
    | @@ -0,0 +1,216 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import unittest
         | 
| 2 | 
            +
            import itertools
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from absl.testing import parameterized
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import stk
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def allclose(x, y, pct=0.25):
         | 
| 11 | 
            +
                mask = torch.isclose(x, y, rtol=5e-2)
         | 
| 12 | 
            +
                pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
         | 
| 13 | 
            +
                if pct_diff > pct:
         | 
| 14 | 
            +
                    print("{:.2f}% of values not close.".format(pct_diff))
         | 
| 15 | 
            +
                    return False
         | 
| 16 | 
            +
                return True
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            # An assortment of problems designed to make sure
         | 
| 20 | 
            +
            # the bindings are operating correctly.
         | 
| 21 | 
            +
            _MATRIX_SIZES = (
         | 
| 22 | 
            +
                (128, 128, 128, 0.0),
         | 
| 23 | 
            +
                (256, 256, 256, 0.5),
         | 
| 24 | 
            +
                (2048, 1024, 512, 0.8),
         | 
| 25 | 
            +
                (512, 128, 128, 0.0),
         | 
| 26 | 
            +
                (128, 128, 512, 0.0),
         | 
| 27 | 
            +
                (1024, 512, 512, 0.0),
         | 
| 28 | 
            +
                (1024, 512, 512, 0.5),
         | 
| 29 | 
            +
                (1024, 512, 512, 0.75),
         | 
| 30 | 
            +
                (512, 512, 1024, 0.0),
         | 
| 31 | 
            +
                (512, 512, 1024, 0.5),
         | 
| 32 | 
            +
                (512, 512, 1024, 0.75),
         | 
| 33 | 
            +
                (1024, 1024, 1024, 0.0),
         | 
| 34 | 
            +
                (1024, 1024, 1024, 0.5),
         | 
| 35 | 
            +
                (1024, 1024, 1024, 0.75),
         | 
| 36 | 
            +
            )
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            _TRANSPOSE = (
         | 
| 39 | 
            +
                (False, False),
         | 
| 40 | 
            +
                (False, True),
         | 
| 41 | 
            +
                (True, False),
         | 
| 42 | 
            +
                (True, True),
         | 
| 43 | 
            +
            )
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            _DTYPE = (
         | 
| 46 | 
            +
                torch.float16, torch.bfloat16
         | 
| 47 | 
            +
            )
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            def _generate_testcases():
         | 
| 50 | 
            +
                testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE)
         | 
| 51 | 
            +
                testcases = [(*size, *trans, 128, dtype) for 
         | 
| 52 | 
            +
                    (size, trans, dtype) in testcases]
         | 
| 53 | 
            +
                return testcases
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            _LINEAR_OP_TESTS = _generate_testcases()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1):
         | 
| 58 | 
            +
                mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
         | 
| 59 | 
            +
                dense = (torch.randn(rows, cols) * std * mask).type(dtype)
         | 
| 60 | 
            +
                sparse = stk.ops.to_sparse(dense, blocking)
         | 
| 61 | 
            +
                cuda_device = torch.device("cuda")
         | 
| 62 | 
            +
                return (dense.to(cuda_device).requires_grad_(True),
         | 
| 63 | 
            +
                        sparse.to(cuda_device).requires_grad_(True))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def _dense(rows, cols, dtype, std=0.1):
         | 
| 67 | 
            +
                cuda_device = torch.device("cuda")
         | 
| 68 | 
            +
                out = (torch.randn(rows, cols) * std).type(dtype)
         | 
| 69 | 
            +
                return out.to(cuda_device).requires_grad_(True)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def _dense_2x(rows, cols, dtype):
         | 
| 73 | 
            +
                a = _dense(rows, cols, dtype)
         | 
| 74 | 
            +
                return a, a.detach().requires_grad_(True)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def _with_transpose(op, a, b, trans_a, trans_b):
         | 
| 78 | 
            +
                a = a.t() if trans_a else a
         | 
| 79 | 
            +
                b = b.t() if trans_b else b
         | 
| 80 | 
            +
                return op(a, b)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def _mmm(a, b, topo):
         | 
| 84 | 
            +
                mask = stk.ops.to_dense(stk.ops.ones_like(topo))
         | 
| 85 | 
            +
                return torch.mm(a, b) * mask
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b):
         | 
| 89 | 
            +
                a = a.t() if trans_a else a
         | 
| 90 | 
            +
                b = b.t() if trans_b else b
         | 
| 91 | 
            +
                return op(a, b, topo)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def _mask(x, mask):
         | 
| 95 | 
            +
                mask = stk.ops.to_dense(stk.ops.ones_like(mask))
         | 
| 96 | 
            +
                return x * mask
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            @parameterized.parameters(*_LINEAR_OP_TESTS)
         | 
| 100 | 
            +
            class LinearOpsTest(parameterized.TestCase):
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
         | 
| 103 | 
            +
                    # Construct the operands.
         | 
| 104 | 
            +
                    a_shape = (k, m) if trans_a else (m, k)
         | 
| 105 | 
            +
                    a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype)
         | 
| 106 | 
            +
                    b_shape = (n, k) if trans_b else (k, n)
         | 
| 107 | 
            +
                    b, bcp = _dense_2x(*b_shape, dtype)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # Execute the matmul.
         | 
| 110 | 
            +
                    out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b)
         | 
| 111 | 
            +
                    expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    # Compute the gradients w.r.t. the inputs.
         | 
| 114 | 
            +
                    expected_out.sum().backward()
         | 
| 115 | 
            +
                    out.sum().backward()
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    # Validate the results.
         | 
| 118 | 
            +
                    self.assertEqual(out.dim(), 2)
         | 
| 119 | 
            +
                    self.assertEqual(expected_out.size()[0], out.size()[0])
         | 
| 120 | 
            +
                    self.assertEqual(expected_out.size()[1], out.size()[1])
         | 
| 121 | 
            +
                    self.assertTrue(allclose(out, expected_out))
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # LHS gradient.
         | 
| 124 | 
            +
                    grad = stk.ops.to_dense(a.grad)
         | 
| 125 | 
            +
                    expected_grad = _mask(a_dense.grad, a.grad)
         | 
| 126 | 
            +
                    self.assertEqual(grad.dim(), 2)
         | 
| 127 | 
            +
                    self.assertEqual(expected_grad.size()[0], grad.size()[0])
         | 
| 128 | 
            +
                    self.assertEqual(expected_grad.size()[1], grad.size()[1])
         | 
| 129 | 
            +
                    self.assertTrue(allclose(grad, expected_grad))
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    # RHS gradient.
         | 
| 132 | 
            +
                    grad = b.grad
         | 
| 133 | 
            +
                    expected_grad = bcp.grad
         | 
| 134 | 
            +
                    self.assertEqual(grad.dim(), 2)
         | 
| 135 | 
            +
                    self.assertEqual(expected_grad.size()[0], grad.size()[0])
         | 
| 136 | 
            +
                    self.assertEqual(expected_grad.size()[1], grad.size()[1])
         | 
| 137 | 
            +
                    self.assertTrue(allclose(grad, expected_grad))
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
         | 
| 140 | 
            +
                    # Construct the operands.
         | 
| 141 | 
            +
                    a_shape = (k, m) if trans_a else (m, k)
         | 
| 142 | 
            +
                    a, acp = _dense_2x(*a_shape, dtype)
         | 
| 143 | 
            +
                    b_shape = (n, k) if trans_b else (k, n)
         | 
| 144 | 
            +
                    b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # Execute the matmul.
         | 
| 147 | 
            +
                    out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b)
         | 
| 148 | 
            +
                    expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # Compute the gradients w.r.t. the inputs.
         | 
| 151 | 
            +
                    expected_out.sum().backward()
         | 
| 152 | 
            +
                    out.sum().backward()
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    # Validate the results.
         | 
| 155 | 
            +
                    self.assertEqual(out.dim(), 2)
         | 
| 156 | 
            +
                    self.assertEqual(expected_out.size()[0], out.size()[0])
         | 
| 157 | 
            +
                    self.assertEqual(expected_out.size()[1], out.size()[1])
         | 
| 158 | 
            +
                    self.assertTrue(allclose(out, expected_out))
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    # LHS gradient.
         | 
| 161 | 
            +
                    grad = a.grad
         | 
| 162 | 
            +
                    expected_grad = acp.grad
         | 
| 163 | 
            +
                    self.assertEqual(grad.dim(), 2)
         | 
| 164 | 
            +
                    self.assertEqual(expected_grad.size()[0], grad.size()[0])
         | 
| 165 | 
            +
                    self.assertEqual(expected_grad.size()[1], grad.size()[1])
         | 
| 166 | 
            +
                    self.assertTrue(allclose(grad, expected_grad))
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # RHS gradient.
         | 
| 169 | 
            +
                    grad = stk.ops.to_dense(b.grad)
         | 
| 170 | 
            +
                    expected_grad = _mask(b_dense.grad, b.grad)
         | 
| 171 | 
            +
                    self.assertEqual(grad.dim(), 2)
         | 
| 172 | 
            +
                    self.assertEqual(expected_grad.size()[0], grad.size()[0])
         | 
| 173 | 
            +
                    self.assertEqual(expected_grad.size()[1], grad.size()[1])
         | 
| 174 | 
            +
                    self.assertTrue(allclose(grad, expected_grad))
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype):
         | 
| 177 | 
            +
                    # Construct the operands.
         | 
| 178 | 
            +
                    a_shape = (k, m) if trans_a else (m, k)
         | 
| 179 | 
            +
                    a, acp = _dense_2x(*a_shape, dtype)
         | 
| 180 | 
            +
                    b_shape = (n, k) if trans_b else (k, n)
         | 
| 181 | 
            +
                    b, bcp = _dense_2x(*b_shape, dtype)
         | 
| 182 | 
            +
                    _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    # Execute the matmul.
         | 
| 185 | 
            +
                    out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b)
         | 
| 186 | 
            +
                    expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # Compute the gradients w.r.t. the inputs.
         | 
| 189 | 
            +
                    expected_out.sum().backward()
         | 
| 190 | 
            +
                    stk.ops.sum(out).backward()
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    # Validate the results.
         | 
| 193 | 
            +
                    out = stk.ops.to_dense(out)
         | 
| 194 | 
            +
                    self.assertEqual(out.dim(), 2)
         | 
| 195 | 
            +
                    self.assertEqual(expected_out.size()[0], out.size()[0])
         | 
| 196 | 
            +
                    self.assertEqual(expected_out.size()[1], out.size()[1])
         | 
| 197 | 
            +
                    self.assertTrue(allclose(out, expected_out))
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    # LHS gradient.
         | 
| 200 | 
            +
                    grad = a.grad
         | 
| 201 | 
            +
                    expected_grad = acp.grad
         | 
| 202 | 
            +
                    self.assertEqual(grad.dim(), 2)
         | 
| 203 | 
            +
                    self.assertEqual(expected_grad.size()[0], grad.size()[0])
         | 
| 204 | 
            +
                    self.assertEqual(expected_grad.size()[1], grad.size()[1])
         | 
| 205 | 
            +
                    self.assertTrue(allclose(grad, expected_grad))
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # RHS gradient.
         | 
| 208 | 
            +
                    grad = b.grad
         | 
| 209 | 
            +
                    expected_grad = bcp.grad
         | 
| 210 | 
            +
                    self.assertEqual(grad.dim(), 2)
         | 
| 211 | 
            +
                    self.assertEqual(expected_grad.size()[0], grad.size()[0])
         | 
| 212 | 
            +
                    self.assertEqual(expected_grad.size()[1], grad.size()[1])
         | 
| 213 | 
            +
                    self.assertTrue(allclose(grad, expected_grad))
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            if __name__ == '__main__':
         | 
| 216 | 
            +
                unittest.main()
         | 
    	
        torch-ext/megablocks/stk/ops/matrix_ops.py
    ADDED
    
    | @@ -0,0 +1,98 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from ..backend import sputnik
         | 
| 2 | 
            +
            from ..matrix import Matrix
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            @torch.no_grad()
         | 
| 8 | 
            +
            def row_indices(shape, data, offsets, column_indices):
         | 
| 9 | 
            +
                return sputnik.row_indices(shape, data, offsets, column_indices)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            # TODO(tgale): Replace this helper with a custom kernel. This operation
         | 
| 13 | 
            +
            # is much simpler to do than how it's currently implemented.
         | 
| 14 | 
            +
            @torch.no_grad()
         | 
| 15 | 
            +
            def _expand_for_blocking(idxs, blocking):
         | 
| 16 | 
            +
                # Duplicate for block column dimension.
         | 
| 17 | 
            +
                idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                # Update the column indices.
         | 
| 20 | 
            +
                idxs[:, :, 1] *= blocking
         | 
| 21 | 
            +
                idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking])
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                # Duplicate for block row dimension.
         | 
| 24 | 
            +
                idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2])
         | 
| 25 | 
            +
                idxs = idxs.repeat(1, blocking, 1, 1)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # Update the row indices.
         | 
| 28 | 
            +
                idxs[:, :, :, 0] *= blocking
         | 
| 29 | 
            +
                idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1])
         | 
| 30 | 
            +
                idxs = torch.reshape(idxs, [-1, 2])
         | 
| 31 | 
            +
                return idxs
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            # TODO(tgale): Add input type checking.
         | 
| 35 | 
            +
            @torch.no_grad()
         | 
| 36 | 
            +
            def to_dense(x):
         | 
| 37 | 
            +
                assert isinstance(x, Matrix)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                shape = (np.prod(x.shape[:-1]), x.shape[-1])
         | 
| 40 | 
            +
                row_idxs = x.row_indices.type(torch.int32)
         | 
| 41 | 
            +
                col_idxs = x.column_indices.type(torch.int32)
         | 
| 42 | 
            +
                indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking)
         | 
| 43 | 
            +
                indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device)
         | 
| 46 | 
            +
                out.scatter_(0, indices, x.data.flatten())
         | 
| 47 | 
            +
                return out.reshape(x.size())
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            @torch.no_grad()
         | 
| 51 | 
            +
            def _mask(x, blocking=1):
         | 
| 52 | 
            +
                assert x.dim() == 2
         | 
| 53 | 
            +
                assert x.size()[0] % blocking == 0
         | 
| 54 | 
            +
                assert x.size()[1] % blocking == 0
         | 
| 55 | 
            +
                block_rows = x.size()[0] // blocking
         | 
| 56 | 
            +
                block_cols = x.size()[1] // blocking
         | 
| 57 | 
            +
                x = torch.reshape(x, [block_rows, blocking, block_cols, blocking])
         | 
| 58 | 
            +
                x = torch.sum(torch.abs(x), dim=(1, 3))
         | 
| 59 | 
            +
                return x != 0
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            # TODO(tgale): Add input type checking.
         | 
| 63 | 
            +
            @torch.no_grad()
         | 
| 64 | 
            +
            def to_sparse(x, blocking=1):
         | 
| 65 | 
            +
                m = _mask(x, blocking)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                # TODO(tgale): Set to appropriate type for input matrix.
         | 
| 68 | 
            +
                row_nnzs = torch.sum(m, dim=1).type(torch.int32)
         | 
| 69 | 
            +
                zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device)
         | 
| 70 | 
            +
                offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)])
         | 
| 71 | 
            +
                offsets = offsets.type(torch.int32)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                indices = torch.nonzero(m).type(torch.int16)
         | 
| 74 | 
            +
                row_indices = indices[:, 0]
         | 
| 75 | 
            +
                column_indices = indices[:, 1]
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                # Nonzero indices in the dense matrix.
         | 
| 78 | 
            +
                nonzero_indices = torch.nonzero(m)
         | 
| 79 | 
            +
                nonzero_indices = _expand_for_blocking(nonzero_indices, blocking)
         | 
| 80 | 
            +
                nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1]
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                # Gather the data and construct the sparse matrix.
         | 
| 83 | 
            +
                data = torch.gather(x.flatten(), dim=0, index=nonzero_indices)
         | 
| 84 | 
            +
                data = torch.reshape(data, [-1, blocking, blocking])
         | 
| 85 | 
            +
                return Matrix(x.size(), data, row_indices, column_indices, offsets)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            @torch.no_grad()
         | 
| 89 | 
            +
            def ones_like(x):
         | 
| 90 | 
            +
                return Matrix(x.size(),
         | 
| 91 | 
            +
                              torch.ones_like(x.data),
         | 
| 92 | 
            +
                              x.row_indices,
         | 
| 93 | 
            +
                              x.column_indices, x.offsets)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def sum(x):
         | 
| 97 | 
            +
                assert isinstance(x, Matrix)
         | 
| 98 | 
            +
                return x.data.sum()
         | 
    	
        torch-ext/megablocks/stk/ops/matrix_ops_test.py
    ADDED
    
    | @@ -0,0 +1,62 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import unittest
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from absl.testing import parameterized
         | 
| 4 | 
            +
            import stk
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            @parameterized.parameters(
         | 
| 9 | 
            +
                (8, 16, 0.0, 1),
         | 
| 10 | 
            +
                (8, 16, 0.5, 1),
         | 
| 11 | 
            +
                (8, 16, .95, 1),
         | 
| 12 | 
            +
                (16, 8, 0.0, 1),
         | 
| 13 | 
            +
                (16, 8, 0.5, 1),
         | 
| 14 | 
            +
                (16, 8, .95, 1),
         | 
| 15 | 
            +
                (8, 16, 0.0, 8),
         | 
| 16 | 
            +
                (8, 16, 0.5, 8),
         | 
| 17 | 
            +
                (8, 16, 1.0, 8),
         | 
| 18 | 
            +
                (16, 8, 0.0, 8),
         | 
| 19 | 
            +
                (16, 8, 0.5, 8),
         | 
| 20 | 
            +
                (16, 8, 1.0, 8),
         | 
| 21 | 
            +
                (128, 256, 0.5, 16),
         | 
| 22 | 
            +
                (256, 128, 0.75, 32),
         | 
| 23 | 
            +
                (512, 512, .875, 128))
         | 
| 24 | 
            +
            class MatrixOpsTest(parameterized.TestCase):
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking):
         | 
| 27 | 
            +
                    mask = stk.random.dense_mask(rows, cols, sparsity, blocking)
         | 
| 28 | 
            +
                    x = (torch.randn(rows, cols) * mask).type(torch.float16)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    # Convert the matrix to sparse format.
         | 
| 31 | 
            +
                    sparse_x = stk.ops.to_sparse(x, blocking)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    # Validate the matrix.
         | 
| 34 | 
            +
                    sparse_x.validate()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    # Validate the shape.
         | 
| 37 | 
            +
                    self.assertEqual(sparse_x.dim(), 2)
         | 
| 38 | 
            +
                    self.assertEqual(sparse_x.size()[0], rows)
         | 
| 39 | 
            +
                    self.assertEqual(sparse_x.size()[1], cols)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    # Validate the sparsity.
         | 
| 42 | 
            +
                    numblocks = rows // blocking * cols // blocking
         | 
| 43 | 
            +
                    nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
         | 
| 44 | 
            +
                    self.assertEqual(sparse_x.nnz, nnz)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    # Convert back to dense format.
         | 
| 47 | 
            +
                    dense_x = stk.ops.to_dense(sparse_x)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    # Validate the shape.
         | 
| 50 | 
            +
                    self.assertEqual(dense_x.dim(), 2)
         | 
| 51 | 
            +
                    self.assertEqual(dense_x.size()[0], rows)
         | 
| 52 | 
            +
                    self.assertEqual(dense_x.size()[1], cols)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # Validate the sparsity
         | 
| 55 | 
            +
                    self.assertEqual(torch.count_nonzero(dense_x).item(), nnz)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    # Validate the output.
         | 
| 58 | 
            +
                    self.assertTrue(torch.all(torch.eq(x, dense_x)))
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            if __name__ == '__main__':
         | 
| 62 | 
            +
                unittest.main()
         | 
    	
        torch-ext/megablocks/stk/random/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # from stk.random.random_ops import dense_mask, mask, randn
         | 
| 2 | 
            +
            from .random_ops import dense_mask, mask, randn
         | 
    	
        torch-ext/megablocks/stk/random/random_ops.py
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from ..ops import matrix_ops
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            @torch.no_grad()
         | 
| 7 | 
            +
            def dense_mask(rows, cols, sparsity, blocking=1):
         | 
| 8 | 
            +
              assert sparsity >= 0.0 and sparsity <= 1.0
         | 
| 9 | 
            +
              assert rows % blocking == 0 and cols % blocking == 0
         | 
| 10 | 
            +
             | 
| 11 | 
            +
              block_rows, block_cols = (rows // blocking, cols // blocking)
         | 
| 12 | 
            +
              nnz = round(block_rows * block_cols * (1 - sparsity))
         | 
| 13 | 
            +
             | 
| 14 | 
            +
              out = np.ones(block_rows * block_cols)
         | 
| 15 | 
            +
              mask = np.random.choice(out.size, out.size - nnz, replace=False)
         | 
| 16 | 
            +
              out[mask] = 0.0
         | 
| 17 | 
            +
             | 
| 18 | 
            +
              out = np.tile(
         | 
| 19 | 
            +
                np.reshape(out, [block_rows, 1, block_cols, 1]),
         | 
| 20 | 
            +
                (1, blocking, 1, blocking))
         | 
| 21 | 
            +
              out = np.reshape(out, [rows, cols])
         | 
| 22 | 
            +
              return torch.from_numpy(out.astype(np.float32))
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            @torch.no_grad()
         | 
| 26 | 
            +
            def mask(m, n, sparsity, blocking=1):
         | 
| 27 | 
            +
                out = dense_mask(m, n, sparsity, blocking).type(torch.float16)
         | 
| 28 | 
            +
                return matrix_ops.to_sparse(out, blocking=blocking)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            @torch.no_grad()
         | 
| 32 | 
            +
            def randn(shape, sparsity, blocking=1):
         | 
| 33 | 
            +
              shape_2d = (np.prod(shape[:-1]), shape[-1])
         | 
| 34 | 
            +
              out = mask(*shape_2d, sparsity, blocking)
         | 
| 35 | 
            +
              out.data.copy_(torch.randn(*out.data.shape))
         | 
| 36 | 
            +
              return out.view(*shape)
         | 
    	
        torch-ext/megablocks/stk/random/random_ops_test.py
    ADDED
    
    | @@ -0,0 +1,73 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import unittest
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from absl.testing import parameterized
         | 
| 4 | 
            +
            from . import random
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            @parameterized.parameters(
         | 
| 9 | 
            +
                (8, 16, 0.0, 1),
         | 
| 10 | 
            +
                (8, 16, 0.5, 1),
         | 
| 11 | 
            +
                (8, 16, .95, 1),
         | 
| 12 | 
            +
                (16, 8, 0.0, 1),
         | 
| 13 | 
            +
                (16, 8, 0.5, 1),
         | 
| 14 | 
            +
                (16, 8, .95, 1),
         | 
| 15 | 
            +
                (8, 16, 0.0, 8),
         | 
| 16 | 
            +
                (8, 16, 0.5, 8),
         | 
| 17 | 
            +
                (8, 16, 1.0, 8),
         | 
| 18 | 
            +
                (16, 8, 0.0, 8),
         | 
| 19 | 
            +
                (16, 8, 0.5, 8),
         | 
| 20 | 
            +
                (16, 8, 1.0, 8),
         | 
| 21 | 
            +
                (128, 256, 0.5, 16),
         | 
| 22 | 
            +
                (256, 128, 0.75, 32),
         | 
| 23 | 
            +
                (512, 512, .875, 128))
         | 
| 24 | 
            +
            class RandomOpsTest(parameterized.TestCase):
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking):
         | 
| 27 | 
            +
                    mask = random.dense_mask(
         | 
| 28 | 
            +
                        rows, cols, sparsity, blocking)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    # Validate the shape.
         | 
| 31 | 
            +
                    self.assertEqual(mask.dim(), 2)
         | 
| 32 | 
            +
                    self.assertEqual(mask.size()[0], rows)
         | 
| 33 | 
            +
                    self.assertEqual(mask.size()[1], cols)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    # Validate the sparsity
         | 
| 36 | 
            +
                    numblocks = rows // blocking * cols // blocking
         | 
| 37 | 
            +
                    nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
         | 
| 38 | 
            +
                    self.assertEqual(
         | 
| 39 | 
            +
                        torch.count_nonzero(mask).item(),
         | 
| 40 | 
            +
                        nnz)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    # Check values are zero or one.
         | 
| 43 | 
            +
                    self.assertTrue(
         | 
| 44 | 
            +
                        torch.all(torch.logical_or(
         | 
| 45 | 
            +
                            torch.eq(mask, 0),
         | 
| 46 | 
            +
                            torch.eq(mask, 1))))
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking):
         | 
| 49 | 
            +
                    mask = random.mask(
         | 
| 50 | 
            +
                        rows, cols, sparsity, blocking)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # Validate the matrix.
         | 
| 53 | 
            +
                    mask.validate()
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Validate the shape.
         | 
| 56 | 
            +
                    self.assertEqual(mask.dim(), 2)
         | 
| 57 | 
            +
                    self.assertEqual(mask.size()[0], rows)
         | 
| 58 | 
            +
                    self.assertEqual(mask.size()[1], cols)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # Validate the sparsity.
         | 
| 61 | 
            +
                    numblocks = rows // blocking * cols // blocking
         | 
| 62 | 
            +
                    nnz = round(numblocks * (1 - sparsity)) * blocking ** 2
         | 
| 63 | 
            +
                    self.assertEqual(mask.nnz, nnz)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    # Check values are zero or one.
         | 
| 66 | 
            +
                    self.assertTrue(
         | 
| 67 | 
            +
                        torch.all(torch.logical_or(
         | 
| 68 | 
            +
                            torch.eq(mask.data, 0),
         | 
| 69 | 
            +
                            torch.eq(mask.data, 1))))
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            if __name__ == '__main__':
         | 
| 73 | 
            +
                unittest.main()
         | 
