danieldk HF Staff commited on
Commit
905c6c3
·
1 Parent(s): f1b9f1b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py +14 -0
  2. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_bft6nicqkg6ni.abi3.so +3 -0
  3. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py +9 -0
  4. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
  5. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/distributed_utils.py +144 -0
  6. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +326 -0
  7. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
  8. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/config_mamba.py +18 -0
  9. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +338 -0
  10. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
  11. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/block.py +107 -0
  12. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2.py +502 -0
  13. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2_simple.py +229 -0
  14. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba_simple.py +339 -0
  15. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mha.py +294 -0
  16. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mlp.py +34 -0
  17. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/ssd_minimal.py +111 -0
  18. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/__init__.py +0 -0
  19. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +659 -0
  20. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/__init__.py +0 -0
  21. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/k_activations.py +169 -0
  22. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +1166 -0
  23. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layernorm_gated.py +437 -0
  24. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +389 -0
  25. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/softplus.py +15 -0
  26. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_bmm.py +262 -0
  27. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
  28. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +2012 -0
  29. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +1884 -0
  30. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_state_passing.py +348 -0
  31. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/__init__.py +0 -0
  32. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/generation.py +390 -0
  33. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/hf.py +23 -0
  34. build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/torch.py +21 -0
  35. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/__init__.py +14 -0
  36. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_nmrmresto7zfi.abi3.so +3 -0
  37. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_ops.py +9 -0
  38. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
  39. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/distributed_utils.py +144 -0
  40. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +326 -0
  41. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
  42. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/config_mamba.py +18 -0
  43. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +338 -0
  44. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
  45. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/block.py +107 -0
  46. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba2.py +502 -0
  47. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba2_simple.py +229 -0
  48. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba_simple.py +339 -0
  49. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mha.py +294 -0
  50. build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mlp.py +34 -0
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "2.2.4"
2
+
3
+ from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from .modules.mamba_simple import Mamba
5
+ from .modules.mamba2 import Mamba2
6
+ from .models.mixer_seq_simple import MambaLMHeadModel
7
+
8
+ __all__ = [
9
+ "selective_scan_fn",
10
+ "mamba_inner_fn",
11
+ "Mamba",
12
+ "Mamba2",
13
+ "MambaLMHeadModel",
14
+ ]
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_bft6nicqkg6ni.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0e8bc801359703c8d092b7c8c9906bd59c083d94e6778b621ba709d79fff5a0
3
+ size 258973648
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _mamba_ssm_bft6nicqkg6ni
3
+ ops = torch.ops._mamba_ssm_bft6nicqkg6ni
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_mamba_ssm_bft6nicqkg6ni::{op_name}"
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py ADDED
File without changes
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/distributed_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.distributed import ProcessGroup
6
+
7
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
+ # version of PyTorch. The following 4 lines are for backward compatibility with
10
+ # older PyTorch.
11
+ if "all_gather_into_tensor" not in dir(torch.distributed):
12
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
+ if "reduce_scatter_tensor" not in dir(torch.distributed):
14
+ torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
+
16
+
17
+ # Raw operation, does not support autograd, but does support async
18
+ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
+ world_size = torch.distributed.get_world_size(process_group)
20
+ output = torch.empty(
21
+ world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
+ )
23
+ handle = torch.distributed.all_gather_into_tensor(
24
+ output, input_.contiguous(), group=process_group, async_op=async_op
25
+ )
26
+ return output, handle
27
+
28
+
29
+ # Raw operation, does not support autograd, but does support async
30
+ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
+ world_size = torch.distributed.get_world_size(process_group)
32
+ assert input_.shape[0] % world_size == 0
33
+ output = torch.empty(
34
+ input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
+ )
36
+ handle = torch.distributed.reduce_scatter_tensor(
37
+ output, input_.contiguous(), group=process_group, async_op=async_op
38
+ )
39
+ return output, handle
40
+
41
+
42
+ # Raw operation, does not support autograd, but does support async
43
+ def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
+ input_ = input_.contiguous()
45
+ handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
+ return input_, handle
47
+
48
+
49
+ class AllGatherFunc(torch.autograd.Function):
50
+ """Gather the input from sequence parallel region and concatenate."""
51
+
52
+ @staticmethod
53
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
+ ctx.process_group = process_group
55
+ output, _ = all_gather_raw(input_, process_group)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output: Tensor):
60
+ grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
+ return grad_input, None
62
+
63
+
64
+ # Supports autograd, but does not support async
65
+ all_gather = AllGatherFunc.apply
66
+
67
+
68
+ class ReduceScatterFunc(torch.autograd.Function):
69
+ """Reduce scatter the input from the sequence parallel region and concatenate."""
70
+
71
+ @staticmethod
72
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
+ ctx.process_group = process_group
74
+ output, _ = reduce_scatter_raw(input_, process_group)
75
+ return output
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output: Tensor):
79
+ grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
+ return grad_input, None
81
+
82
+
83
+ # Supports autograd, but does not support async
84
+ reduce_scatter = ReduceScatterFunc.apply
85
+
86
+
87
+ class AllReduceFunc(torch.autograd.Function):
88
+ """Gather the input from sequence parallel region and concatenate."""
89
+
90
+ @staticmethod
91
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
+ ctx.process_group = process_group
93
+ output, _ = all_reduce_raw(input_, process_group)
94
+ return output
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output: Tensor):
98
+ return grad_output, None
99
+
100
+
101
+ # Supports autograd, but does not support async
102
+ all_reduce = AllReduceFunc.apply
103
+
104
+
105
+ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
+ # We want to iterate over parameters with _shared_params=True in the same order,
107
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
+ pamams_shared = {
109
+ name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
+ }
111
+ for _, p in sorted(pamams_shared.items()):
112
+ with torch.no_grad():
113
+ # Broadcast needs src to be global rank, not group rank
114
+ torch.distributed.broadcast(
115
+ p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
+ )
117
+
118
+
119
+ # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
+ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
+ # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
+ params_seqparallel = {
124
+ name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
+ }
126
+ grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
+ if grads:
128
+ with torch.no_grad():
129
+ coalesced = torch._utils._flatten_dense_tensors(grads)
130
+ torch.distributed.all_reduce(coalesced, group=process_group)
131
+ for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
+ buf.copy_(synced)
133
+
134
+
135
+ def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
+ """Get the dim for the local rank derived from splitting dim on world_size processes.
137
+
138
+ The split may not be even across the world_size processes.
139
+ """
140
+ multiple = dim // multiple_of
141
+ div = multiple // world_size
142
+ mod = multiple % world_size
143
+ local_multiple = div + int(local_rank < mod)
144
+ return local_multiple * multiple_of
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.distributed import ProcessGroup
10
+ from ..utils.torch import custom_bwd, custom_fwd
11
+
12
+ from einops import rearrange
13
+
14
+ from ..distributed.distributed_utils import (
15
+ all_gather_raw,
16
+ all_reduce,
17
+ all_reduce_raw,
18
+ reduce_scatter,
19
+ reduce_scatter_raw,
20
+ )
21
+
22
+
23
+ class ParallelLinearFunc(torch.autograd.Function):
24
+ @staticmethod
25
+ @custom_fwd
26
+ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
+ """
28
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
+ with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
+ """
31
+ ctx.compute_weight_gradient = weight.requires_grad
32
+ ctx.process_group = process_group
33
+ ctx.sequence_parallel = sequence_parallel
34
+
35
+ if torch.is_autocast_enabled():
36
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
+ x = x.contiguous()
38
+ if process_group is not None and sequence_parallel:
39
+ # We want to kick off the all_gather early, before weight dtype conversion
40
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
+ else:
42
+ total_x = x
43
+
44
+ if torch.is_autocast_enabled():
45
+ weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
+ bias = (
47
+ bias.to(dtype=torch.get_autocast_gpu_dtype())
48
+ if bias is not None
49
+ else None
50
+ )
51
+ weight = weight.contiguous()
52
+ if process_group is not None and sequence_parallel:
53
+ handle_x.wait()
54
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
55
+ batch_dim = batch_shape.numel()
56
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
57
+ output = F.linear(total_x, weight, bias)
58
+ if ctx.compute_weight_gradient:
59
+ ctx.save_for_backward(x, weight)
60
+ else:
61
+ ctx.save_for_backward(weight)
62
+ return output
63
+
64
+ @staticmethod
65
+ @custom_bwd
66
+ def backward(ctx, grad_output):
67
+ grad_output = grad_output.contiguous()
68
+ process_group = ctx.process_group
69
+ sequence_parallel = ctx.sequence_parallel
70
+ if ctx.compute_weight_gradient:
71
+ x, weight = ctx.saved_tensors
72
+ if process_group is not None and sequence_parallel:
73
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
74
+ else:
75
+ total_x = x
76
+ else:
77
+ (weight,) = ctx.saved_tensors
78
+ total_x = None
79
+ batch_shape = grad_output.shape[:-1]
80
+ batch_dim = batch_shape.numel()
81
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
82
+ if ctx.needs_input_grad[0]:
83
+ grad_input = F.linear(grad_output, weight.t())
84
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
85
+ if process_group is not None:
86
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
87
+ grad_input, handle_grad_input = reduce_fn(
88
+ grad_input, process_group, async_op=True
89
+ )
90
+ else:
91
+ grad_input = None
92
+ if ctx.needs_input_grad[1]:
93
+ assert ctx.compute_weight_gradient
94
+ if process_group is not None and sequence_parallel:
95
+ handle_x.wait()
96
+ grad_weight = torch.einsum(
97
+ "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
98
+ )
99
+ else:
100
+ grad_weight = None
101
+ grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
102
+ if process_group is not None and ctx.needs_input_grad[0]:
103
+ handle_grad_input.wait()
104
+ return grad_input, grad_weight, grad_bias, None, None
105
+
106
+
107
+ def parallel_linear_func(
108
+ x: Tensor,
109
+ weight: Tensor,
110
+ bias: Optional[Tensor] = None,
111
+ process_group: Optional[ProcessGroup] = None,
112
+ sequence_parallel: bool = True,
113
+ ):
114
+ return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
115
+
116
+
117
+ class ColumnParallelLinear(nn.Linear):
118
+ def __init__(
119
+ self,
120
+ in_features: int,
121
+ out_features: int,
122
+ process_group: ProcessGroup,
123
+ bias: bool = True,
124
+ sequence_parallel=True,
125
+ multiple_of=1,
126
+ device=None,
127
+ dtype=None,
128
+ ) -> None:
129
+ world_size = torch.distributed.get_world_size(process_group)
130
+ if out_features % multiple_of:
131
+ raise ValueError(
132
+ f"out_features ({out_features}) must be a multiple of {multiple_of}"
133
+ )
134
+ multiple = out_features // multiple_of
135
+ # We want to split @multiple across world_size, but it could be an uneven split
136
+ div = multiple // world_size
137
+ mod = multiple % world_size
138
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
139
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
140
+ super().__init__(
141
+ in_features,
142
+ local_multiple * multiple_of,
143
+ bias=bias,
144
+ device=device,
145
+ dtype=dtype,
146
+ )
147
+ self.process_group = process_group
148
+ self.sequence_parallel = sequence_parallel
149
+
150
+ def forward(self, x):
151
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
152
+ # we do an all_gather of x before doing the matmul.
153
+ # If not, then the input is already gathered.
154
+ return parallel_linear_func(
155
+ x,
156
+ self.weight,
157
+ self.bias,
158
+ process_group=self.process_group,
159
+ sequence_parallel=self.sequence_parallel,
160
+ )
161
+
162
+
163
+ class RowParallelLinear(nn.Linear):
164
+ def __init__(
165
+ self,
166
+ in_features: int,
167
+ out_features: int,
168
+ process_group: ProcessGroup,
169
+ bias: bool = True,
170
+ sequence_parallel=True,
171
+ multiple_of=1,
172
+ device=None,
173
+ dtype=None,
174
+ ) -> None:
175
+ world_size = torch.distributed.get_world_size(process_group)
176
+ rank = torch.distributed.get_rank(process_group)
177
+ if in_features % multiple_of:
178
+ raise ValueError(
179
+ f"in_features ({in_features}) must be a multiple of {multiple_of}"
180
+ )
181
+ multiple = in_features // multiple_of
182
+ # We want to split @multiple across world_size, but it could be an uneven split
183
+ div = multiple // world_size
184
+ mod = multiple % world_size
185
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
+ # Only rank 0 will have bias
188
+ super().__init__(
189
+ local_multiple * multiple_of,
190
+ out_features,
191
+ bias=bias and rank == 0,
192
+ device=device,
193
+ dtype=dtype,
194
+ )
195
+ self.process_group = process_group
196
+ self.sequence_parallel = sequence_parallel
197
+
198
+ def forward(self, x):
199
+ """
200
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
201
+ a reduce_scatter of the result.
202
+ """
203
+ out = parallel_linear_func(x, self.weight, self.bias)
204
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
205
+ return reduce_fn(out, self.process_group)
206
+
207
+
208
+ class VocabParallelEmbedding(nn.Embedding):
209
+ def __init__(
210
+ self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
211
+ ):
212
+ self.process_group = process_group
213
+ if process_group is not None:
214
+ world_size = torch.distributed.get_world_size(process_group)
215
+ if num_embeddings % world_size != 0:
216
+ raise ValueError(
217
+ f"num_embeddings ({num_embeddings}) must be divisible by "
218
+ f"world_size ({world_size})"
219
+ )
220
+ if world_size > 1 and padding_idx is not None:
221
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
222
+ else:
223
+ world_size = 1
224
+ super().__init__(
225
+ num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
226
+ )
227
+
228
+ def forward(self, input: Tensor) -> Tensor:
229
+ if self.process_group is None:
230
+ return super().forward(input)
231
+ else:
232
+ rank = torch.distributed.get_rank(self.process_group)
233
+ vocab_size = self.num_embeddings
234
+ vocab_start_index, vocab_end_index = (
235
+ rank * vocab_size,
236
+ (rank + 1) * vocab_size,
237
+ )
238
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
239
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
240
+ input = input - vocab_start_index
241
+ input[input_ids_mask] = 0
242
+ embeddings = super().forward(input)
243
+ embeddings[input_ids_mask] = 0.0
244
+ return embeddings
245
+
246
+
247
+ class ColumnParallelEmbedding(nn.Embedding):
248
+ def __init__(
249
+ self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
250
+ ):
251
+ self.process_group = process_group
252
+ if process_group is not None:
253
+ world_size = torch.distributed.get_world_size(process_group)
254
+ if embedding_dim % world_size != 0:
255
+ raise ValueError(
256
+ f"embedding_dim ({embedding_dim}) must be divisible by "
257
+ f"world_size ({world_size})"
258
+ )
259
+ else:
260
+ world_size = 1
261
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
262
+
263
+
264
+ class ParallelEmbeddings(nn.Module):
265
+ def __init__(
266
+ self,
267
+ embed_dim,
268
+ vocab_size,
269
+ max_position_embeddings,
270
+ process_group,
271
+ padding_idx=None,
272
+ sequence_parallel=True,
273
+ device=None,
274
+ dtype=None,
275
+ ):
276
+ """
277
+ If max_position_embeddings <= 0, there's no position embeddings
278
+ """
279
+ factory_kwargs = {"device": device, "dtype": dtype}
280
+ super().__init__()
281
+ self.process_group = process_group
282
+ self.sequence_parallel = sequence_parallel
283
+ self.word_embeddings = VocabParallelEmbedding(
284
+ vocab_size,
285
+ embed_dim,
286
+ padding_idx=padding_idx,
287
+ process_group=process_group,
288
+ **factory_kwargs,
289
+ )
290
+ self.max_position_embeddings = max_position_embeddings
291
+ if self.max_position_embeddings > 0:
292
+ self.position_embeddings = ColumnParallelEmbedding(
293
+ max_position_embeddings,
294
+ embed_dim,
295
+ process_group=process_group,
296
+ **factory_kwargs,
297
+ )
298
+
299
+ def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
300
+ """
301
+ input_ids: (batch, seqlen)
302
+ position_ids: (batch, seqlen)
303
+ """
304
+ batch_size, seqlen = input_ids.shape
305
+ world_size = torch.distributed.get_world_size(self.process_group)
306
+ embeddings = self.word_embeddings(input_ids)
307
+ if self.max_position_embeddings > 0:
308
+ if position_ids is None:
309
+ position_ids = torch.arange(
310
+ seqlen, dtype=torch.long, device=input_ids.device
311
+ )
312
+ position_embeddings = self.position_embeddings(position_ids)
313
+ if world_size <= 1:
314
+ embeddings = embeddings + position_embeddings
315
+ else:
316
+ partition_dim = self.position_embeddings.embedding_dim
317
+ rank = torch.distributed.get_rank(self.process_group)
318
+ embeddings[
319
+ ..., rank * partition_dim : (rank + 1) * partition_dim
320
+ ] += position_embeddings
321
+ if combine_batch_seqlen_dim:
322
+ embeddings = rearrange(embeddings, "b s d -> (b s) d")
323
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
324
+ return (
325
+ embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
326
+ )
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/__init__.py ADDED
File without changes
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/config_mamba.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ d_intermediate: int = 0
9
+ n_layer: int = 64
10
+ vocab_size: int = 50277
11
+ ssm_cfg: dict = field(default_factory=dict)
12
+ attn_layer_idx: list = field(default_factory=list)
13
+ attn_cfg: dict = field(default_factory=dict)
14
+ rms_norm: bool = True
15
+ residual_in_fp32: bool = True
16
+ fused_add_norm: bool = True
17
+ pad_vocab_size_multiple: int = 8
18
+ tie_embeddings: bool = True
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+ import copy
8
+
9
+ from collections import namedtuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .config_mamba import MambaConfig
15
+ from ..modules.mamba_simple import Mamba
16
+ from ..modules.mamba2 import Mamba2
17
+ from ..modules.mha import MHA
18
+ from ..modules.mlp import GatedMLP
19
+ from ..modules.block import Block
20
+ from ..utils.generation import GenerationMixin
21
+ from ..utils.hf import load_config_hf, load_state_dict_hf
22
+
23
+ try:
24
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
+ except ImportError:
26
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
+
28
+
29
+ def create_block(
30
+ d_model,
31
+ d_intermediate,
32
+ ssm_cfg=None,
33
+ attn_layer_idx=None,
34
+ attn_cfg=None,
35
+ norm_epsilon=1e-5,
36
+ rms_norm=False,
37
+ residual_in_fp32=False,
38
+ fused_add_norm=False,
39
+ layer_idx=None,
40
+ device=None,
41
+ dtype=None,
42
+ ):
43
+ if ssm_cfg is None:
44
+ ssm_cfg = {}
45
+ if attn_layer_idx is None:
46
+ attn_layer_idx = []
47
+ if attn_cfg is None:
48
+ attn_cfg = {}
49
+ factory_kwargs = {"device": device, "dtype": dtype}
50
+ if layer_idx not in attn_layer_idx:
51
+ # Create a copy of the config to modify
52
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
55
+ raise ValueError(
56
+ f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
57
+ )
58
+ mixer_cls = partial(
59
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
60
+ layer_idx=layer_idx,
61
+ **ssm_cfg,
62
+ **factory_kwargs,
63
+ )
64
+ else:
65
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
66
+ norm_cls = partial(
67
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
+ )
69
+ if d_intermediate == 0:
70
+ mlp_cls = nn.Identity
71
+ else:
72
+ mlp_cls = partial(
73
+ GatedMLP,
74
+ hidden_features=d_intermediate,
75
+ out_features=d_model,
76
+ **factory_kwargs,
77
+ )
78
+ block = Block(
79
+ d_model,
80
+ mixer_cls,
81
+ mlp_cls,
82
+ norm_cls=norm_cls,
83
+ fused_add_norm=fused_add_norm,
84
+ residual_in_fp32=residual_in_fp32,
85
+ )
86
+ block.layer_idx = layer_idx
87
+ return block
88
+
89
+
90
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
91
+ def _init_weights(
92
+ module,
93
+ n_layer,
94
+ initializer_range=0.02, # Now only used for embedding layer.
95
+ rescale_prenorm_residual=True,
96
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
97
+ ):
98
+ if isinstance(module, nn.Linear):
99
+ if module.bias is not None:
100
+ if not getattr(module.bias, "_no_reinit", False):
101
+ nn.init.zeros_(module.bias)
102
+ elif isinstance(module, nn.Embedding):
103
+ nn.init.normal_(module.weight, std=initializer_range)
104
+
105
+ if rescale_prenorm_residual:
106
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
107
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
108
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
109
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
110
+ #
111
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
112
+ for name, p in module.named_parameters():
113
+ if name in ["out_proj.weight", "fc2.weight"]:
114
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
115
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
116
+ # We need to reinit p since this code could be called multiple times
117
+ # Having just p *= scale would repeatedly scale it down
118
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
119
+ with torch.no_grad():
120
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
121
+
122
+
123
+ class MixerModel(nn.Module):
124
+ def __init__(
125
+ self,
126
+ d_model: int,
127
+ n_layer: int,
128
+ d_intermediate: int,
129
+ vocab_size: int,
130
+ ssm_cfg=None,
131
+ attn_layer_idx=None,
132
+ attn_cfg=None,
133
+ norm_epsilon: float = 1e-5,
134
+ rms_norm: bool = False,
135
+ initializer_cfg=None,
136
+ fused_add_norm=False,
137
+ residual_in_fp32=False,
138
+ device=None,
139
+ dtype=None,
140
+ ) -> None:
141
+ factory_kwargs = {"device": device, "dtype": dtype}
142
+ super().__init__()
143
+ self.residual_in_fp32 = residual_in_fp32
144
+
145
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
146
+
147
+ # We change the order of residual and layer norm:
148
+ # Instead of LN -> Attn / MLP -> Add, we do:
149
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
150
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
151
+ # This is for performance reason: we can fuse add + layer_norm.
152
+ self.fused_add_norm = fused_add_norm
153
+ if self.fused_add_norm:
154
+ if layer_norm_fn is None or rms_norm_fn is None:
155
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
156
+
157
+ self.layers = nn.ModuleList(
158
+ [
159
+ create_block(
160
+ d_model,
161
+ d_intermediate=d_intermediate,
162
+ ssm_cfg=ssm_cfg,
163
+ attn_layer_idx=attn_layer_idx,
164
+ attn_cfg=attn_cfg,
165
+ norm_epsilon=norm_epsilon,
166
+ rms_norm=rms_norm,
167
+ residual_in_fp32=residual_in_fp32,
168
+ fused_add_norm=fused_add_norm,
169
+ layer_idx=i,
170
+ **factory_kwargs,
171
+ )
172
+ for i in range(n_layer)
173
+ ]
174
+ )
175
+
176
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
+ d_model, eps=norm_epsilon, **factory_kwargs
178
+ )
179
+
180
+ self.apply(
181
+ partial(
182
+ _init_weights,
183
+ n_layer=n_layer,
184
+ **(initializer_cfg if initializer_cfg is not None else {}),
185
+ n_residuals_per_layer=(
186
+ 1 if d_intermediate == 0 else 2
187
+ ), # 2 if we have MLP
188
+ )
189
+ )
190
+
191
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
192
+ return {
193
+ i: layer.allocate_inference_cache(
194
+ batch_size, max_seqlen, dtype=dtype, **kwargs
195
+ )
196
+ for i, layer in enumerate(self.layers)
197
+ }
198
+
199
+ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
200
+ hidden_states = self.embedding(input_ids)
201
+ residual = None
202
+ for layer in self.layers:
203
+ hidden_states, residual = layer(
204
+ hidden_states,
205
+ residual,
206
+ inference_params=inference_params,
207
+ **mixer_kwargs,
208
+ )
209
+ if not self.fused_add_norm:
210
+ residual = (
211
+ (hidden_states + residual) if residual is not None else hidden_states
212
+ )
213
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
214
+ else:
215
+ # Set prenorm=False here since we don't need the residual
216
+ hidden_states = layer_norm_fn(
217
+ hidden_states,
218
+ self.norm_f.weight,
219
+ self.norm_f.bias,
220
+ eps=self.norm_f.eps,
221
+ residual=residual,
222
+ prenorm=False,
223
+ residual_in_fp32=self.residual_in_fp32,
224
+ is_rms_norm=isinstance(self.norm_f, RMSNorm),
225
+ )
226
+ return hidden_states
227
+
228
+
229
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
230
+
231
+ def __init__(
232
+ self,
233
+ config: MambaConfig,
234
+ initializer_cfg=None,
235
+ device=None,
236
+ dtype=None,
237
+ ) -> None:
238
+ self.config = config
239
+ d_model = config.d_model
240
+ n_layer = config.n_layer
241
+ d_intermediate = config.d_intermediate
242
+ vocab_size = config.vocab_size
243
+ ssm_cfg = config.ssm_cfg
244
+ attn_layer_idx = config.attn_layer_idx
245
+ attn_cfg = config.attn_cfg
246
+ rms_norm = config.rms_norm
247
+ residual_in_fp32 = config.residual_in_fp32
248
+ fused_add_norm = config.fused_add_norm
249
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
250
+ factory_kwargs = {"device": device, "dtype": dtype}
251
+
252
+ super().__init__()
253
+ if vocab_size % pad_vocab_size_multiple != 0:
254
+ vocab_size += pad_vocab_size_multiple - (
255
+ vocab_size % pad_vocab_size_multiple
256
+ )
257
+ self.backbone = MixerModel(
258
+ d_model=d_model,
259
+ n_layer=n_layer,
260
+ d_intermediate=d_intermediate,
261
+ vocab_size=vocab_size,
262
+ ssm_cfg=ssm_cfg,
263
+ attn_layer_idx=attn_layer_idx,
264
+ attn_cfg=attn_cfg,
265
+ rms_norm=rms_norm,
266
+ initializer_cfg=initializer_cfg,
267
+ fused_add_norm=fused_add_norm,
268
+ residual_in_fp32=residual_in_fp32,
269
+ **factory_kwargs,
270
+ )
271
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
272
+
273
+ # Initialize weights and apply final processing
274
+ self.apply(
275
+ partial(
276
+ _init_weights,
277
+ n_layer=n_layer,
278
+ **(initializer_cfg if initializer_cfg is not None else {}),
279
+ )
280
+ )
281
+ self.tie_weights()
282
+
283
+ def tie_weights(self):
284
+ if self.config.tie_embeddings:
285
+ self.lm_head.weight = self.backbone.embedding.weight
286
+
287
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
288
+ return self.backbone.allocate_inference_cache(
289
+ batch_size, max_seqlen, dtype=dtype, **kwargs
290
+ )
291
+
292
+ def forward(
293
+ self,
294
+ input_ids,
295
+ position_ids=None,
296
+ inference_params=None,
297
+ num_last_tokens=0,
298
+ **mixer_kwargs,
299
+ ):
300
+ """
301
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
302
+ num_last_tokens: if > 0, only return the logits for the last n tokens
303
+ """
304
+ hidden_states = self.backbone(
305
+ input_ids, inference_params=inference_params, **mixer_kwargs
306
+ )
307
+ if num_last_tokens > 0:
308
+ hidden_states = hidden_states[:, -num_last_tokens:]
309
+ lm_logits = self.lm_head(hidden_states)
310
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
311
+ return CausalLMOutput(logits=lm_logits)
312
+
313
+ @classmethod
314
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
315
+ config_data = load_config_hf(pretrained_model_name)
316
+ config = MambaConfig(**config_data)
317
+ model = cls(config, device=device, dtype=dtype, **kwargs)
318
+ model.load_state_dict(
319
+ load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
320
+ )
321
+ return model
322
+
323
+ def save_pretrained(self, save_directory):
324
+ """
325
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
326
+ Save the model and its configuration file to a directory.
327
+ """
328
+ # Ensure save_directory exists
329
+ os.makedirs(save_directory, exist_ok=True)
330
+
331
+ # Save the model's state_dict
332
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
333
+ torch.save(self.state_dict(), model_path)
334
+
335
+ # Save the configuration of the model
336
+ config_path = os.path.join(save_directory, "config.json")
337
+ with open(config_path, "w") as f:
338
+ json.dump(self.config.__dict__, f, indent=4)
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/__init__.py ADDED
File without changes
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/block.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
+
9
+
10
+ class Block(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim,
14
+ mixer_cls,
15
+ mlp_cls,
16
+ norm_cls=nn.LayerNorm,
17
+ fused_add_norm=False,
18
+ residual_in_fp32=False,
19
+ ):
20
+ """
21
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
22
+
23
+ This Block has a slightly different structure compared to a regular
24
+ prenorm Transformer block.
25
+ The standard block is: LN -> MHA/MLP -> Add.
26
+ [Ref: https://arxiv.org/abs/2002.04745]
27
+ Here we have: Add -> LN -> Mixer, returning both
28
+ the hidden_states (output of the mixer) and the residual.
29
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
30
+ The residual needs to be provided (except for the very first block).
31
+ """
32
+ super().__init__()
33
+ self.residual_in_fp32 = residual_in_fp32
34
+ self.fused_add_norm = fused_add_norm
35
+ self.norm = norm_cls(dim)
36
+ self.mixer = mixer_cls(dim)
37
+ if mlp_cls is not nn.Identity:
38
+ self.norm2 = norm_cls(dim)
39
+ self.mlp = mlp_cls(dim)
40
+ else:
41
+ self.mlp = None
42
+ if self.fused_add_norm:
43
+ assert RMSNorm is not None, "RMSNorm import fails"
44
+ assert isinstance(
45
+ self.norm, (nn.LayerNorm, RMSNorm)
46
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: Tensor,
51
+ residual: Optional[Tensor] = None,
52
+ inference_params=None,
53
+ **mixer_kwargs
54
+ ):
55
+ r"""Pass the input through the encoder layer.
56
+
57
+ Args:
58
+ hidden_states: the sequence to the encoder layer (required).
59
+ residual: hidden_states = Mixer(LN(residual))
60
+ """
61
+ if not self.fused_add_norm:
62
+ residual = (
63
+ (hidden_states + residual) if residual is not None else hidden_states
64
+ )
65
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
66
+ if self.residual_in_fp32:
67
+ residual = residual.to(torch.float32)
68
+ else:
69
+ hidden_states, residual = layer_norm_fn(
70
+ hidden_states,
71
+ self.norm.weight,
72
+ self.norm.bias,
73
+ residual=residual,
74
+ prenorm=True,
75
+ residual_in_fp32=self.residual_in_fp32,
76
+ eps=self.norm.eps,
77
+ is_rms_norm=isinstance(self.norm, RMSNorm),
78
+ )
79
+ hidden_states = self.mixer(
80
+ hidden_states, inference_params=inference_params, **mixer_kwargs
81
+ )
82
+
83
+ if self.mlp is not None:
84
+ if not self.fused_add_norm:
85
+ residual = hidden_states + residual
86
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
87
+ if self.residual_in_fp32:
88
+ residual = residual.to(torch.float32)
89
+ else:
90
+ hidden_states, residual = layer_norm_fn(
91
+ hidden_states,
92
+ self.norm2.weight,
93
+ self.norm2.bias,
94
+ residual=residual,
95
+ prenorm=True,
96
+ residual_in_fp32=self.residual_in_fp32,
97
+ eps=self.norm2.eps,
98
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
99
+ )
100
+ hidden_states = self.mlp(hidden_states)
101
+
102
+ return hidden_states, residual
103
+
104
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
+ return self.mixer.allocate_inference_cache(
106
+ batch_size, max_seqlen, dtype=dtype, **kwargs
107
+ )
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ try:
22
+ from ..ops.triton.selective_state_update import selective_state_update
23
+ except ImportError:
24
+ selective_state_update = None
25
+
26
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
+
28
+ from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
+ from ..distributed.distributed_utils import all_reduce, reduce_scatter
30
+
31
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
32
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
+
34
+ from huggingface_hub import PyTorchModelHubMixin
35
+
36
+
37
+ class Mamba2(nn.Module, PyTorchModelHubMixin):
38
+ def __init__(
39
+ self,
40
+ d_model,
41
+ d_state=128,
42
+ d_conv=4,
43
+ conv_init=None,
44
+ expand=2,
45
+ headdim=64,
46
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
+ ngroups=1,
48
+ A_init_range=(1, 16),
49
+ D_has_hdim=False,
50
+ rmsnorm=True,
51
+ norm_before_gate=False,
52
+ dt_min=0.001,
53
+ dt_max=0.1,
54
+ dt_init_floor=1e-4,
55
+ dt_limit=(0.0, float("inf")),
56
+ bias=False,
57
+ conv_bias=True,
58
+ # Fused kernel and sharding options
59
+ chunk_size=256,
60
+ use_mem_eff_path=True,
61
+ layer_idx=None, # Absorb kwarg for general module
62
+ process_group=None,
63
+ sequence_parallel=True,
64
+ device=None,
65
+ dtype=None,
66
+ ):
67
+ factory_kwargs = {"device": device, "dtype": dtype}
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.d_state = d_state
71
+ self.d_conv = d_conv
72
+ self.conv_init = conv_init
73
+ self.expand = expand
74
+ self.process_group = process_group
75
+ self.sequence_parallel = sequence_parallel
76
+ self.world_size = 1 if process_group is None else process_group.size()
77
+ self.local_rank = 0 if process_group is None else process_group.rank()
78
+ self.d_inner = (self.expand * self.d_model) // self.world_size
79
+ assert self.d_inner * self.world_size == self.expand * self.d_model
80
+ self.headdim = headdim
81
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
+ assert ngroups % self.world_size == 0
83
+ self.ngroups = ngroups // self.world_size
84
+ assert self.d_ssm % self.headdim == 0
85
+ self.nheads = self.d_ssm // self.headdim
86
+ self.D_has_hdim = D_has_hdim
87
+ self.rmsnorm = rmsnorm
88
+ self.norm_before_gate = norm_before_gate
89
+ self.dt_limit = dt_limit
90
+ self.activation = "silu"
91
+ self.chunk_size = chunk_size
92
+ self.use_mem_eff_path = use_mem_eff_path
93
+ self.layer_idx = layer_idx
94
+
95
+ # Order: [z, x, B, C, dt]
96
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
+ if self.process_group is None:
98
+ self.in_proj = nn.Linear(
99
+ self.d_model, d_in_proj, bias=bias, **factory_kwargs
100
+ )
101
+ else:
102
+ self.in_proj = ColumnParallelLinear(
103
+ self.d_model,
104
+ d_in_proj * self.world_size,
105
+ bias=bias,
106
+ process_group=self.process_group,
107
+ sequence_parallel=self.sequence_parallel,
108
+ **factory_kwargs,
109
+ )
110
+
111
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
112
+ self.conv1d = nn.Conv1d(
113
+ in_channels=conv_dim,
114
+ out_channels=conv_dim,
115
+ bias=conv_bias,
116
+ kernel_size=d_conv,
117
+ groups=conv_dim,
118
+ padding=d_conv - 1,
119
+ **factory_kwargs,
120
+ )
121
+ if self.conv_init is not None:
122
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
123
+
124
+ self.act = nn.SiLU()
125
+
126
+ # Initialize log dt bias
127
+ dt = torch.exp(
128
+ torch.rand(self.nheads, **factory_kwargs)
129
+ * (math.log(dt_max) - math.log(dt_min))
130
+ + math.log(dt_min)
131
+ )
132
+ dt = torch.clamp(dt, min=dt_init_floor)
133
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
134
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
135
+ self.dt_bias = nn.Parameter(inv_dt)
136
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
137
+ # name.endswith("bias") in param_grouping.py
138
+ self.dt_bias._no_weight_decay = True
139
+
140
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
141
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
142
+ *A_init_range
143
+ )
144
+ A_log = torch.log(A).to(dtype=dtype)
145
+ self.A_log = nn.Parameter(A_log)
146
+ self.A_log._no_weight_decay = True
147
+
148
+ # D "skip" parameter
149
+ self.D = nn.Parameter(
150
+ torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
151
+ )
152
+ self.D._no_weight_decay = True
153
+
154
+ if self.rmsnorm:
155
+ assert RMSNormGated is not None
156
+ self.norm = RMSNormGated(
157
+ self.d_ssm,
158
+ eps=1e-5,
159
+ norm_before_gate=self.norm_before_gate,
160
+ group_size=self.d_ssm // ngroups,
161
+ **factory_kwargs,
162
+ )
163
+
164
+ if self.process_group is None:
165
+ self.out_proj = nn.Linear(
166
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
167
+ )
168
+ else:
169
+ self.out_proj = RowParallelLinear(
170
+ self.d_inner * self.world_size,
171
+ self.d_model,
172
+ bias=bias,
173
+ process_group=self.process_group,
174
+ sequence_parallel=self.sequence_parallel,
175
+ **factory_kwargs,
176
+ )
177
+
178
+ def forward(
179
+ self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
180
+ ):
181
+ """
182
+ u: (batch, seqlen, hidden_dim) if seqlen=None.
183
+ If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
184
+ split u during sequence parallel, we split the batch * seqlen dimension
185
+ (in case batch is small).
186
+ Returns: same shape as u
187
+ """
188
+ seqlen_og = seqlen
189
+ if seqlen is None:
190
+ batch, seqlen, dim = u.shape
191
+ else:
192
+ batch_seqlen, dim = u.shape
193
+ batch = batch_seqlen // seqlen
194
+
195
+ conv_state, ssm_state = None, None
196
+ if inference_params is not None:
197
+ inference_batch = (
198
+ cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
199
+ )
200
+ conv_state, ssm_state = self._get_states_from_cache(
201
+ inference_params, inference_batch
202
+ )
203
+ if inference_params.seqlen_offset > 0:
204
+ # The states are updated inplace
205
+ out, _, _ = self.step(u, conv_state, ssm_state)
206
+ return out
207
+
208
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
209
+ if seqlen_og is not None:
210
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
211
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
212
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
213
+ dt_limit_kwargs = (
214
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
215
+ )
216
+ if self.use_mem_eff_path and inference_params is None:
217
+ out = mamba_split_conv1d_scan_combined(
218
+ zxbcdt,
219
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
220
+ self.conv1d.bias,
221
+ self.dt_bias,
222
+ A,
223
+ D=(
224
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
225
+ if self.D_has_hdim
226
+ else self.D
227
+ ),
228
+ chunk_size=self.chunk_size,
229
+ seq_idx=seq_idx,
230
+ activation=self.activation,
231
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
232
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
233
+ outproj_weight=self.out_proj.weight,
234
+ outproj_bias=self.out_proj.bias,
235
+ headdim=None if self.D_has_hdim else self.headdim,
236
+ ngroups=self.ngroups,
237
+ norm_before_gate=self.norm_before_gate,
238
+ **dt_limit_kwargs,
239
+ )
240
+ if seqlen_og is not None:
241
+ out = rearrange(out, "b l d -> (b l) d")
242
+ if self.process_group is not None:
243
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
244
+ out = reduce_fn(out, self.process_group)
245
+ else:
246
+ d_mlp = (
247
+ zxbcdt.shape[-1]
248
+ - 2 * self.d_ssm
249
+ - 2 * self.ngroups * self.d_state
250
+ - self.nheads
251
+ ) // 2
252
+ z0, x0, z, xBC, dt = torch.split(
253
+ zxbcdt,
254
+ [
255
+ d_mlp,
256
+ d_mlp,
257
+ self.d_ssm,
258
+ self.d_ssm + 2 * self.ngroups * self.d_state,
259
+ self.nheads,
260
+ ],
261
+ dim=-1,
262
+ )
263
+ if conv_state is not None:
264
+ if cu_seqlens is None:
265
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
266
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
267
+ xBC_t = rearrange(xBC, "b l d -> b d l")
268
+ conv_state.copy_(
269
+ F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
270
+ ) # Update state (B D W)
271
+ else:
272
+ assert (
273
+ causal_conv1d_varlen_states is not None
274
+ ), "varlen inference requires causal_conv1d package"
275
+ assert (
276
+ batch == 1
277
+ ), "varlen inference only supports batch dimension 1"
278
+ conv_varlen_states = causal_conv1d_varlen_states(
279
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
280
+ )
281
+ conv_state.copy_(conv_varlen_states)
282
+ assert self.activation in ["silu", "swish"]
283
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
284
+ assert (
285
+ seq_idx is None
286
+ ), "varlen conv1d requires the causal_conv1d package"
287
+ xBC = self.act(
288
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
289
+ :, : -(self.d_conv - 1)
290
+ ]
291
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
292
+ else:
293
+ xBC = causal_conv1d_fn(
294
+ xBC.transpose(1, 2),
295
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
296
+ bias=self.conv1d.bias,
297
+ activation=self.activation,
298
+ seq_idx=seq_idx,
299
+ ).transpose(1, 2)
300
+ x, B, C = torch.split(
301
+ xBC,
302
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
303
+ dim=-1,
304
+ )
305
+ y = mamba_chunk_scan_combined(
306
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
307
+ dt,
308
+ A,
309
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
310
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
311
+ chunk_size=self.chunk_size,
312
+ D=(
313
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
314
+ if self.D_has_hdim
315
+ else self.D
316
+ ),
317
+ z=(
318
+ rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
319
+ if not self.rmsnorm
320
+ else None
321
+ ),
322
+ dt_bias=self.dt_bias,
323
+ dt_softplus=True,
324
+ seq_idx=seq_idx,
325
+ cu_seqlens=cu_seqlens,
326
+ **dt_limit_kwargs,
327
+ return_final_states=ssm_state is not None,
328
+ return_varlen_states=cu_seqlens is not None
329
+ and inference_params is not None,
330
+ )
331
+ if ssm_state is not None:
332
+ y, last_state, *rest = y
333
+ if cu_seqlens is None:
334
+ ssm_state.copy_(last_state)
335
+ else:
336
+ varlen_states = rest[0]
337
+ ssm_state.copy_(varlen_states)
338
+ y = rearrange(y, "b l h p -> b l (h p)")
339
+ if self.rmsnorm:
340
+ y = self.norm(y, z)
341
+ if d_mlp > 0:
342
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
343
+ if seqlen_og is not None:
344
+ y = rearrange(y, "b l d -> (b l) d")
345
+ out = self.out_proj(y)
346
+ return out
347
+
348
+ def step(self, hidden_states, conv_state, ssm_state):
349
+ dtype = hidden_states.dtype
350
+ assert (
351
+ hidden_states.shape[1] == 1
352
+ ), "Only support decoding with 1 token at a time for now"
353
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
354
+ d_mlp = (
355
+ zxbcdt.shape[-1]
356
+ - 2 * self.d_ssm
357
+ - 2 * self.ngroups * self.d_state
358
+ - self.nheads
359
+ ) // 2
360
+ z0, x0, z, xBC, dt = torch.split(
361
+ zxbcdt,
362
+ [
363
+ d_mlp,
364
+ d_mlp,
365
+ self.d_ssm,
366
+ self.d_ssm + 2 * self.ngroups * self.d_state,
367
+ self.nheads,
368
+ ],
369
+ dim=-1,
370
+ )
371
+
372
+ # Conv step
373
+ if causal_conv1d_update is None:
374
+ conv_state.copy_(
375
+ torch.roll(conv_state, shifts=-1, dims=-1)
376
+ ) # Update state (B D W)
377
+ conv_state[:, :, -1] = xBC
378
+ xBC = torch.sum(
379
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
380
+ ) # (B D)
381
+ if self.conv1d.bias is not None:
382
+ xBC = xBC + self.conv1d.bias
383
+ xBC = self.act(xBC).to(dtype=dtype)
384
+ else:
385
+ xBC = causal_conv1d_update(
386
+ xBC,
387
+ conv_state,
388
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
389
+ self.conv1d.bias,
390
+ self.activation,
391
+ )
392
+
393
+ x, B, C = torch.split(
394
+ xBC,
395
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
396
+ dim=-1,
397
+ )
398
+ A = -torch.exp(self.A_log.float()) # (nheads,)
399
+
400
+ # SSM step
401
+ if selective_state_update is None:
402
+ assert (
403
+ self.ngroups == 1
404
+ ), "Only support ngroups=1 for this inference code path"
405
+ # Discretize A and B
406
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
407
+ dA = torch.exp(dt * A) # (batch, nheads)
408
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
409
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
410
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
411
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
412
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
413
+ y = rearrange(y, "b h p -> b (h p)")
414
+ if not self.rmsnorm:
415
+ y = y * self.act(z) # (B D)
416
+ else:
417
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
418
+ dtype=torch.float32
419
+ )
420
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
421
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
422
+ D = repeat(self.D, "h -> h p", p=self.headdim)
423
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
424
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
425
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
426
+ if not self.rmsnorm:
427
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
428
+ y = selective_state_update(
429
+ ssm_state,
430
+ x_reshaped,
431
+ dt,
432
+ A,
433
+ B,
434
+ C,
435
+ D,
436
+ z=z if not self.rmsnorm else None,
437
+ dt_bias=dt_bias,
438
+ dt_softplus=True,
439
+ )
440
+ y = rearrange(y, "b h p -> b (h p)")
441
+ if self.rmsnorm:
442
+ y = self.norm(y, z)
443
+ if d_mlp > 0:
444
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
445
+ out = self.out_proj(y)
446
+ return out.unsqueeze(1), conv_state, ssm_state
447
+
448
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
449
+ device = self.out_proj.weight.device
450
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
451
+ conv_state = torch.zeros(
452
+ batch_size,
453
+ self.d_conv,
454
+ self.conv1d.weight.shape[0],
455
+ device=device,
456
+ dtype=conv_dtype,
457
+ ).transpose(1, 2)
458
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
459
+ ssm_state = torch.zeros(
460
+ batch_size,
461
+ self.nheads,
462
+ self.headdim,
463
+ self.d_state,
464
+ device=device,
465
+ dtype=ssm_dtype,
466
+ )
467
+ return conv_state, ssm_state
468
+
469
+ def _get_states_from_cache(
470
+ self, inference_params, batch_size, initialize_states=False
471
+ ):
472
+ assert self.layer_idx is not None
473
+ if self.layer_idx not in inference_params.key_value_memory_dict:
474
+ batch_shape = (batch_size,)
475
+ conv_state = torch.zeros(
476
+ batch_size,
477
+ self.d_conv,
478
+ self.conv1d.weight.shape[0],
479
+ device=self.conv1d.weight.device,
480
+ dtype=self.conv1d.weight.dtype,
481
+ ).transpose(1, 2)
482
+ ssm_state = torch.zeros(
483
+ batch_size,
484
+ self.nheads,
485
+ self.headdim,
486
+ self.d_state,
487
+ device=self.in_proj.weight.device,
488
+ dtype=self.in_proj.weight.dtype,
489
+ )
490
+ inference_params.key_value_memory_dict[self.layer_idx] = (
491
+ conv_state,
492
+ ssm_state,
493
+ )
494
+ else:
495
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
496
+ self.layer_idx
497
+ ]
498
+ # TODO: What if batch size changes between generation, and we reuse the same states?
499
+ if initialize_states:
500
+ conv_state.zero_()
501
+ ssm_state.zero_()
502
+ return conv_state, ssm_state
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2_simple.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ try:
11
+ from causal_conv1d import causal_conv1d_fn
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+
15
+ try:
16
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
17
+ except ImportError:
18
+ RMSNormGated, LayerNorm = None, None
19
+
20
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
21
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
22
+
23
+
24
+ class Mamba2Simple(nn.Module):
25
+ def __init__(
26
+ self,
27
+ d_model,
28
+ d_state=64,
29
+ d_conv=4,
30
+ conv_init=None,
31
+ expand=2,
32
+ headdim=128,
33
+ ngroups=1,
34
+ A_init_range=(1, 16),
35
+ dt_min=0.001,
36
+ dt_max=0.1,
37
+ dt_init_floor=1e-4,
38
+ dt_limit=(0.0, float("inf")),
39
+ learnable_init_states=False,
40
+ activation="swish",
41
+ bias=False,
42
+ conv_bias=True,
43
+ # Fused kernel and sharding options
44
+ chunk_size=256,
45
+ use_mem_eff_path=True,
46
+ layer_idx=None, # Absorb kwarg for general module
47
+ device=None,
48
+ dtype=None,
49
+ ):
50
+ factory_kwargs = {"device": device, "dtype": dtype}
51
+ super().__init__()
52
+ self.d_model = d_model
53
+ self.d_state = d_state
54
+ self.d_conv = d_conv
55
+ self.conv_init = conv_init
56
+ self.expand = expand
57
+ self.d_inner = self.expand * self.d_model
58
+ self.headdim = headdim
59
+ self.ngroups = ngroups
60
+ assert self.d_inner % self.headdim == 0
61
+ self.nheads = self.d_inner // self.headdim
62
+ self.dt_limit = dt_limit
63
+ self.learnable_init_states = learnable_init_states
64
+ self.activation = activation
65
+ self.chunk_size = chunk_size
66
+ self.use_mem_eff_path = use_mem_eff_path
67
+ self.layer_idx = layer_idx
68
+
69
+ # Order: [z, x, B, C, dt]
70
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
71
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
72
+
73
+ conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
74
+ self.conv1d = nn.Conv1d(
75
+ in_channels=conv_dim,
76
+ out_channels=conv_dim,
77
+ bias=conv_bias,
78
+ kernel_size=d_conv,
79
+ groups=conv_dim,
80
+ padding=d_conv - 1,
81
+ **factory_kwargs,
82
+ )
83
+ if self.conv_init is not None:
84
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
85
+ # self.conv1d.weight._no_weight_decay = True
86
+
87
+ if self.learnable_init_states:
88
+ self.init_states = nn.Parameter(
89
+ torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
90
+ )
91
+ self.init_states._no_weight_decay = True
92
+
93
+ self.act = nn.SiLU()
94
+
95
+ # Initialize log dt bias
96
+ dt = torch.exp(
97
+ torch.rand(self.nheads, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ )
101
+ dt = torch.clamp(dt, min=dt_init_floor)
102
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
103
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
104
+ self.dt_bias = nn.Parameter(inv_dt)
105
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
106
+ # name.endswith("bias") in param_grouping.py
107
+ self.dt_bias._no_weight_decay = True
108
+
109
+ # A parameter
110
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
111
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
112
+ *A_init_range
113
+ )
114
+ A_log = torch.log(A).to(dtype=dtype)
115
+ self.A_log = nn.Parameter(A_log)
116
+ # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
117
+ self.A_log._no_weight_decay = True
118
+
119
+ # D "skip" parameter
120
+ self.D = nn.Parameter(torch.ones(self.nheads, device=device))
121
+ self.D._no_weight_decay = True
122
+
123
+ # Extra normalization layer right before output projection
124
+ assert RMSNormGated is not None
125
+ self.norm = RMSNormGated(
126
+ self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
127
+ )
128
+
129
+ self.out_proj = nn.Linear(
130
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
131
+ )
132
+
133
+ def forward(self, u, seq_idx=None):
134
+ """
135
+ u: (B, L, D)
136
+ Returns: same shape as u
137
+ """
138
+ batch, seqlen, dim = u.shape
139
+
140
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
141
+ A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
142
+ initial_states = (
143
+ repeat(self.init_states, "... -> b ...", b=batch)
144
+ if self.learnable_init_states
145
+ else None
146
+ )
147
+ dt_limit_kwargs = (
148
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
149
+ )
150
+
151
+ if self.use_mem_eff_path:
152
+ # Fully fused path
153
+ out = mamba_split_conv1d_scan_combined(
154
+ zxbcdt,
155
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
156
+ self.conv1d.bias,
157
+ self.dt_bias,
158
+ A,
159
+ D=self.D,
160
+ chunk_size=self.chunk_size,
161
+ seq_idx=seq_idx,
162
+ activation=self.activation,
163
+ rmsnorm_weight=self.norm.weight,
164
+ rmsnorm_eps=self.norm.eps,
165
+ outproj_weight=self.out_proj.weight,
166
+ outproj_bias=self.out_proj.bias,
167
+ headdim=self.headdim,
168
+ ngroups=self.ngroups,
169
+ norm_before_gate=False,
170
+ initial_states=initial_states,
171
+ **dt_limit_kwargs,
172
+ )
173
+ else:
174
+ z, xBC, dt = torch.split(
175
+ zxbcdt,
176
+ [
177
+ self.d_inner,
178
+ self.d_inner + 2 * self.ngroups * self.d_state,
179
+ self.nheads,
180
+ ],
181
+ dim=-1,
182
+ )
183
+ dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
184
+ assert self.activation in ["silu", "swish"]
185
+
186
+ # 1D Convolution
187
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
188
+ xBC = self.act(
189
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
190
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
191
+ xBC = xBC[:, :seqlen, :]
192
+ else:
193
+ xBC = causal_conv1d_fn(
194
+ x=xBC.transpose(1, 2),
195
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
196
+ bias=self.conv1d.bias,
197
+ activation=self.activation,
198
+ ).transpose(1, 2)
199
+
200
+ # Split into 3 main branches: X, B, C
201
+ # These correspond to V, K, Q respectively in the SSM/attention duality
202
+ x, B, C = torch.split(
203
+ xBC,
204
+ [
205
+ self.d_inner,
206
+ self.ngroups * self.d_state,
207
+ self.ngroups * self.d_state,
208
+ ],
209
+ dim=-1,
210
+ )
211
+ y = mamba_chunk_scan_combined(
212
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
213
+ dt,
214
+ A,
215
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
216
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
217
+ chunk_size=self.chunk_size,
218
+ D=self.D,
219
+ z=None,
220
+ seq_idx=seq_idx,
221
+ initial_states=initial_states,
222
+ **dt_limit_kwargs,
223
+ )
224
+ y = rearrange(y, "b l h p -> b l (h p)")
225
+
226
+ # Multiply "gate" branch and apply extra normalization layer
227
+ y = self.norm(y, z)
228
+ out = self.out_proj(y)
229
+ return out
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from ..ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(
63
+ self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
64
+ )
65
+
66
+ self.conv1d = nn.Conv1d(
67
+ in_channels=self.d_inner,
68
+ out_channels=self.d_inner,
69
+ bias=conv_bias,
70
+ kernel_size=d_conv,
71
+ groups=self.d_inner,
72
+ padding=d_conv - 1,
73
+ **factory_kwargs,
74
+ )
75
+
76
+ self.activation = "silu"
77
+ self.act = nn.SiLU()
78
+
79
+ self.x_proj = nn.Linear(
80
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
81
+ )
82
+ self.dt_proj = nn.Linear(
83
+ self.dt_rank, self.d_inner, bias=True, **factory_kwargs
84
+ )
85
+
86
+ # Initialize special dt projection to preserve variance at initialization
87
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
88
+ if dt_init == "constant":
89
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
90
+ elif dt_init == "random":
91
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
96
+ dt = torch.exp(
97
+ torch.rand(self.d_inner, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ ).clamp(min=dt_init_floor)
101
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
102
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
103
+ with torch.no_grad():
104
+ self.dt_proj.bias.copy_(inv_dt)
105
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
106
+ self.dt_proj.bias._no_reinit = True
107
+
108
+ # S4D real initialization
109
+ A = repeat(
110
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
111
+ "n -> d n",
112
+ d=self.d_inner,
113
+ ).contiguous()
114
+ A_log = torch.log(A) # Keep A_log in fp32
115
+ self.A_log = nn.Parameter(A_log)
116
+ self.A_log._no_weight_decay = True
117
+
118
+ # D "skip" parameter
119
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
120
+ self.D._no_weight_decay = True
121
+
122
+ self.out_proj = nn.Linear(
123
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
124
+ )
125
+
126
+ def forward(self, hidden_states, inference_params=None):
127
+ """
128
+ hidden_states: (B, L, D)
129
+ Returns: same shape as hidden_states
130
+ """
131
+ batch, seqlen, dim = hidden_states.shape
132
+
133
+ conv_state, ssm_state = None, None
134
+ if inference_params is not None:
135
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
136
+ if inference_params.seqlen_offset > 0:
137
+ # The states are updated inplace
138
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
139
+ return out
140
+
141
+ # We do matmul and transpose BLH -> HBL at the same time
142
+ xz = rearrange(
143
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
144
+ "d (b l) -> b d l",
145
+ l=seqlen,
146
+ )
147
+ if self.in_proj.bias is not None:
148
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
149
+
150
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
151
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
152
+ if (
153
+ self.use_fast_path
154
+ and causal_conv1d_fn is not None
155
+ and inference_params is None
156
+ ): # Doesn't support outputting the states
157
+ out = mamba_inner_fn(
158
+ xz,
159
+ self.conv1d.weight,
160
+ self.conv1d.bias,
161
+ self.x_proj.weight,
162
+ self.dt_proj.weight,
163
+ self.out_proj.weight,
164
+ self.out_proj.bias,
165
+ A,
166
+ None, # input-dependent B
167
+ None, # input-dependent C
168
+ self.D.float(),
169
+ delta_bias=self.dt_proj.bias.float(),
170
+ delta_softplus=True,
171
+ )
172
+ else:
173
+ x, z = xz.chunk(2, dim=1)
174
+ # Compute short convolution
175
+ if conv_state is not None:
176
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
177
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
178
+ conv_state.copy_(
179
+ F.pad(x, (self.d_conv - x.shape[-1], 0))
180
+ ) # Update state (B D W)
181
+ if causal_conv1d_fn is None:
182
+ x = self.act(self.conv1d(x)[..., :seqlen])
183
+ else:
184
+ assert self.activation in ["silu", "swish"]
185
+ x = causal_conv1d_fn(
186
+ x=x,
187
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
188
+ bias=self.conv1d.bias,
189
+ activation=self.activation,
190
+ )
191
+
192
+ # We're careful here about the layout, to avoid extra transposes.
193
+ # We want dt to have d as the slowest moving dimension
194
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
195
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
196
+ dt, B, C = torch.split(
197
+ x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
198
+ )
199
+ dt = self.dt_proj.weight @ dt.t()
200
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
201
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
202
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
203
+ assert self.activation in ["silu", "swish"]
204
+ y = selective_scan_fn(
205
+ x,
206
+ dt,
207
+ A,
208
+ B,
209
+ C,
210
+ self.D.float(),
211
+ z=z,
212
+ delta_bias=self.dt_proj.bias.float(),
213
+ delta_softplus=True,
214
+ return_last_state=ssm_state is not None,
215
+ )
216
+ if ssm_state is not None:
217
+ y, last_state = y
218
+ ssm_state.copy_(last_state)
219
+ y = rearrange(y, "b d l -> b l d")
220
+ out = self.out_proj(y)
221
+ return out
222
+
223
+ def step(self, hidden_states, conv_state, ssm_state):
224
+ dtype = hidden_states.dtype
225
+ assert (
226
+ hidden_states.shape[1] == 1
227
+ ), "Only support decoding with 1 token at a time for now"
228
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
229
+ x, z = xz.chunk(2, dim=-1) # (B D)
230
+
231
+ # Conv step
232
+ if causal_conv1d_update is None:
233
+ conv_state.copy_(
234
+ torch.roll(conv_state, shifts=-1, dims=-1)
235
+ ) # Update state (B D W)
236
+ conv_state[:, :, -1] = x
237
+ x = torch.sum(
238
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
239
+ ) # (B D)
240
+ if self.conv1d.bias is not None:
241
+ x = x + self.conv1d.bias
242
+ x = self.act(x).to(dtype=dtype)
243
+ else:
244
+ x = causal_conv1d_update(
245
+ x,
246
+ conv_state,
247
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
248
+ self.conv1d.bias,
249
+ self.activation,
250
+ )
251
+
252
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
253
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
254
+ # Don't add dt_bias here
255
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
256
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
257
+
258
+ # SSM step
259
+ if selective_state_update is None:
260
+ # Discretize A and B
261
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
262
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
263
+ dB = torch.einsum("bd,bn->bdn", dt, B)
264
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
265
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
266
+ y = y + self.D.to(dtype) * x
267
+ y = y * self.act(z) # (B D)
268
+ else:
269
+ y = selective_state_update(
270
+ ssm_state,
271
+ x,
272
+ dt,
273
+ A,
274
+ B,
275
+ C,
276
+ self.D,
277
+ z=z,
278
+ dt_bias=self.dt_proj.bias,
279
+ dt_softplus=True,
280
+ )
281
+
282
+ out = self.out_proj(y)
283
+ return out.unsqueeze(1), conv_state, ssm_state
284
+
285
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
286
+ device = self.out_proj.weight.device
287
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
288
+ conv_state = torch.zeros(
289
+ batch_size,
290
+ self.d_model * self.expand,
291
+ self.d_conv,
292
+ device=device,
293
+ dtype=conv_dtype,
294
+ )
295
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
296
+ # ssm_dtype = torch.float32
297
+ ssm_state = torch.zeros(
298
+ batch_size,
299
+ self.d_model * self.expand,
300
+ self.d_state,
301
+ device=device,
302
+ dtype=ssm_dtype,
303
+ )
304
+ return conv_state, ssm_state
305
+
306
+ def _get_states_from_cache(
307
+ self, inference_params, batch_size, initialize_states=False
308
+ ):
309
+ assert self.layer_idx is not None
310
+ if self.layer_idx not in inference_params.key_value_memory_dict:
311
+ batch_shape = (batch_size,)
312
+ conv_state = torch.zeros(
313
+ batch_size,
314
+ self.d_model * self.expand,
315
+ self.d_conv,
316
+ device=self.conv1d.weight.device,
317
+ dtype=self.conv1d.weight.dtype,
318
+ )
319
+ ssm_state = torch.zeros(
320
+ batch_size,
321
+ self.d_model * self.expand,
322
+ self.d_state,
323
+ device=self.dt_proj.weight.device,
324
+ dtype=self.dt_proj.weight.dtype,
325
+ # dtype=torch.float32,
326
+ )
327
+ inference_params.key_value_memory_dict[self.layer_idx] = (
328
+ conv_state,
329
+ ssm_state,
330
+ )
331
+ else:
332
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
333
+ self.layer_idx
334
+ ]
335
+ # TODO: What if batch size changes between generation, and we reuse the same states?
336
+ if initialize_states:
337
+ conv_state.zero_()
338
+ ssm_state.zero_()
339
+ return conv_state, ssm_state
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mha.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ try:
11
+ from flash_attn import flash_attn_with_kvcache
12
+ except ImportError:
13
+ flash_attn_with_kvcache = None
14
+
15
+ try:
16
+ from flash_attn.layers.rotary import RotaryEmbedding
17
+ except ImportError:
18
+ RotaryEmbedding = None
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn, causal_conv1d_update = None, None
24
+
25
+
26
+ def _update_kv_cache(kv, inference_params, layer_idx):
27
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
28
+ # Pre-allocate memory for key-values for inference.
29
+ num_heads, head_dim = kv.shape[-2:]
30
+ assert layer_idx in inference_params.key_value_memory_dict
31
+ kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
32
+ # Adjust key and value for inference
33
+ batch_start = inference_params.batch_size_offset
34
+ batch_end = batch_start + kv.shape[0]
35
+ sequence_start = inference_params.seqlen_offset
36
+ sequence_end = sequence_start + kv.shape[1]
37
+ assert batch_end <= kv_cache.shape[0]
38
+ assert sequence_end <= kv_cache.shape[1]
39
+ assert kv_cache is not None
40
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
41
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
42
+
43
+
44
+ class MHA(nn.Module):
45
+ """Multi-head self-attention and cross-attention"""
46
+
47
+ def __init__(
48
+ self,
49
+ embed_dim,
50
+ num_heads,
51
+ num_heads_kv=None,
52
+ head_dim=None, # If None, use embed_dim // num_heads
53
+ mlp_dim=0,
54
+ qkv_proj_bias=True,
55
+ out_proj_bias=True,
56
+ softmax_scale=None,
57
+ causal=False,
58
+ layer_idx=None,
59
+ d_conv=0,
60
+ rotary_emb_dim=0,
61
+ rotary_emb_base=10000.0,
62
+ rotary_emb_interleaved=False,
63
+ device=None,
64
+ dtype=None,
65
+ ) -> None:
66
+ """
67
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
68
+ return_residual: whether to return the input x along with the output. This is for
69
+ performance reason: for post-norm architecture, returning the input allows us
70
+ to fuse the backward of nn.Linear with the residual connection.
71
+ """
72
+ factory_kwargs = {"device": device, "dtype": dtype}
73
+ super().__init__()
74
+ self.embed_dim = embed_dim
75
+ self.layer_idx = layer_idx
76
+ self.d_conv = d_conv
77
+ self.rotary_emb_dim = rotary_emb_dim
78
+ self.softmax_scale = softmax_scale
79
+ self.causal = causal
80
+
81
+ self.num_heads = num_heads
82
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
83
+ assert (
84
+ self.num_heads % self.num_heads_kv == 0
85
+ ), "num_heads must be divisible by num_heads_kv"
86
+ if head_dim is None:
87
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
88
+ self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
89
+ self.mlp_dim = math.ceil(mlp_dim / 256) * 256
90
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
91
+ out_dim = self.head_dim * self.num_heads
92
+
93
+ if self.rotary_emb_dim > 0:
94
+ assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
95
+ self.rotary_emb = RotaryEmbedding(
96
+ self.rotary_emb_dim,
97
+ base=rotary_emb_base,
98
+ interleaved=rotary_emb_interleaved,
99
+ device=device,
100
+ )
101
+
102
+ self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
103
+ if self.d_conv > 0:
104
+ self.conv1d = nn.Conv1d(
105
+ qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
106
+ **factory_kwargs
107
+ )
108
+ self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
109
+
110
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
111
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
112
+ device = self.out_proj.weight.device
113
+ if self.d_conv > 0:
114
+ conv_state = torch.zeros(
115
+ batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
116
+ )
117
+ else:
118
+ conv_state = None
119
+ kv_cache = torch.empty(
120
+ batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
121
+ )
122
+ return kv_cache, conv_state
123
+
124
+ def _update_kv_cache(self, kv, inference_params):
125
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
126
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
127
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
128
+
129
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
130
+ """
131
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
132
+ q: (batch_size, seqlen_q, nheads, head_dim)
133
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
134
+ """
135
+ assert inference_params is not None and inference_params.seqlen_offset > 0
136
+ if self.rotary_emb_dim > 0:
137
+ self.rotary_emb._update_cos_sin_cache(
138
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
139
+ )
140
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
141
+ else:
142
+ rotary_cos, rotary_sin = None, None
143
+ batch = q.shape[0]
144
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
145
+ kv_cache = kv_cache[:batch]
146
+ cache_seqlens = (
147
+ inference_params.lengths_per_sample[:batch]
148
+ if inference_params.lengths_per_sample is not None
149
+ else inference_params.seqlen_offset
150
+ )
151
+ assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
152
+ context = flash_attn_with_kvcache(
153
+ q,
154
+ kv_cache[:, :, 0],
155
+ kv_cache[:, :, 1],
156
+ kv[:, :, 0],
157
+ kv[:, :, 1],
158
+ rotary_cos=rotary_cos,
159
+ rotary_sin=rotary_sin,
160
+ cache_seqlens=cache_seqlens,
161
+ softmax_scale=self.softmax_scale,
162
+ causal=self.causal,
163
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
164
+ )
165
+ return context
166
+
167
+ def _update_kvcache_attention(self, q, kv, inference_params):
168
+ """Write kv to inference_params, then do attention"""
169
+ if (
170
+ inference_params.seqlen_offset == 0
171
+ or flash_attn_with_kvcache is None
172
+ ):
173
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
174
+ kv = self._update_kv_cache(kv, inference_params)
175
+ k, v = kv.unbind(dim=-3)
176
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
177
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
178
+ return F.scaled_dot_product_attention(
179
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
180
+ ).transpose(1, 2)
181
+ else:
182
+ batch = q.shape[0]
183
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184
+ kv_cache = kv_cache[:batch]
185
+ cache_seqlens = (
186
+ inference_params.lengths_per_sample[:batch]
187
+ if inference_params.lengths_per_sample is not None
188
+ else inference_params.seqlen_offset
189
+ )
190
+ return flash_attn_with_kvcache(
191
+ q,
192
+ kv_cache[:, :, 0],
193
+ kv_cache[:, :, 1],
194
+ kv[:, :, 0],
195
+ kv[:, :, 1],
196
+ cache_seqlens=cache_seqlens,
197
+ softmax_scale=self.softmax_scale,
198
+ causal=self.causal,
199
+ )
200
+
201
+ def forward(self, x, inference_params=None):
202
+ """
203
+ Arguments:
204
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
205
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
206
+ is the is the sum of the sequence lengths in the batch.
207
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
208
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
209
+ """
210
+ if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
211
+ inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
212
+ x.shape[0], inference_params.max_seqlen, dtype=x.dtype
213
+ )
214
+ seqlen_offset = (
215
+ 0
216
+ if inference_params is None
217
+ else (
218
+ inference_params.lengths_per_sample
219
+ if inference_params.lengths_per_sample is not None
220
+ else inference_params.seqlen_offset
221
+ )
222
+ )
223
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
224
+ qkv = self.in_proj(x)
225
+ if self.mlp_dim > 0:
226
+ qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
227
+ x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
228
+ x_mlp = x_mlp_up * F.silu(x_mlp_gate)
229
+ if self.d_conv > 0:
230
+ # The inference code for conv1d is pretty messy, should clean it up
231
+ if (inference_params is None or inference_params.seqlen_offset == 0):
232
+ if causal_conv1d_fn is None:
233
+ qkv = rearrange(
234
+ self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
235
+ ).contiguous()
236
+ else:
237
+ qkv = causal_conv1d_fn(
238
+ qkv.transpose(1, 2),
239
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
240
+ self.conv1d.bias
241
+ ).transpose(1, 2)
242
+ if inference_params is not None:
243
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
244
+ # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
245
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
246
+ qkv_t = rearrange(qkv, "b l d -> b d l")
247
+ conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
248
+ else:
249
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
250
+ assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
251
+ qkv = qkv.squeeze(1)
252
+ # Conv step
253
+ if causal_conv1d_update is None:
254
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
255
+ conv_state[:, :, -1] = qkv
256
+ qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
257
+ if self.conv1d.bias is not None:
258
+ qkv = qkv + self.conv1d.bias
259
+ else:
260
+ qkv = causal_conv1d_update(
261
+ qkv,
262
+ conv_state,
263
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
264
+ self.conv1d.bias
265
+ )
266
+ qkv = qkv.unsqueeze(1)
267
+ q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
268
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
269
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
270
+ if (
271
+ inference_params is None
272
+ or inference_params.seqlen_offset == 0
273
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
274
+ ):
275
+ if self.rotary_emb_dim > 0:
276
+ q, kv = self.rotary_emb(
277
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
278
+ )
279
+ if inference_params is None:
280
+ k, v = kv.unbind(dim=-3)
281
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
282
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
283
+ context = F.scaled_dot_product_attention(
284
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
285
+ ).transpose(1, 2)
286
+ else:
287
+ context = self._update_kvcache_attention(q, kv, inference_params)
288
+ else:
289
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
290
+ context = rearrange(context, "... h d -> ... (h d)")
291
+ if self.mlp_dim > 0:
292
+ context = torch.cat([context, x_mlp], dim=-1)
293
+ out = self.out_proj(context)
294
+ return out
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mlp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class GatedMLP(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ hidden_features=None,
11
+ out_features=None,
12
+ activation=F.silu,
13
+ bias=False,
14
+ multiple_of=128,
15
+ device=None,
16
+ dtype=None,
17
+ ):
18
+ factory_kwargs = {"device": device, "dtype": dtype}
19
+ super().__init__()
20
+ out_features = out_features if out_features is not None else in_features
21
+ hidden_features = (
22
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
23
+ )
24
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
25
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
26
+ self.activation = activation
27
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
28
+
29
+ def forward(self, x):
30
+ y = self.fc1(x)
31
+ y, gate = y.chunk(2, dim=-1)
32
+ y = y * self.activation(gate)
33
+ y = self.fc2(y)
34
+ return y
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/ssd_minimal.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Albert Gu and Tri Dao.
2
+ """Minimal implementation of SSD.
3
+
4
+ This is the same as Listing 1 from the paper.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
12
+
13
+
14
+ def segsum_unstable(x):
15
+ """Naive segment sum calculation."""
16
+ T = x.size(-1)
17
+ x_cumsum = torch.cumsum(x, dim=-1)
18
+ x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
19
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
20
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
21
+ return x_segsum
22
+
23
+
24
+ def segsum(x):
25
+ """More stable segment sum calculation."""
26
+ T = x.size(-1)
27
+ x = repeat(x, "... d -> ... d e", e=T)
28
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
29
+ x = x.masked_fill(~mask, 0)
30
+ x_segsum = torch.cumsum(x, dim=-2)
31
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
32
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
33
+ return x_segsum
34
+
35
+
36
+ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
37
+ """
38
+ Arguments:
39
+ X: (batch, length, n_heads, d_head)
40
+ A: (batch, length, n_heads)
41
+ B: (batch, length, n_heads, d_state)
42
+ C: (batch, length, n_heads, d_state)
43
+ Return:
44
+ Y: (batch, length, n_heads, d_head)
45
+ """
46
+ assert X.dtype == A.dtype == B.dtype == C.dtype
47
+ assert X.shape[1] % block_len == 0
48
+
49
+ # Rearrange into blocks/chunks
50
+ X, A, B, C = [
51
+ rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
52
+ ]
53
+
54
+ A = rearrange(A, "b c l h -> b h c l")
55
+ A_cumsum = torch.cumsum(A, dim=-1)
56
+
57
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
58
+ L = torch.exp(segsum(A))
59
+ Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
60
+
61
+ # 2. Compute the state for each intra-chunk
62
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
63
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
64
+ states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
65
+
66
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
67
+ # (middle term of factorization of off-diag blocks; A terms)
68
+ if initial_states is None:
69
+ initial_states = torch.zeros_like(states[:, :1])
70
+ states = torch.cat([initial_states, states], dim=1)
71
+ decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
72
+ new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
73
+ states, final_state = new_states[:, :-1], new_states[:, -1]
74
+
75
+ # 4. Compute state -> output conversion per chunk
76
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
77
+ state_decay_out = torch.exp(A_cumsum)
78
+ Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
79
+
80
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
81
+ Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
82
+ return Y, final_state
83
+
84
+
85
+ # Simple test
86
+ def test_correctness():
87
+ torch.manual_seed(42)
88
+
89
+ ## Dimensions
90
+ # Denoted (B, T, Q, D, P) in the paper
91
+ batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
92
+ nheads = dim // headdim # (H) in the paper
93
+ ngroups = 1 # (G) in the paper
94
+ dstate = 64 # (N) in the paper
95
+ dtype = torch.float32
96
+ device = "cuda"
97
+
98
+ x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
99
+ dt = F.softplus(
100
+ torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4
101
+ ).requires_grad_()
102
+ A = (
103
+ -torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))
104
+ ).requires_grad_()
105
+ B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
106
+ C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
107
+ D = torch.randn(nheads, dtype=dtype, device=device)
108
+
109
+ # Comparing fused version and minimal version
110
+ y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
111
+ y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/__init__.py ADDED
File without changes
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from ..utils.torch import custom_fwd, custom_bwd
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ try:
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
16
+ from .triton.layer_norm import _layer_norm_fwd
17
+
18
+ from .._ops import ops
19
+
20
+
21
+ class SelectiveScanFn(torch.autograd.Function):
22
+
23
+ @staticmethod
24
+ def forward(
25
+ ctx,
26
+ u,
27
+ delta,
28
+ A,
29
+ B,
30
+ C,
31
+ D=None,
32
+ z=None,
33
+ delta_bias=None,
34
+ delta_softplus=False,
35
+ return_last_state=False,
36
+ ):
37
+ if u.stride(-1) != 1:
38
+ u = u.contiguous()
39
+ if delta.stride(-1) != 1:
40
+ delta = delta.contiguous()
41
+ if D is not None:
42
+ D = D.contiguous()
43
+ if B.stride(-1) != 1:
44
+ B = B.contiguous()
45
+ if C.stride(-1) != 1:
46
+ C = C.contiguous()
47
+ if z is not None and z.stride(-1) != 1:
48
+ z = z.contiguous()
49
+ if B.dim() == 3:
50
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
51
+ ctx.squeeze_B = True
52
+ if C.dim() == 3:
53
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
54
+ ctx.squeeze_C = True
55
+ out, x, *rest = ops.selective_scan_fwd(
56
+ u, delta, A, B, C, D, z, delta_bias, delta_softplus
57
+ )
58
+ ctx.delta_softplus = delta_softplus
59
+ ctx.has_z = z is not None
60
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
61
+ if not ctx.has_z:
62
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
63
+ return out if not return_last_state else (out, last_state)
64
+ else:
65
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
66
+ out_z = rest[0]
67
+ return out_z if not return_last_state else (out_z, last_state)
68
+
69
+ @staticmethod
70
+ def backward(ctx, dout, *args):
71
+ if not ctx.has_z:
72
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
73
+ z = None
74
+ out = None
75
+ else:
76
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
77
+ if dout.stride(-1) != 1:
78
+ dout = dout.contiguous()
79
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
80
+ # backward of selective_scan_cuda with the backward of chunk).
81
+ # Here we just pass in None and dz will be allocated in the C++ code.
82
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
83
+ u,
84
+ delta,
85
+ A,
86
+ B,
87
+ C,
88
+ D,
89
+ z,
90
+ delta_bias,
91
+ dout,
92
+ x,
93
+ out,
94
+ None,
95
+ ctx.delta_softplus,
96
+ False, # option to recompute out_z, not used here
97
+ )
98
+ dz = rest[0] if ctx.has_z else None
99
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
100
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
101
+ return (
102
+ du,
103
+ ddelta,
104
+ dA,
105
+ dB,
106
+ dC,
107
+ dD if D is not None else None,
108
+ dz,
109
+ ddelta_bias if delta_bias is not None else None,
110
+ None,
111
+ None,
112
+ )
113
+
114
+
115
+ def rms_norm_forward(
116
+ x,
117
+ weight,
118
+ bias,
119
+ eps=1e-6,
120
+ is_rms_norm=True,
121
+ ):
122
+ # x (b l) d
123
+ if x.stride(-1) != 1:
124
+ x = x.contiguous()
125
+ weight = weight.contiguous()
126
+ if bias is not None:
127
+ bias = bias.contiguous()
128
+ y = _layer_norm_fwd(
129
+ x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
130
+ )[0]
131
+ # y (b l) d
132
+ return y
133
+
134
+
135
+ def selective_scan_fn(
136
+ u,
137
+ delta,
138
+ A,
139
+ B,
140
+ C,
141
+ D=None,
142
+ z=None,
143
+ delta_bias=None,
144
+ delta_softplus=False,
145
+ return_last_state=False,
146
+ ):
147
+ """if return_last_state is True, returns (out, last_state)
148
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
149
+ not considered in the backward pass.
150
+ """
151
+ return SelectiveScanFn.apply(
152
+ u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
153
+ )
154
+
155
+
156
+ def selective_scan_ref(
157
+ u,
158
+ delta,
159
+ A,
160
+ B,
161
+ C,
162
+ D=None,
163
+ z=None,
164
+ delta_bias=None,
165
+ delta_softplus=False,
166
+ return_last_state=False,
167
+ ):
168
+ """
169
+ u: r(B D L)
170
+ delta: r(B D L)
171
+ A: c(D N) or r(D N)
172
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
173
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
174
+ D: r(D)
175
+ z: r(B D L)
176
+ delta_bias: r(D), fp32
177
+
178
+ out: r(B D L)
179
+ last_state (optional): r(B D dstate) or c(B D dstate)
180
+ """
181
+ dtype_in = u.dtype
182
+ u = u.float()
183
+ delta = delta.float()
184
+ if delta_bias is not None:
185
+ delta = delta + delta_bias[..., None].float()
186
+ if delta_softplus:
187
+ delta = F.softplus(delta)
188
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
189
+ is_variable_B = B.dim() >= 3
190
+ is_variable_C = C.dim() >= 3
191
+ if A.is_complex():
192
+ if is_variable_B:
193
+ B = torch.view_as_complex(
194
+ rearrange(B.float(), "... (L two) -> ... L two", two=2)
195
+ )
196
+ if is_variable_C:
197
+ C = torch.view_as_complex(
198
+ rearrange(C.float(), "... (L two) -> ... L two", two=2)
199
+ )
200
+ else:
201
+ B = B.float()
202
+ C = C.float()
203
+ x = A.new_zeros((batch, dim, dstate))
204
+ ys = []
205
+ deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
206
+ if not is_variable_B:
207
+ deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
208
+ else:
209
+ if B.dim() == 3:
210
+ deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
211
+ else:
212
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
213
+ deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
214
+ if is_variable_C and C.dim() == 4:
215
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
216
+ last_state = None
217
+ for i in range(u.shape[2]):
218
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
219
+ if not is_variable_C:
220
+ y = torch.einsum("bdn,dn->bd", x, C)
221
+ else:
222
+ if C.dim() == 3:
223
+ y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
224
+ else:
225
+ y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
226
+ if i == u.shape[2] - 1:
227
+ last_state = x
228
+ if y.is_complex():
229
+ y = y.real * 2
230
+ ys.append(y)
231
+ y = torch.stack(ys, dim=2) # (batch dim L)
232
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
233
+ if z is not None:
234
+ out = out * F.silu(z)
235
+ out = out.to(dtype=dtype_in)
236
+ return out if not return_last_state else (out, last_state)
237
+
238
+
239
+ class MambaInnerFn(torch.autograd.Function):
240
+
241
+ @staticmethod
242
+ @custom_fwd
243
+ def forward(
244
+ ctx,
245
+ xz,
246
+ conv1d_weight,
247
+ conv1d_bias,
248
+ x_proj_weight,
249
+ delta_proj_weight,
250
+ out_proj_weight,
251
+ out_proj_bias,
252
+ A,
253
+ B=None,
254
+ C=None,
255
+ D=None,
256
+ delta_bias=None,
257
+ B_proj_bias=None,
258
+ C_proj_bias=None,
259
+ delta_softplus=True,
260
+ checkpoint_lvl=1,
261
+ b_rms_weight=None,
262
+ c_rms_weight=None,
263
+ dt_rms_weight=None,
264
+ b_c_dt_rms_eps=1e-6,
265
+ ):
266
+ """
267
+ xz: (batch, dim, seqlen)
268
+ """
269
+ assert (
270
+ causal_conv1d_cuda is not None
271
+ ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
272
+ assert checkpoint_lvl in [0, 1]
273
+ L = xz.shape[-1]
274
+ delta_rank = delta_proj_weight.shape[1]
275
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
276
+ if torch.is_autocast_enabled():
277
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
278
+ delta_proj_weight = delta_proj_weight.to(
279
+ dtype=torch.get_autocast_gpu_dtype()
280
+ )
281
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
282
+ out_proj_bias = (
283
+ out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
284
+ if out_proj_bias is not None
285
+ else None
286
+ )
287
+ if xz.stride(-1) != 1:
288
+ xz = xz.contiguous()
289
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
290
+ x, z = xz.chunk(2, dim=1)
291
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
292
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
293
+ x, conv1d_weight, conv1d_bias, None, None, None, True
294
+ )
295
+ # We're being very careful here about the layout, to avoid extra transposes.
296
+ # We want delta to have d as the slowest moving dimension
297
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
298
+ x_dbl = F.linear(
299
+ rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
300
+ ) # (bl d)
301
+ delta = rearrange(
302
+ delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
303
+ )
304
+ ctx.is_variable_B = B is None
305
+ ctx.is_variable_C = C is None
306
+ ctx.B_proj_bias_is_None = B_proj_bias is None
307
+ ctx.C_proj_bias_is_None = C_proj_bias is None
308
+ if B is None: # variable B
309
+ B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
310
+ if B_proj_bias is not None:
311
+ B = B + B_proj_bias.to(dtype=B.dtype)
312
+ if not A.is_complex():
313
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
314
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
315
+ else:
316
+ B = rearrange(
317
+ B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
318
+ ).contiguous()
319
+ else:
320
+ if B.stride(-1) != 1:
321
+ B = B.contiguous()
322
+ if C is None: # variable C
323
+ C = x_dbl[:, -d_state:] # (bl dstate)
324
+ if C_proj_bias is not None:
325
+ C = C + C_proj_bias.to(dtype=C.dtype)
326
+ if not A.is_complex():
327
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
328
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
329
+ else:
330
+ C = rearrange(
331
+ C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
332
+ ).contiguous()
333
+ else:
334
+ if C.stride(-1) != 1:
335
+ C = C.contiguous()
336
+ if D is not None:
337
+ D = D.contiguous()
338
+
339
+ if b_rms_weight is not None:
340
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
341
+ B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
342
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
343
+ if c_rms_weight is not None:
344
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
345
+ C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
346
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
347
+ if dt_rms_weight is not None:
348
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
349
+ delta = rms_norm_forward(
350
+ delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
351
+ )
352
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
353
+
354
+ out, scan_intermediates, out_z = ops.selective_scan_fwd(
355
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
356
+ )
357
+ ctx.delta_softplus = delta_softplus
358
+ ctx.out_proj_bias_is_None = out_proj_bias is None
359
+ ctx.checkpoint_lvl = checkpoint_lvl
360
+ ctx.b_rms_weight = b_rms_weight
361
+ ctx.c_rms_weight = c_rms_weight
362
+ ctx.dt_rms_weight = dt_rms_weight
363
+ ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
364
+ if (
365
+ checkpoint_lvl >= 1
366
+ ): # Will recompute conv1d_out and delta in the backward pass
367
+ conv1d_out, delta = None, None
368
+ ctx.save_for_backward(
369
+ xz,
370
+ conv1d_weight,
371
+ conv1d_bias,
372
+ x_dbl,
373
+ x_proj_weight,
374
+ delta_proj_weight,
375
+ out_proj_weight,
376
+ conv1d_out,
377
+ delta,
378
+ A,
379
+ B,
380
+ C,
381
+ D,
382
+ delta_bias,
383
+ scan_intermediates,
384
+ b_rms_weight,
385
+ c_rms_weight,
386
+ dt_rms_weight,
387
+ out,
388
+ )
389
+ return F.linear(
390
+ rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
391
+ )
392
+
393
+ @staticmethod
394
+ @custom_bwd
395
+ def backward(ctx, dout):
396
+ # dout: (batch, seqlen, dim)
397
+ assert (
398
+ causal_conv1d_cuda is not None
399
+ ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
400
+ (
401
+ xz,
402
+ conv1d_weight,
403
+ conv1d_bias,
404
+ x_dbl,
405
+ x_proj_weight,
406
+ delta_proj_weight,
407
+ out_proj_weight,
408
+ conv1d_out,
409
+ delta,
410
+ A,
411
+ B,
412
+ C,
413
+ D,
414
+ delta_bias,
415
+ scan_intermediates,
416
+ b_rms_weight,
417
+ c_rms_weight,
418
+ dt_rms_weight,
419
+ out,
420
+ ) = ctx.saved_tensors
421
+ L = xz.shape[-1]
422
+ delta_rank = delta_proj_weight.shape[1]
423
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
424
+ x, z = xz.chunk(2, dim=1)
425
+ if dout.stride(-1) != 1:
426
+ dout = dout.contiguous()
427
+ if ctx.checkpoint_lvl == 1:
428
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
429
+ x, conv1d_weight, conv1d_bias, None, None, None, True
430
+ )
431
+ delta = rearrange(
432
+ delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
433
+ )
434
+ if dt_rms_weight is not None:
435
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
436
+ delta = rms_norm_forward(
437
+ delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
438
+ )
439
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
440
+ if b_rms_weight is not None:
441
+ # Recompute & RMSNorm B
442
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
443
+ B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
444
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
445
+ if c_rms_weight is not None:
446
+ # Recompute & RMSNorm C
447
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
448
+ C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
449
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
450
+
451
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
452
+ # backward of selective_scan_cuda with the backward of chunk).
453
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
454
+ dx, dz = dxz.chunk(2, dim=1)
455
+ dout = rearrange(dout, "b l e -> e (b l)")
456
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
457
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
458
+ ops.selective_scan_bwd(
459
+ conv1d_out,
460
+ delta,
461
+ A,
462
+ B,
463
+ C,
464
+ D,
465
+ z,
466
+ delta_bias,
467
+ dout_y,
468
+ scan_intermediates,
469
+ out,
470
+ dz,
471
+ ctx.delta_softplus,
472
+ True, # option to recompute out_z
473
+ )
474
+ )
475
+ dout_proj_weight = torch.einsum(
476
+ "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
477
+ )
478
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
479
+ dD = dD if D is not None else None
480
+ dx_dbl = torch.empty_like(x_dbl)
481
+ dB_proj_bias = None
482
+ if ctx.is_variable_B:
483
+ if not A.is_complex():
484
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
485
+ else:
486
+ dB = rearrange(
487
+ dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
488
+ ).contiguous()
489
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
490
+ dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
491
+ dB = None
492
+ dC_proj_bias = None
493
+ if ctx.is_variable_C:
494
+ if not A.is_complex():
495
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
496
+ else:
497
+ dC = rearrange(
498
+ dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
499
+ ).contiguous()
500
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
501
+ dx_dbl[:, -d_state:] = dC # (bl d)
502
+ dC = None
503
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
504
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
505
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
506
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
507
+ dx_proj_weight = torch.einsum(
508
+ "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
509
+ )
510
+ dconv1d_out = torch.addmm(
511
+ dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
512
+ )
513
+ dconv1d_out = rearrange(
514
+ dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
515
+ )
516
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
517
+ # backward of conv1d with the backward of chunk).
518
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
519
+ x,
520
+ conv1d_weight,
521
+ conv1d_bias,
522
+ dconv1d_out,
523
+ None,
524
+ None,
525
+ None,
526
+ dx,
527
+ False,
528
+ True,
529
+ )
530
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
531
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
532
+ return (
533
+ dxz,
534
+ dconv1d_weight,
535
+ dconv1d_bias,
536
+ dx_proj_weight,
537
+ ddelta_proj_weight,
538
+ dout_proj_weight,
539
+ dout_proj_bias,
540
+ dA,
541
+ dB,
542
+ dC,
543
+ dD,
544
+ ddelta_bias if delta_bias is not None else None,
545
+ # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
546
+ dB_proj_bias,
547
+ dC_proj_bias,
548
+ None,
549
+ None,
550
+ None,
551
+ None,
552
+ None,
553
+ None,
554
+ )
555
+
556
+
557
+ def mamba_inner_fn(
558
+ xz,
559
+ conv1d_weight,
560
+ conv1d_bias,
561
+ x_proj_weight,
562
+ delta_proj_weight,
563
+ out_proj_weight,
564
+ out_proj_bias,
565
+ A,
566
+ B=None,
567
+ C=None,
568
+ D=None,
569
+ delta_bias=None,
570
+ B_proj_bias=None,
571
+ C_proj_bias=None,
572
+ delta_softplus=True,
573
+ checkpoint_lvl=1,
574
+ b_rms_weight=None,
575
+ c_rms_weight=None,
576
+ dt_rms_weight=None,
577
+ b_c_dt_rms_eps=1e-6,
578
+ ):
579
+ return MambaInnerFn.apply(
580
+ xz,
581
+ conv1d_weight,
582
+ conv1d_bias,
583
+ x_proj_weight,
584
+ delta_proj_weight,
585
+ out_proj_weight,
586
+ out_proj_bias,
587
+ A,
588
+ B,
589
+ C,
590
+ D,
591
+ delta_bias,
592
+ B_proj_bias,
593
+ C_proj_bias,
594
+ delta_softplus,
595
+ checkpoint_lvl,
596
+ b_rms_weight,
597
+ c_rms_weight,
598
+ dt_rms_weight,
599
+ b_c_dt_rms_eps,
600
+ )
601
+
602
+
603
+ def mamba_inner_ref(
604
+ xz,
605
+ conv1d_weight,
606
+ conv1d_bias,
607
+ x_proj_weight,
608
+ delta_proj_weight,
609
+ out_proj_weight,
610
+ out_proj_bias,
611
+ A,
612
+ B=None,
613
+ C=None,
614
+ D=None,
615
+ delta_bias=None,
616
+ B_proj_bias=None,
617
+ C_proj_bias=None,
618
+ delta_softplus=True,
619
+ ):
620
+ assert (
621
+ causal_conv1d_fn is not None
622
+ ), "causal_conv1d_fn is not available. Please install causal-conv1d."
623
+ L = xz.shape[-1]
624
+ delta_rank = delta_proj_weight.shape[1]
625
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
626
+ x, z = xz.chunk(2, dim=1)
627
+ x = causal_conv1d_fn(
628
+ x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
629
+ )
630
+ # We're being very careful here about the layout, to avoid extra transposes.
631
+ # We want delta to have d as the slowest moving dimension
632
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
633
+ x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
634
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
635
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
636
+ if B is None: # variable B
637
+ B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
638
+ if B_proj_bias is not None:
639
+ B = B + B_proj_bias.to(dtype=B.dtype)
640
+ if not A.is_complex():
641
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
642
+ else:
643
+ B = rearrange(
644
+ B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
645
+ ).contiguous()
646
+ if C is None: # variable B
647
+ C = x_dbl[:, -d_state:] # (bl d)
648
+ if C_proj_bias is not None:
649
+ C = C + C_proj_bias.to(dtype=C.dtype)
650
+ if not A.is_complex():
651
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
652
+ else:
653
+ C = rearrange(
654
+ C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
655
+ ).contiguous()
656
+ y = selective_scan_fn(
657
+ x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
658
+ )
659
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/__init__.py ADDED
File without changes
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/k_activations.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_N': 32}),
12
+ triton.Config({'BLOCK_N': 64}),
13
+ triton.Config({'BLOCK_N': 128}),
14
+ triton.Config({'BLOCK_N': 256}),
15
+ triton.Config({'BLOCK_N': 512}),
16
+ triton.Config({'BLOCK_N': 1024}),
17
+ ],
18
+ key=['ncols'],
19
+ )
20
+ @triton.jit
21
+ def _swiglu_fwd_kernel(
22
+ X,
23
+ Y,
24
+ OUT,
25
+ stride_x_row, # how much to increase the pointer when moving by 1 row
26
+ stride_y_row,
27
+ stride_out_row,
28
+ ncols,
29
+ BLOCK_N: tl.constexpr,
30
+ ):
31
+ # Map the program id to the row of X and Y it should compute.
32
+ row = tl.program_id(0)
33
+ start_col = tl.program_id(1) * BLOCK_N
34
+ X += row * stride_x_row
35
+ Y += row * stride_y_row
36
+ OUT += row * stride_out_row
37
+ cols = start_col + tl.arange(0, BLOCK_N)
38
+ x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
39
+ y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
40
+ out = x * tl.sigmoid(x) * y
41
+ tl.store(OUT + cols, out, mask=cols < ncols)
42
+
43
+
44
+ def _swiglu_fwd(xy, out=None):
45
+ if xy.stride(-1) != 1:
46
+ xy = xy.contiguous()
47
+ batch_shape = xy.shape[:-1]
48
+ xy = xy.reshape(-1, xy.shape[-1])
49
+ x, y = xy.chunk(2, dim=-1)
50
+ if out is None:
51
+ out = torch.empty_like(x)
52
+ else:
53
+ out = out.reshape(-1, out.shape[-1])
54
+ assert out.shape == x.shape
55
+ assert out.stride(-1) == 1
56
+ M, N = x.shape
57
+ grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
58
+ with torch.cuda.device(x.device.index):
59
+ _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
60
+ return out.reshape(*batch_shape, out.shape[-1])
61
+
62
+
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BLOCK_N': 32}),
66
+ triton.Config({'BLOCK_N': 64}),
67
+ triton.Config({'BLOCK_N': 128}),
68
+ triton.Config({'BLOCK_N': 256}),
69
+ triton.Config({'BLOCK_N': 512}),
70
+ triton.Config({'BLOCK_N': 1024}),
71
+ ],
72
+ key=['ncols'],
73
+ )
74
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
75
+ @triton.jit
76
+ def _swiglu_bwd_kernel(
77
+ X,
78
+ Y,
79
+ DOUT,
80
+ OUT,
81
+ DX,
82
+ DY,
83
+ stride_x_row, # how much to increase the pointer when moving by 1 row
84
+ stride_y_row,
85
+ stride_dout_row,
86
+ stride_out_row,
87
+ stride_dx_row,
88
+ stride_dy_row,
89
+ ncols,
90
+ BLOCK_N: tl.constexpr,
91
+ RECOMPUTE_OUTPUT: tl.constexpr,
92
+ ):
93
+ # Map the program id to the row of X and Y it should compute.
94
+ row = tl.program_id(0)
95
+ start_col = tl.program_id(1) * BLOCK_N
96
+ X += row * stride_x_row
97
+ Y += row * stride_y_row
98
+ DOUT += row * stride_dout_row
99
+ if RECOMPUTE_OUTPUT:
100
+ OUT += row * stride_out_row
101
+ DX += row * stride_dx_row
102
+ DY += row * stride_dy_row
103
+ cols = start_col + tl.arange(0, BLOCK_N)
104
+ x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
105
+ y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
106
+ dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
107
+ x_sigmoid = tl.sigmoid(x)
108
+ dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
109
+ dy = x * x_sigmoid * dout
110
+ tl.store(DX + cols, dx, mask=cols < ncols)
111
+ tl.store(DY + cols, dy, mask=cols < ncols)
112
+ if RECOMPUTE_OUTPUT:
113
+ out = x * x_sigmoid * y
114
+ tl.store(OUT + cols, out, mask=cols < ncols)
115
+
116
+
117
+ def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
118
+ if xy.stride(-1) != 1:
119
+ xy = xy.contiguous()
120
+ if dout.stride(-1) != 1:
121
+ dout = dout.contiguous()
122
+ batch_shape = xy.shape[:-1]
123
+ xy = xy.reshape(-1, xy.shape[-1])
124
+ x, y = xy.chunk(2, dim=-1)
125
+ dout = dout.reshape(-1, dout.shape[-1])
126
+ assert dout.shape == x.shape
127
+ if dxy is None:
128
+ dxy = torch.empty_like(xy)
129
+ else:
130
+ dxy = dxy.reshape(-1, dxy.shape[-1])
131
+ assert dxy.shape == xy.shape
132
+ dx, dy = dxy.chunk(2, dim=-1)
133
+ assert dx.stride(-1) == 1
134
+ assert dy.stride(-1) == 1
135
+ if recompute_output:
136
+ if out is None:
137
+ out = torch.empty_like(x)
138
+ else:
139
+ out = out.reshape(-1, out.shape[-1])
140
+ assert out.shape == x.shape
141
+ assert out.stride(-1) == 1
142
+ M, N = x.shape
143
+ grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
144
+ with torch.cuda.device(x.device.index):
145
+ _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
146
+ x.stride(0), y.stride(0), dout.stride(0),
147
+ out.stride(0) if recompute_output else 0,
148
+ dx.stride(0), dy.stride(0),
149
+ N)
150
+ if not recompute_output:
151
+ return dxy.reshape(*batch_shape, dxy.shape[-1])
152
+ else:
153
+ return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
154
+
155
+
156
+ class SwiGLU(torch.autograd.Function):
157
+
158
+ @staticmethod
159
+ def forward(ctx, xy):
160
+ ctx.save_for_backward(xy)
161
+ return _swiglu_fwd(xy)
162
+
163
+ @staticmethod
164
+ def backward(ctx, dout):
165
+ xy, = ctx.saved_tensors
166
+ return _swiglu_bwd(xy, dout)
167
+
168
+
169
+ swiglu = SwiGLU.apply
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py ADDED
@@ -0,0 +1,1166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from ...utils.torch import custom_bwd, custom_fwd
15
+
16
+ import triton
17
+ import triton.language as tl
18
+
19
+
20
+ def layer_norm_ref(
21
+ x,
22
+ weight,
23
+ bias,
24
+ residual=None,
25
+ x1=None,
26
+ weight1=None,
27
+ bias1=None,
28
+ eps=1e-6,
29
+ dropout_p=0.0,
30
+ rowscale=None,
31
+ prenorm=False,
32
+ dropout_mask=None,
33
+ dropout_mask1=None,
34
+ upcast=False,
35
+ ):
36
+ dtype = x.dtype
37
+ if upcast:
38
+ x = x.float()
39
+ weight = weight.float()
40
+ bias = bias.float() if bias is not None else None
41
+ residual = residual.float() if residual is not None else residual
42
+ x1 = x1.float() if x1 is not None else None
43
+ weight1 = weight1.float() if weight1 is not None else None
44
+ bias1 = bias1.float() if bias1 is not None else None
45
+ if x1 is not None:
46
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
+ if rowscale is not None:
48
+ x = x * rowscale[..., None]
49
+ if dropout_p > 0.0:
50
+ if dropout_mask is not None:
51
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
+ else:
53
+ x = F.dropout(x, p=dropout_p)
54
+ if x1 is not None:
55
+ if dropout_mask1 is not None:
56
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
+ else:
58
+ x1 = F.dropout(x1, p=dropout_p)
59
+ if x1 is not None:
60
+ x = x + x1
61
+ if residual is not None:
62
+ x = (x + residual).to(x.dtype)
63
+ out = F.layer_norm(
64
+ x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
65
+ ).to(dtype)
66
+ if weight1 is None:
67
+ return out if not prenorm else (out, x)
68
+ else:
69
+ out1 = F.layer_norm(
70
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
+ ).to(dtype)
72
+ return (out, out1) if not prenorm else (out, out1, x)
73
+
74
+
75
+ def rms_norm_ref(
76
+ x,
77
+ weight,
78
+ bias,
79
+ residual=None,
80
+ x1=None,
81
+ weight1=None,
82
+ bias1=None,
83
+ eps=1e-6,
84
+ dropout_p=0.0,
85
+ rowscale=None,
86
+ prenorm=False,
87
+ dropout_mask=None,
88
+ dropout_mask1=None,
89
+ upcast=False,
90
+ ):
91
+ dtype = x.dtype
92
+ if upcast:
93
+ x = x.float()
94
+ weight = weight.float()
95
+ bias = bias.float() if bias is not None else None
96
+ residual = residual.float() if residual is not None else residual
97
+ x1 = x1.float() if x1 is not None else None
98
+ weight1 = weight1.float() if weight1 is not None else None
99
+ bias1 = bias1.float() if bias1 is not None else None
100
+ if x1 is not None:
101
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
+ if rowscale is not None:
103
+ x = x * rowscale[..., None]
104
+ if dropout_p > 0.0:
105
+ if dropout_mask is not None:
106
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
+ else:
108
+ x = F.dropout(x, p=dropout_p)
109
+ if x1 is not None:
110
+ if dropout_mask1 is not None:
111
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
+ else:
113
+ x1 = F.dropout(x1, p=dropout_p)
114
+ if x1 is not None:
115
+ x = x + x1
116
+ if residual is not None:
117
+ x = (x + residual).to(x.dtype)
118
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
120
+ dtype
121
+ )
122
+ if weight1 is None:
123
+ return out if not prenorm else (out, x)
124
+ else:
125
+ out1 = (
126
+ (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
127
+ ).to(dtype)
128
+ return (out, out1) if not prenorm else (out, out1, x)
129
+
130
+
131
+ def config_prune(configs):
132
+
133
+ if torch.version.hip:
134
+ try:
135
+ # set warp size based on gcn architecure
136
+ gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
137
+ if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
138
+ # radeon
139
+ warp_size = 32
140
+ else:
141
+ # instinct
142
+ warp_size = 64
143
+ except AttributeError as e:
144
+ # fall back to crude method to set warp size
145
+ device_name = torch.cuda.get_device_properties(0).name
146
+ if "instinct" in device_name.lower():
147
+ warp_size = 64
148
+ else:
149
+ warp_size = 32
150
+ warnings.warn(
151
+ f"{e}, warp size set to {warp_size} based on device name: {device_name}",
152
+ UserWarning,
153
+ )
154
+
155
+ else:
156
+ # cuda
157
+ warp_size = 32
158
+
159
+ max_block_sz = 1024
160
+ max_num_warps = max_block_sz // warp_size
161
+ pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
162
+ return pruned_configs
163
+
164
+
165
+ configs_autotune = [
166
+ triton.Config({}, num_warps=1),
167
+ triton.Config({}, num_warps=2),
168
+ triton.Config({}, num_warps=4),
169
+ triton.Config({}, num_warps=8),
170
+ triton.Config({}, num_warps=16),
171
+ triton.Config({}, num_warps=32),
172
+ ]
173
+
174
+ pruned_configs_autotune = config_prune(configs_autotune)
175
+
176
+
177
+ @triton.autotune(
178
+ configs=pruned_configs_autotune,
179
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
180
+ )
181
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
182
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
183
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
184
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
185
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
186
+ @triton.jit
187
+ def _layer_norm_fwd_1pass_kernel(
188
+ X, # pointer to the input
189
+ Y, # pointer to the output
190
+ W, # pointer to the weights
191
+ B, # pointer to the biases
192
+ RESIDUAL, # pointer to the residual
193
+ X1,
194
+ W1,
195
+ B1,
196
+ Y1,
197
+ RESIDUAL_OUT, # pointer to the residual
198
+ ROWSCALE,
199
+ SEEDS, # Dropout seeds for each row
200
+ DROPOUT_MASK,
201
+ Mean, # pointer to the mean
202
+ Rstd, # pointer to the 1/std
203
+ stride_x_row, # how much to increase the pointer when moving by 1 row
204
+ stride_y_row,
205
+ stride_res_row,
206
+ stride_res_out_row,
207
+ stride_x1_row,
208
+ stride_y1_row,
209
+ M, # number of rows in X
210
+ N, # number of columns in X
211
+ eps, # epsilon to avoid division by zero
212
+ dropout_p, # Dropout probability
213
+ IS_RMS_NORM: tl.constexpr,
214
+ BLOCK_N: tl.constexpr,
215
+ HAS_RESIDUAL: tl.constexpr,
216
+ STORE_RESIDUAL_OUT: tl.constexpr,
217
+ HAS_BIAS: tl.constexpr,
218
+ HAS_DROPOUT: tl.constexpr,
219
+ STORE_DROPOUT_MASK: tl.constexpr,
220
+ HAS_ROWSCALE: tl.constexpr,
221
+ HAS_X1: tl.constexpr,
222
+ HAS_W1: tl.constexpr,
223
+ HAS_B1: tl.constexpr,
224
+ ):
225
+ # Map the program id to the row of X and Y it should compute.
226
+ row = tl.program_id(0)
227
+ X += row * stride_x_row
228
+ Y += row * stride_y_row
229
+ if HAS_RESIDUAL:
230
+ RESIDUAL += row * stride_res_row
231
+ if STORE_RESIDUAL_OUT:
232
+ RESIDUAL_OUT += row * stride_res_out_row
233
+ if HAS_X1:
234
+ X1 += row * stride_x1_row
235
+ if HAS_W1:
236
+ Y1 += row * stride_y1_row
237
+ # Compute mean and variance
238
+ cols = tl.arange(0, BLOCK_N)
239
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
240
+ if HAS_ROWSCALE:
241
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
242
+ x *= rowscale
243
+ if HAS_DROPOUT:
244
+ # Compute dropout mask
245
+ # 7 rounds is good enough, and reduces register pressure
246
+ keep_mask = (
247
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
248
+ )
249
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
250
+ if STORE_DROPOUT_MASK:
251
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
252
+ if HAS_X1:
253
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
254
+ if HAS_ROWSCALE:
255
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
256
+ x1 *= rowscale
257
+ if HAS_DROPOUT:
258
+ # Compute dropout mask
259
+ # 7 rounds is good enough, and reduces register pressure
260
+ keep_mask = (
261
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
262
+ > dropout_p
263
+ )
264
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
265
+ if STORE_DROPOUT_MASK:
266
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
267
+ x += x1
268
+ if HAS_RESIDUAL:
269
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
270
+ x += residual
271
+ if STORE_RESIDUAL_OUT:
272
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
273
+ if not IS_RMS_NORM:
274
+ mean = tl.sum(x, axis=0) / N
275
+ tl.store(Mean + row, mean)
276
+ xbar = tl.where(cols < N, x - mean, 0.0)
277
+ var = tl.sum(xbar * xbar, axis=0) / N
278
+ else:
279
+ xbar = tl.where(cols < N, x, 0.0)
280
+ var = tl.sum(xbar * xbar, axis=0) / N
281
+ rstd = 1 / tl.sqrt(var + eps)
282
+ tl.store(Rstd + row, rstd)
283
+ # Normalize and apply linear transformation
284
+ mask = cols < N
285
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
286
+ if HAS_BIAS:
287
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
288
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
289
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
290
+ # Write output
291
+ tl.store(Y + cols, y, mask=mask)
292
+ if HAS_W1:
293
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
294
+ if HAS_B1:
295
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
296
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
297
+ tl.store(Y1 + cols, y1, mask=mask)
298
+
299
+
300
+ def _layer_norm_fwd(
301
+ x,
302
+ weight,
303
+ bias,
304
+ eps,
305
+ residual=None,
306
+ x1=None,
307
+ weight1=None,
308
+ bias1=None,
309
+ dropout_p=0.0,
310
+ rowscale=None,
311
+ out_dtype=None,
312
+ residual_dtype=None,
313
+ is_rms_norm=False,
314
+ return_dropout_mask=False,
315
+ ):
316
+ if residual is not None:
317
+ residual_dtype = residual.dtype
318
+ M, N = x.shape
319
+ assert x.stride(-1) == 1
320
+ if residual is not None:
321
+ assert residual.stride(-1) == 1
322
+ assert residual.shape == (M, N)
323
+ assert weight.shape == (N,)
324
+ assert weight.stride(-1) == 1
325
+ if bias is not None:
326
+ assert bias.stride(-1) == 1
327
+ assert bias.shape == (N,)
328
+ if x1 is not None:
329
+ assert x1.shape == x.shape
330
+ assert rowscale is None
331
+ assert x1.stride(-1) == 1
332
+ if weight1 is not None:
333
+ assert weight1.shape == (N,)
334
+ assert weight1.stride(-1) == 1
335
+ if bias1 is not None:
336
+ assert bias1.shape == (N,)
337
+ assert bias1.stride(-1) == 1
338
+ if rowscale is not None:
339
+ assert rowscale.is_contiguous()
340
+ assert rowscale.shape == (M,)
341
+ # allocate output
342
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
343
+ assert y.stride(-1) == 1
344
+ if weight1 is not None:
345
+ y1 = torch.empty_like(y)
346
+ assert y1.stride(-1) == 1
347
+ else:
348
+ y1 = None
349
+ if (
350
+ residual is not None
351
+ or (residual_dtype is not None and residual_dtype != x.dtype)
352
+ or dropout_p > 0.0
353
+ or rowscale is not None
354
+ or x1 is not None
355
+ ):
356
+ residual_out = torch.empty(
357
+ M,
358
+ N,
359
+ device=x.device,
360
+ dtype=residual_dtype if residual_dtype is not None else x.dtype,
361
+ )
362
+ assert residual_out.stride(-1) == 1
363
+ else:
364
+ residual_out = None
365
+ mean = (
366
+ torch.empty((M,), dtype=torch.float32, device=x.device)
367
+ if not is_rms_norm
368
+ else None
369
+ )
370
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
371
+ if dropout_p > 0.0:
372
+ seeds = torch.randint(
373
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
374
+ )
375
+ else:
376
+ seeds = None
377
+ if return_dropout_mask and dropout_p > 0.0:
378
+ dropout_mask = torch.empty(
379
+ M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
380
+ )
381
+ else:
382
+ dropout_mask = None
383
+ # Less than 64KB per feature: enqueue fused kernel
384
+ MAX_FUSED_SIZE = 65536 // x.element_size()
385
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
386
+ if N > BLOCK_N:
387
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
388
+ with torch.cuda.device(x.device.index):
389
+ _layer_norm_fwd_1pass_kernel[(M,)](
390
+ x,
391
+ y,
392
+ weight,
393
+ bias,
394
+ residual,
395
+ x1,
396
+ weight1,
397
+ bias1,
398
+ y1,
399
+ residual_out,
400
+ rowscale,
401
+ seeds,
402
+ dropout_mask,
403
+ mean,
404
+ rstd,
405
+ x.stride(0),
406
+ y.stride(0),
407
+ residual.stride(0) if residual is not None else 0,
408
+ residual_out.stride(0) if residual_out is not None else 0,
409
+ x1.stride(0) if x1 is not None else 0,
410
+ y1.stride(0) if y1 is not None else 0,
411
+ M,
412
+ N,
413
+ eps,
414
+ dropout_p,
415
+ is_rms_norm,
416
+ BLOCK_N,
417
+ residual is not None,
418
+ residual_out is not None,
419
+ bias is not None,
420
+ dropout_p > 0.0,
421
+ dropout_mask is not None,
422
+ rowscale is not None,
423
+ )
424
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
425
+ if dropout_mask is not None and x1 is not None:
426
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
427
+ else:
428
+ dropout_mask1 = None
429
+ return (
430
+ y,
431
+ y1,
432
+ mean,
433
+ rstd,
434
+ residual_out if residual_out is not None else x,
435
+ seeds,
436
+ dropout_mask,
437
+ dropout_mask1,
438
+ )
439
+
440
+
441
+ @triton.autotune(
442
+ configs=pruned_configs_autotune,
443
+ key=[
444
+ "N",
445
+ "HAS_DRESIDUAL",
446
+ "STORE_DRESIDUAL",
447
+ "IS_RMS_NORM",
448
+ "HAS_BIAS",
449
+ "HAS_DROPOUT",
450
+ ],
451
+ )
452
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
453
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
454
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
455
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
456
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
457
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
458
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
459
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
460
+ @triton.jit
461
+ def _layer_norm_bwd_kernel(
462
+ X, # pointer to the input
463
+ W, # pointer to the weights
464
+ B, # pointer to the biases
465
+ Y, # pointer to the output to be recomputed
466
+ DY, # pointer to the output gradient
467
+ DX, # pointer to the input gradient
468
+ DW, # pointer to the partial sum of weights gradient
469
+ DB, # pointer to the partial sum of biases gradient
470
+ DRESIDUAL,
471
+ W1,
472
+ DY1,
473
+ DX1,
474
+ DW1,
475
+ DB1,
476
+ DRESIDUAL_IN,
477
+ ROWSCALE,
478
+ SEEDS,
479
+ Mean, # pointer to the mean
480
+ Rstd, # pointer to the 1/std
481
+ stride_x_row, # how much to increase the pointer when moving by 1 row
482
+ stride_y_row,
483
+ stride_dy_row,
484
+ stride_dx_row,
485
+ stride_dres_row,
486
+ stride_dy1_row,
487
+ stride_dx1_row,
488
+ stride_dres_in_row,
489
+ M, # number of rows in X
490
+ N, # number of columns in X
491
+ eps, # epsilon to avoid division by zero
492
+ dropout_p,
493
+ rows_per_program,
494
+ IS_RMS_NORM: tl.constexpr,
495
+ BLOCK_N: tl.constexpr,
496
+ HAS_DRESIDUAL: tl.constexpr,
497
+ STORE_DRESIDUAL: tl.constexpr,
498
+ HAS_BIAS: tl.constexpr,
499
+ HAS_DROPOUT: tl.constexpr,
500
+ HAS_ROWSCALE: tl.constexpr,
501
+ HAS_DY1: tl.constexpr,
502
+ HAS_DX1: tl.constexpr,
503
+ HAS_B1: tl.constexpr,
504
+ RECOMPUTE_OUTPUT: tl.constexpr,
505
+ ):
506
+ # Map the program id to the elements of X, DX, and DY it should compute.
507
+ row_block_id = tl.program_id(0)
508
+ row_start = row_block_id * rows_per_program
509
+ # Do not early exit if row_start >= M, because we need to write DW and DB
510
+ cols = tl.arange(0, BLOCK_N)
511
+ mask = cols < N
512
+ X += row_start * stride_x_row
513
+ if HAS_DRESIDUAL:
514
+ DRESIDUAL += row_start * stride_dres_row
515
+ if STORE_DRESIDUAL:
516
+ DRESIDUAL_IN += row_start * stride_dres_in_row
517
+ DY += row_start * stride_dy_row
518
+ DX += row_start * stride_dx_row
519
+ if HAS_DY1:
520
+ DY1 += row_start * stride_dy1_row
521
+ if HAS_DX1:
522
+ DX1 += row_start * stride_dx1_row
523
+ if RECOMPUTE_OUTPUT:
524
+ Y += row_start * stride_y_row
525
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
526
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
527
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
528
+ if HAS_DY1:
529
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
530
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
531
+ if HAS_BIAS:
532
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
533
+ if HAS_DY1:
534
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
535
+ if HAS_B1:
536
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
537
+ row_end = min((row_block_id + 1) * rows_per_program, M)
538
+ for row in range(row_start, row_end):
539
+ # Load data to SRAM
540
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
541
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
542
+ if HAS_DY1:
543
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
544
+ if not IS_RMS_NORM:
545
+ mean = tl.load(Mean + row)
546
+ rstd = tl.load(Rstd + row)
547
+ # Compute dx
548
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
549
+ xhat = tl.where(mask, xhat, 0.0)
550
+ if RECOMPUTE_OUTPUT:
551
+ y = xhat * w + b if HAS_BIAS else xhat * w
552
+ tl.store(Y + cols, y, mask=mask)
553
+ wdy = w * dy
554
+ dw += dy * xhat
555
+ if HAS_BIAS:
556
+ db += dy
557
+ if HAS_DY1:
558
+ wdy += w1 * dy1
559
+ dw1 += dy1 * xhat
560
+ if HAS_B1:
561
+ db1 += dy1
562
+ if not IS_RMS_NORM:
563
+ c1 = tl.sum(xhat * wdy, axis=0) / N
564
+ c2 = tl.sum(wdy, axis=0) / N
565
+ dx = (wdy - (xhat * c1 + c2)) * rstd
566
+ else:
567
+ c1 = tl.sum(xhat * wdy, axis=0) / N
568
+ dx = (wdy - xhat * c1) * rstd
569
+ if HAS_DRESIDUAL:
570
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
571
+ dx += dres
572
+ # Write dx
573
+ if STORE_DRESIDUAL:
574
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
575
+ if HAS_DX1:
576
+ if HAS_DROPOUT:
577
+ keep_mask = (
578
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
579
+ > dropout_p
580
+ )
581
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
582
+ else:
583
+ dx1 = dx
584
+ tl.store(DX1 + cols, dx1, mask=mask)
585
+ if HAS_DROPOUT:
586
+ keep_mask = (
587
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
588
+ > dropout_p
589
+ )
590
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
591
+ if HAS_ROWSCALE:
592
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
593
+ dx *= rowscale
594
+ tl.store(DX + cols, dx, mask=mask)
595
+
596
+ X += stride_x_row
597
+ if HAS_DRESIDUAL:
598
+ DRESIDUAL += stride_dres_row
599
+ if STORE_DRESIDUAL:
600
+ DRESIDUAL_IN += stride_dres_in_row
601
+ if RECOMPUTE_OUTPUT:
602
+ Y += stride_y_row
603
+ DY += stride_dy_row
604
+ DX += stride_dx_row
605
+ if HAS_DY1:
606
+ DY1 += stride_dy1_row
607
+ if HAS_DX1:
608
+ DX1 += stride_dx1_row
609
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
610
+ if HAS_BIAS:
611
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
612
+ if HAS_DY1:
613
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
614
+ if HAS_B1:
615
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
616
+
617
+
618
+ def _layer_norm_bwd(
619
+ dy,
620
+ x,
621
+ weight,
622
+ bias,
623
+ eps,
624
+ mean,
625
+ rstd,
626
+ dresidual=None,
627
+ dy1=None,
628
+ weight1=None,
629
+ bias1=None,
630
+ seeds=None,
631
+ dropout_p=0.0,
632
+ rowscale=None,
633
+ has_residual=False,
634
+ has_x1=False,
635
+ is_rms_norm=False,
636
+ x_dtype=None,
637
+ recompute_output=False,
638
+ ):
639
+ M, N = x.shape
640
+ assert x.stride(-1) == 1
641
+ assert dy.stride(-1) == 1
642
+ assert dy.shape == (M, N)
643
+ if dresidual is not None:
644
+ assert dresidual.stride(-1) == 1
645
+ assert dresidual.shape == (M, N)
646
+ assert weight.shape == (N,)
647
+ assert weight.stride(-1) == 1
648
+ if bias is not None:
649
+ assert bias.stride(-1) == 1
650
+ assert bias.shape == (N,)
651
+ if dy1 is not None:
652
+ assert weight1 is not None
653
+ assert dy1.shape == dy.shape
654
+ assert dy1.stride(-1) == 1
655
+ if weight1 is not None:
656
+ assert weight1.shape == (N,)
657
+ assert weight1.stride(-1) == 1
658
+ if bias1 is not None:
659
+ assert bias1.shape == (N,)
660
+ assert bias1.stride(-1) == 1
661
+ if seeds is not None:
662
+ assert seeds.is_contiguous()
663
+ assert seeds.shape == (M if not has_x1 else M * 2,)
664
+ if rowscale is not None:
665
+ assert rowscale.is_contiguous()
666
+ assert rowscale.shape == (M,)
667
+ # allocate output
668
+ dx = (
669
+ torch.empty_like(x)
670
+ if x_dtype is None
671
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
672
+ )
673
+ dresidual_in = (
674
+ torch.empty_like(x)
675
+ if has_residual
676
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
677
+ else None
678
+ )
679
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
680
+ y = (
681
+ torch.empty(M, N, dtype=dy.dtype, device=dy.device)
682
+ if recompute_output
683
+ else None
684
+ )
685
+ if recompute_output:
686
+ assert (
687
+ weight1 is None
688
+ ), "recompute_output is not supported with parallel LayerNorm"
689
+
690
+ # Less than 64KB per feature: enqueue fused kernel
691
+ MAX_FUSED_SIZE = 65536 // x.element_size()
692
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
693
+ if N > BLOCK_N:
694
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
695
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
696
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
697
+ _db = (
698
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
699
+ if bias is not None
700
+ else None
701
+ )
702
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
703
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
704
+ rows_per_program = math.ceil(M / sm_count)
705
+ grid = (sm_count,)
706
+ with torch.cuda.device(x.device.index):
707
+ _layer_norm_bwd_kernel[grid](
708
+ x,
709
+ weight,
710
+ bias,
711
+ y,
712
+ dy,
713
+ dx,
714
+ _dw,
715
+ _db,
716
+ dresidual,
717
+ weight1,
718
+ dy1,
719
+ dx1,
720
+ _dw1,
721
+ _db1,
722
+ dresidual_in,
723
+ rowscale,
724
+ seeds,
725
+ mean,
726
+ rstd,
727
+ x.stride(0),
728
+ 0 if not recompute_output else y.stride(0),
729
+ dy.stride(0),
730
+ dx.stride(0),
731
+ dresidual.stride(0) if dresidual is not None else 0,
732
+ dy1.stride(0) if dy1 is not None else 0,
733
+ dx1.stride(0) if dx1 is not None else 0,
734
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
735
+ M,
736
+ N,
737
+ eps,
738
+ dropout_p,
739
+ rows_per_program,
740
+ is_rms_norm,
741
+ BLOCK_N,
742
+ dresidual is not None,
743
+ dresidual_in is not None,
744
+ bias is not None,
745
+ dropout_p > 0.0,
746
+ )
747
+ dw = _dw.sum(0).to(weight.dtype)
748
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
749
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
750
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
751
+ # Don't need to compute dresidual_in separately in this case
752
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
753
+ dresidual_in = dx
754
+ if has_x1 and dropout_p == 0.0:
755
+ dx1 = dx
756
+ return (
757
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
758
+ if not recompute_output
759
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
760
+ )
761
+
762
+
763
+ class LayerNormFn(torch.autograd.Function):
764
+ @staticmethod
765
+ def forward(
766
+ ctx,
767
+ x,
768
+ weight,
769
+ bias,
770
+ residual=None,
771
+ x1=None,
772
+ weight1=None,
773
+ bias1=None,
774
+ eps=1e-6,
775
+ dropout_p=0.0,
776
+ rowscale=None,
777
+ prenorm=False,
778
+ residual_in_fp32=False,
779
+ is_rms_norm=False,
780
+ return_dropout_mask=False,
781
+ ):
782
+ x_shape_og = x.shape
783
+ # reshape input data into 2D tensor
784
+ x = x.reshape(-1, x.shape[-1])
785
+ if x.stride(-1) != 1:
786
+ x = x.contiguous()
787
+ if residual is not None:
788
+ assert residual.shape == x_shape_og
789
+ residual = residual.reshape(-1, residual.shape[-1])
790
+ if residual.stride(-1) != 1:
791
+ residual = residual.contiguous()
792
+ if x1 is not None:
793
+ assert x1.shape == x_shape_og
794
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
795
+ x1 = x1.reshape(-1, x1.shape[-1])
796
+ if x1.stride(-1) != 1:
797
+ x1 = x1.contiguous()
798
+ weight = weight.contiguous()
799
+ if bias is not None:
800
+ bias = bias.contiguous()
801
+ if weight1 is not None:
802
+ weight1 = weight1.contiguous()
803
+ if bias1 is not None:
804
+ bias1 = bias1.contiguous()
805
+ if rowscale is not None:
806
+ rowscale = rowscale.reshape(-1).contiguous()
807
+ residual_dtype = (
808
+ residual.dtype
809
+ if residual is not None
810
+ else (torch.float32 if residual_in_fp32 else None)
811
+ )
812
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
813
+ _layer_norm_fwd(
814
+ x,
815
+ weight,
816
+ bias,
817
+ eps,
818
+ residual,
819
+ x1,
820
+ weight1,
821
+ bias1,
822
+ dropout_p=dropout_p,
823
+ rowscale=rowscale,
824
+ residual_dtype=residual_dtype,
825
+ is_rms_norm=is_rms_norm,
826
+ return_dropout_mask=return_dropout_mask,
827
+ )
828
+ )
829
+ ctx.save_for_backward(
830
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
831
+ )
832
+ ctx.x_shape_og = x_shape_og
833
+ ctx.eps = eps
834
+ ctx.dropout_p = dropout_p
835
+ ctx.is_rms_norm = is_rms_norm
836
+ ctx.has_residual = residual is not None
837
+ ctx.has_x1 = x1 is not None
838
+ ctx.prenorm = prenorm
839
+ ctx.x_dtype = x.dtype
840
+ y = y.reshape(x_shape_og)
841
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
842
+ residual_out = (
843
+ residual_out.reshape(x_shape_og) if residual_out is not None else None
844
+ )
845
+ dropout_mask = (
846
+ dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
847
+ )
848
+ dropout_mask1 = (
849
+ dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
850
+ )
851
+ if not return_dropout_mask:
852
+ if weight1 is None:
853
+ return y if not prenorm else (y, residual_out)
854
+ else:
855
+ return (y, y1) if not prenorm else (y, y1, residual_out)
856
+ else:
857
+ if weight1 is None:
858
+ return (
859
+ (y, dropout_mask, dropout_mask1)
860
+ if not prenorm
861
+ else (y, residual_out, dropout_mask, dropout_mask1)
862
+ )
863
+ else:
864
+ return (
865
+ (y, y1, dropout_mask, dropout_mask1)
866
+ if not prenorm
867
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
868
+ )
869
+
870
+ @staticmethod
871
+ def backward(ctx, dy, *args):
872
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
873
+ dy = dy.reshape(-1, dy.shape[-1])
874
+ if dy.stride(-1) != 1:
875
+ dy = dy.contiguous()
876
+ assert dy.shape == x.shape
877
+ if weight1 is not None:
878
+ dy1, args = args[0], args[1:]
879
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
880
+ if dy1.stride(-1) != 1:
881
+ dy1 = dy1.contiguous()
882
+ assert dy1.shape == x.shape
883
+ else:
884
+ dy1 = None
885
+ if ctx.prenorm:
886
+ dresidual = args[0]
887
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
888
+ if dresidual.stride(-1) != 1:
889
+ dresidual = dresidual.contiguous()
890
+ assert dresidual.shape == x.shape
891
+ else:
892
+ dresidual = None
893
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
894
+ dy,
895
+ x,
896
+ weight,
897
+ bias,
898
+ ctx.eps,
899
+ mean,
900
+ rstd,
901
+ dresidual,
902
+ dy1,
903
+ weight1,
904
+ bias1,
905
+ seeds,
906
+ ctx.dropout_p,
907
+ rowscale,
908
+ ctx.has_residual,
909
+ ctx.has_x1,
910
+ ctx.is_rms_norm,
911
+ x_dtype=ctx.x_dtype,
912
+ )
913
+ return (
914
+ dx.reshape(ctx.x_shape_og),
915
+ dw,
916
+ db,
917
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
918
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
919
+ dw1,
920
+ db1,
921
+ None,
922
+ None,
923
+ None,
924
+ None,
925
+ None,
926
+ None,
927
+ None,
928
+ )
929
+
930
+
931
+ def layer_norm_fn(
932
+ x,
933
+ weight,
934
+ bias,
935
+ residual=None,
936
+ x1=None,
937
+ weight1=None,
938
+ bias1=None,
939
+ eps=1e-6,
940
+ dropout_p=0.0,
941
+ rowscale=None,
942
+ prenorm=False,
943
+ residual_in_fp32=False,
944
+ is_rms_norm=False,
945
+ return_dropout_mask=False,
946
+ ):
947
+ return LayerNormFn.apply(
948
+ x,
949
+ weight,
950
+ bias,
951
+ residual,
952
+ x1,
953
+ weight1,
954
+ bias1,
955
+ eps,
956
+ dropout_p,
957
+ rowscale,
958
+ prenorm,
959
+ residual_in_fp32,
960
+ is_rms_norm,
961
+ return_dropout_mask,
962
+ )
963
+
964
+
965
+ def rms_norm_fn(
966
+ x,
967
+ weight,
968
+ bias,
969
+ residual=None,
970
+ x1=None,
971
+ weight1=None,
972
+ bias1=None,
973
+ eps=1e-6,
974
+ dropout_p=0.0,
975
+ rowscale=None,
976
+ prenorm=False,
977
+ residual_in_fp32=False,
978
+ return_dropout_mask=False,
979
+ ):
980
+ return LayerNormFn.apply(
981
+ x,
982
+ weight,
983
+ bias,
984
+ residual,
985
+ x1,
986
+ weight1,
987
+ bias1,
988
+ eps,
989
+ dropout_p,
990
+ rowscale,
991
+ prenorm,
992
+ residual_in_fp32,
993
+ True,
994
+ return_dropout_mask,
995
+ )
996
+
997
+
998
+ class RMSNorm(torch.nn.Module):
999
+
1000
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
1001
+ factory_kwargs = {"device": device, "dtype": dtype}
1002
+ super().__init__()
1003
+ self.eps = eps
1004
+ if dropout_p > 0.0:
1005
+ self.drop = torch.nn.Dropout(dropout_p)
1006
+ else:
1007
+ self.drop = None
1008
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1009
+ self.register_parameter("bias", None)
1010
+ self.reset_parameters()
1011
+
1012
+ def reset_parameters(self):
1013
+ torch.nn.init.ones_(self.weight)
1014
+
1015
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1016
+ return rms_norm_fn(
1017
+ x,
1018
+ self.weight,
1019
+ self.bias,
1020
+ residual=residual,
1021
+ eps=self.eps,
1022
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1023
+ prenorm=prenorm,
1024
+ residual_in_fp32=residual_in_fp32,
1025
+ )
1026
+
1027
+
1028
+ class LayerNormLinearFn(torch.autograd.Function):
1029
+ @staticmethod
1030
+ @custom_fwd
1031
+ def forward(
1032
+ ctx,
1033
+ x,
1034
+ norm_weight,
1035
+ norm_bias,
1036
+ linear_weight,
1037
+ linear_bias,
1038
+ residual=None,
1039
+ eps=1e-6,
1040
+ prenorm=False,
1041
+ residual_in_fp32=False,
1042
+ is_rms_norm=False,
1043
+ ):
1044
+ x_shape_og = x.shape
1045
+ # reshape input data into 2D tensor
1046
+ x = x.reshape(-1, x.shape[-1])
1047
+ if x.stride(-1) != 1:
1048
+ x = x.contiguous()
1049
+ if residual is not None:
1050
+ assert residual.shape == x_shape_og
1051
+ residual = residual.reshape(-1, residual.shape[-1])
1052
+ if residual.stride(-1) != 1:
1053
+ residual = residual.contiguous()
1054
+ norm_weight = norm_weight.contiguous()
1055
+ if norm_bias is not None:
1056
+ norm_bias = norm_bias.contiguous()
1057
+ residual_dtype = (
1058
+ residual.dtype
1059
+ if residual is not None
1060
+ else (torch.float32 if residual_in_fp32 else None)
1061
+ )
1062
+ y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1063
+ x,
1064
+ norm_weight,
1065
+ norm_bias,
1066
+ eps,
1067
+ residual,
1068
+ out_dtype=(
1069
+ None
1070
+ if not torch.is_autocast_enabled()
1071
+ else torch.get_autocast_gpu_dtype()
1072
+ ),
1073
+ residual_dtype=residual_dtype,
1074
+ is_rms_norm=is_rms_norm,
1075
+ )
1076
+ y = y.reshape(x_shape_og)
1077
+ dtype = (
1078
+ torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1079
+ )
1080
+ linear_weight = linear_weight.to(dtype)
1081
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1082
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1083
+ # We don't store y, will be recomputed in the backward pass to save memory
1084
+ ctx.save_for_backward(
1085
+ residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1086
+ )
1087
+ ctx.x_shape_og = x_shape_og
1088
+ ctx.eps = eps
1089
+ ctx.is_rms_norm = is_rms_norm
1090
+ ctx.has_residual = residual is not None
1091
+ ctx.prenorm = prenorm
1092
+ ctx.x_dtype = x.dtype
1093
+ ctx.linear_bias_is_none = linear_bias is None
1094
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1095
+
1096
+ @staticmethod
1097
+ @custom_bwd
1098
+ def backward(ctx, dout, *args):
1099
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1100
+ dout = dout.reshape(-1, dout.shape[-1])
1101
+ dy = F.linear(dout, linear_weight.t())
1102
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1103
+ if dy.stride(-1) != 1:
1104
+ dy = dy.contiguous()
1105
+ assert dy.shape == x.shape
1106
+ if ctx.prenorm:
1107
+ dresidual = args[0]
1108
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1109
+ if dresidual.stride(-1) != 1:
1110
+ dresidual = dresidual.contiguous()
1111
+ assert dresidual.shape == x.shape
1112
+ else:
1113
+ dresidual = None
1114
+ dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1115
+ dy,
1116
+ x,
1117
+ norm_weight,
1118
+ norm_bias,
1119
+ ctx.eps,
1120
+ mean,
1121
+ rstd,
1122
+ dresidual=dresidual,
1123
+ has_residual=ctx.has_residual,
1124
+ is_rms_norm=ctx.is_rms_norm,
1125
+ x_dtype=ctx.x_dtype,
1126
+ recompute_output=True,
1127
+ )
1128
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1129
+ return (
1130
+ dx.reshape(ctx.x_shape_og),
1131
+ dnorm_weight,
1132
+ dnorm_bias,
1133
+ dlinear_weight,
1134
+ dlinear_bias,
1135
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1136
+ None,
1137
+ None,
1138
+ None,
1139
+ None,
1140
+ )
1141
+
1142
+
1143
+ def layer_norm_linear_fn(
1144
+ x,
1145
+ norm_weight,
1146
+ norm_bias,
1147
+ linear_weight,
1148
+ linear_bias,
1149
+ residual=None,
1150
+ eps=1e-6,
1151
+ prenorm=False,
1152
+ residual_in_fp32=False,
1153
+ is_rms_norm=False,
1154
+ ):
1155
+ return LayerNormLinearFn.apply(
1156
+ x,
1157
+ norm_weight,
1158
+ norm_bias,
1159
+ linear_weight,
1160
+ linear_bias,
1161
+ residual,
1162
+ eps,
1163
+ prenorm,
1164
+ residual_in_fp32,
1165
+ is_rms_norm,
1166
+ )
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layernorm_gated.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
3
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
4
+ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
5
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from einops import rearrange
16
+
17
+
18
+ def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
19
+ dtype = x.dtype
20
+ N = x.shape[-1]
21
+ weight = weight.float()
22
+ bias = bias.float() if bias is not None else None
23
+ if upcast:
24
+ x = x.float()
25
+ z = z.float() if z is not None else z
26
+ if z is not None and not norm_before_gate:
27
+ x = x * F.silu(z)
28
+ if group_size is None:
29
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
30
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
31
+ else:
32
+ x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
33
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
34
+ out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
35
+ if bias is not None:
36
+ out = out + bias
37
+ if z is not None and norm_before_gate:
38
+ out *= F.silu(z)
39
+ return out.to(dtype)
40
+
41
+
42
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
43
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
44
+ @triton.jit
45
+ def _layer_norm_fwd_1pass_kernel(
46
+ X, # pointer to the input
47
+ Y, # pointer to the output
48
+ W, # pointer to the weights
49
+ B, # pointer to the biases
50
+ Z, # pointer to the other branch
51
+ Mean, # pointer to the mean
52
+ Rstd, # pointer to the 1/std
53
+ stride_x_row, # how much to increase the pointer when moving by 1 row
54
+ stride_y_row,
55
+ stride_z_row,
56
+ M, # number of rows in X
57
+ N, # number of columns in X
58
+ eps, # epsilon to avoid division by zero
59
+ BLOCK_N: tl.constexpr,
60
+ HAS_BIAS: tl.constexpr,
61
+ HAS_Z: tl.constexpr,
62
+ NORM_BEFORE_GATE: tl.constexpr,
63
+ IS_RMS_NORM: tl.constexpr,
64
+ ):
65
+ # Map the program id to the row of X and Y it should compute.
66
+ row = tl.program_id(0)
67
+ group = tl.program_id(1)
68
+ X += row * stride_x_row + group * N
69
+ Y += row * stride_y_row + group * N
70
+ if HAS_Z:
71
+ Z += row * stride_z_row + group * N
72
+ if not IS_RMS_NORM:
73
+ Mean += group * M
74
+ Rstd += group * M
75
+ W += group * N
76
+ if HAS_BIAS:
77
+ B += group * N
78
+ # Compute mean and variance
79
+ cols = tl.arange(0, BLOCK_N)
80
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
81
+ if HAS_Z and not NORM_BEFORE_GATE:
82
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
83
+ x *= z * tl.sigmoid(z)
84
+ if not IS_RMS_NORM:
85
+ mean = tl.sum(x, axis=0) / N
86
+ tl.store(Mean + row, mean)
87
+ xbar = tl.where(cols < N, x - mean, 0.)
88
+ var = tl.sum(xbar * xbar, axis=0) / N
89
+ else:
90
+ xbar = tl.where(cols < N, x, 0.)
91
+ var = tl.sum(xbar * xbar, axis=0) / N
92
+ rstd = 1 / tl.sqrt(var + eps)
93
+ tl.store(Rstd + row, rstd)
94
+ # Normalize and apply linear transformation
95
+ mask = cols < N
96
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
97
+ if HAS_BIAS:
98
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
99
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
100
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
101
+ if HAS_Z and NORM_BEFORE_GATE:
102
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
103
+ y *= z * tl.sigmoid(z)
104
+ # Write output
105
+ tl.store(Y + cols, y, mask=mask)
106
+
107
+
108
+ def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
109
+ M, N = x.shape
110
+ if group_size is None:
111
+ group_size = N
112
+ assert N % group_size == 0
113
+ ngroups = N // group_size
114
+ assert x.stride(-1) == 1
115
+ if z is not None:
116
+ assert z.stride(-1) == 1
117
+ assert z.shape == (M, N)
118
+ assert weight.shape == (N,)
119
+ assert weight.stride(-1) == 1
120
+ if bias is not None:
121
+ assert bias.stride(-1) == 1
122
+ assert bias.shape == (N,)
123
+ # allocate output
124
+ if out is not None:
125
+ assert out.shape == x.shape
126
+ else:
127
+ out = torch.empty_like(x)
128
+ assert out.stride(-1) == 1
129
+ mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
130
+ rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
131
+ # Less than 64KB per feature: enqueue fused kernel
132
+ MAX_FUSED_SIZE = 65536 // x.element_size()
133
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
134
+ if group_size > BLOCK_N:
135
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
136
+ # heuristics for number of warps
137
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
138
+ grid = (M, ngroups)
139
+ with torch.cuda.device(x.device.index):
140
+ _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
141
+ x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
142
+ M, group_size, eps,
143
+ BLOCK_N=BLOCK_N,
144
+ NORM_BEFORE_GATE=norm_before_gate,
145
+ IS_RMS_NORM=is_rms_norm,
146
+ num_warps=num_warps)
147
+ return out, mean, rstd
148
+
149
+
150
+
151
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
152
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
153
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
154
+ @triton.jit
155
+ def _layer_norm_bwd_kernel(
156
+ X, # pointer to the input
157
+ W, # pointer to the weights
158
+ B, # pointer to the biases
159
+ Z, # pointer to the other branch
160
+ Y, # pointer to the output to be recomputed
161
+ DY, # pointer to the output gradient
162
+ DX, # pointer to the input gradient
163
+ DW, # pointer to the partial sum of weights gradient
164
+ DB, # pointer to the partial sum of biases gradient
165
+ DZ, # pointer to the other branch
166
+ Mean, # pointer to the mean
167
+ Rstd, # pointer to the 1/std
168
+ stride_x_row, # how much to increase the pointer when moving by 1 row
169
+ stride_z_row,
170
+ stride_y_row,
171
+ stride_dy_row,
172
+ stride_dx_row,
173
+ stride_dz_row,
174
+ stride_dw_row,
175
+ stride_db_row,
176
+ M, # number of rows in X
177
+ N, # number of columns in X
178
+ eps, # epsilon to avoid division by zero
179
+ rows_per_program,
180
+ NORM_BEFORE_GATE: tl.constexpr,
181
+ IS_RMS_NORM: tl.constexpr,
182
+ HAS_BIAS: tl.constexpr,
183
+ HAS_Z: tl.constexpr,
184
+ RECOMPUTE_OUTPUT: tl.constexpr,
185
+ BLOCK_N: tl.constexpr,
186
+ ):
187
+ # Map the program id to the elements of X, DX, and DY it should compute.
188
+ row_block_id = tl.program_id(0)
189
+ group = tl.program_id(1)
190
+ row_start = row_block_id * rows_per_program
191
+ cols = tl.arange(0, BLOCK_N)
192
+ mask = cols < N
193
+ X += row_start * stride_x_row + group * N
194
+ if HAS_Z:
195
+ Z += row_start * stride_z_row + group * N
196
+ DZ += row_start * stride_dz_row + group * N
197
+ DY += row_start * stride_dy_row + group * N
198
+ DX += row_start * stride_dx_row + group * N
199
+ if RECOMPUTE_OUTPUT:
200
+ Y += row_start * stride_y_row + group * N
201
+ if not IS_RMS_NORM:
202
+ Mean += group * M
203
+ Rstd += group * M
204
+ W += group * N
205
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
206
+ if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
207
+ B += group * N
208
+ b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
209
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
210
+ if HAS_BIAS:
211
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
212
+ row_end = min((row_block_id + 1) * rows_per_program, M)
213
+ for row in range(row_start, row_end):
214
+ # Load data to SRAM
215
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
216
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
217
+ if not IS_RMS_NORM:
218
+ mean = tl.load(Mean + row)
219
+ if HAS_Z and not NORM_BEFORE_GATE:
220
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
221
+ x_og = x
222
+ x = x_og * z * tl.sigmoid(z)
223
+ rstd = tl.load(Rstd + row)
224
+ # Compute dx
225
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
226
+ xhat = tl.where(mask, xhat, 0.)
227
+ if HAS_Z and NORM_BEFORE_GATE:
228
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
229
+ z_sigmoid = tl.sigmoid(z)
230
+ y = xhat * w + b if HAS_BIAS else xhat * w
231
+ if RECOMPUTE_OUTPUT:
232
+ tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
233
+ dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
234
+ tl.store(DZ + cols, dz, mask=mask)
235
+ dy *= z * z_sigmoid
236
+ else:
237
+ if RECOMPUTE_OUTPUT:
238
+ y = xhat * w + b if HAS_BIAS else xhat * w
239
+ tl.store(Y + cols, y, mask=mask)
240
+ wdy = w * dy
241
+ c1 = tl.sum(xhat * wdy, axis=0) / N
242
+ if not IS_RMS_NORM:
243
+ c2 = tl.sum(wdy, axis=0) / N
244
+ dx = (wdy - (xhat * c1 + c2)) * rstd
245
+ else:
246
+ dx = (wdy - xhat * c1) * rstd
247
+ dw += dy * xhat
248
+ if HAS_BIAS:
249
+ db += dy
250
+ if HAS_Z and not NORM_BEFORE_GATE:
251
+ z_sigmoid = tl.sigmoid(z)
252
+ dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
253
+ tl.store(DZ + cols, dz, mask=mask)
254
+ dx *= z * z_sigmoid
255
+ # Write dx
256
+ tl.store(DX + cols, dx, mask=mask)
257
+
258
+ X += stride_x_row
259
+ if HAS_Z:
260
+ Z += stride_z_row
261
+ DZ += stride_dz_row
262
+ if RECOMPUTE_OUTPUT:
263
+ Y += stride_y_row
264
+ DY += stride_dy_row
265
+ DX += stride_dx_row
266
+ tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
267
+ if HAS_BIAS:
268
+ tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
269
+
270
+
271
+ def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
272
+ norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
273
+ M, N = x.shape
274
+ if group_size is None:
275
+ group_size = N
276
+ assert N % group_size == 0
277
+ ngroups = N // group_size
278
+ assert x.stride(-1) == 1
279
+ assert dy.stride(-1) == 1
280
+ assert dy.shape == (M, N)
281
+ if z is not None:
282
+ assert z.stride(-1) == 1
283
+ assert z.shape == (M, N)
284
+ assert weight.shape == (N,)
285
+ assert weight.stride(-1) == 1
286
+ if bias is not None:
287
+ assert bias.stride(-1) == 1
288
+ assert bias.shape == (N,)
289
+ # allocate output
290
+ dx = torch.empty_like(x)
291
+ if dz is not None:
292
+ assert z is not None
293
+ assert dz.shape == z.shape
294
+ assert dz.stride(-1) == 1
295
+ else:
296
+ dz = torch.empty_like(z) if z is not None else None
297
+ if recompute_output:
298
+ if out is None:
299
+ out = torch.empty_like(x)
300
+ assert out.shape == x.shape
301
+
302
+ # Less than 64KB per feature: enqueue fused kernel
303
+ MAX_FUSED_SIZE = 65536 // x.element_size()
304
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
305
+ if group_size > BLOCK_N:
306
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
307
+ # heuristics for number of warps
308
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
309
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
310
+ # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
311
+ # would limit the occupancy.
312
+ nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
313
+ _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
314
+ _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
315
+ rows_per_program = math.ceil(M / nrow_groups)
316
+ grid = (nrow_groups, ngroups)
317
+ with torch.cuda.device(x.device.index):
318
+ _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
319
+ dy, dx, _dw, _db, dz, mean, rstd,
320
+ x.stride(0),
321
+ z.stride(0) if z is not None else 0,
322
+ 0 if not recompute_output else out.stride(0),
323
+ dy.stride(0), dx.stride(0),
324
+ dz.stride(0) if dz is not None else 0,
325
+ _dw.stride(0),
326
+ _db.stride(0) if _db is not None else 0,
327
+ M, group_size, eps,
328
+ rows_per_program,
329
+ BLOCK_N=BLOCK_N,
330
+ NORM_BEFORE_GATE=norm_before_gate,
331
+ IS_RMS_NORM=is_rms_norm,
332
+ num_warps=num_warps)
333
+ dw = _dw.sum(0).to(weight.dtype)
334
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
335
+ return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
336
+
337
+
338
+ class LayerNormFn(torch.autograd.Function):
339
+
340
+ @staticmethod
341
+ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
342
+ is_rms_norm=False):
343
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
344
+ """
345
+
346
+ x_shape_og = x.shape
347
+ # reshape input data into 2D tensor
348
+ x = x.reshape(-1, x.shape[-1])
349
+ if x.stride(-1) != 1:
350
+ x = x.contiguous()
351
+ if z is not None:
352
+ assert z.shape == x_shape_og
353
+ z = z.reshape(-1, z.shape[-1])
354
+ if z.stride(-1) != 1:
355
+ z = z.contiguous()
356
+ weight = weight.contiguous()
357
+ if bias is not None:
358
+ bias = bias.contiguous()
359
+ y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
360
+ ctx.save_for_backward(x, weight, bias, mean, rstd, z)
361
+ ctx.x_shape_og = x_shape_og
362
+ ctx.eps = eps
363
+ ctx.group_size = group_size
364
+ ctx.norm_before_gate = norm_before_gate
365
+ ctx.is_rms_norm = is_rms_norm
366
+ return y.reshape(x_shape_og)
367
+
368
+ @staticmethod
369
+ def backward(ctx, dy):
370
+ x, weight, bias, mean, rstd, z = ctx.saved_tensors
371
+ dy = dy.reshape(-1, dy.shape[-1])
372
+ if dy.stride(-1) != 1:
373
+ dy = dy.contiguous()
374
+ assert dy.shape == x.shape
375
+ dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
376
+ ctx.norm_before_gate, ctx.is_rms_norm)
377
+ return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
378
+
379
+
380
+ def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
381
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
382
+
383
+
384
+ def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
385
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
386
+
387
+
388
+ class LayerNorm(torch.nn.Module):
389
+
390
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
391
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
392
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
393
+ """
394
+
395
+ factory_kwargs = {"device": device, "dtype": dtype}
396
+ super().__init__()
397
+ self.eps = eps
398
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
399
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
400
+ self.group_size = group_size
401
+ self.norm_before_gate = norm_before_gate
402
+ self.reset_parameters()
403
+
404
+ def reset_parameters(self):
405
+ torch.nn.init.ones_(self.weight)
406
+ torch.nn.init.zeros_(self.bias)
407
+
408
+ def forward(self, x, z=None):
409
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
410
+ """
411
+ return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
412
+ norm_before_gate=self.norm_before_gate)
413
+
414
+
415
+ class RMSNorm(torch.nn.Module):
416
+
417
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
418
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
419
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
420
+ """
421
+ factory_kwargs = {"device": device, "dtype": dtype}
422
+ super().__init__()
423
+ self.eps = eps
424
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
425
+ self.register_parameter("bias", None)
426
+ self.group_size = group_size
427
+ self.norm_before_gate = norm_before_gate
428
+ self.reset_parameters()
429
+
430
+ def reset_parameters(self):
431
+ torch.nn.init.ones_(self.weight)
432
+
433
+ def forward(self, x, z=None):
434
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
435
+ """
436
+ return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
437
+ norm_before_gate=self.norm_before_gate)
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from .softplus import softplus
16
+
17
+
18
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
+ @triton.heuristics(
22
+ {
23
+ "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
24
+ is not None
25
+ }
26
+ )
27
+ @triton.heuristics(
28
+ {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
29
+ )
30
+ @triton.jit
31
+ def _selective_scan_update_kernel(
32
+ # Pointers to matrices
33
+ state_ptr,
34
+ x_ptr,
35
+ dt_ptr,
36
+ dt_bias_ptr,
37
+ A_ptr,
38
+ B_ptr,
39
+ C_ptr,
40
+ D_ptr,
41
+ z_ptr,
42
+ out_ptr,
43
+ state_batch_indices_ptr,
44
+ # Matrix dimensions
45
+ batch,
46
+ nheads,
47
+ dim,
48
+ dstate,
49
+ nheads_ngroups_ratio,
50
+ # Strides
51
+ stride_state_batch,
52
+ stride_state_head,
53
+ stride_state_dim,
54
+ stride_state_dstate,
55
+ stride_x_batch,
56
+ stride_x_head,
57
+ stride_x_dim,
58
+ stride_dt_batch,
59
+ stride_dt_head,
60
+ stride_dt_dim,
61
+ stride_dt_bias_head,
62
+ stride_dt_bias_dim,
63
+ stride_A_head,
64
+ stride_A_dim,
65
+ stride_A_dstate,
66
+ stride_B_batch,
67
+ stride_B_group,
68
+ stride_B_dstate,
69
+ stride_C_batch,
70
+ stride_C_group,
71
+ stride_C_dstate,
72
+ stride_D_head,
73
+ stride_D_dim,
74
+ stride_z_batch,
75
+ stride_z_head,
76
+ stride_z_dim,
77
+ stride_out_batch,
78
+ stride_out_head,
79
+ stride_out_dim,
80
+ # Meta-parameters
81
+ DT_SOFTPLUS: tl.constexpr,
82
+ TIE_HDIM: tl.constexpr,
83
+ BLOCK_SIZE_M: tl.constexpr,
84
+ HAS_DT_BIAS: tl.constexpr,
85
+ HAS_D: tl.constexpr,
86
+ HAS_Z: tl.constexpr,
87
+ HAS_STATE_BATCH_INDICES: tl.constexpr,
88
+ BLOCK_SIZE_DSTATE: tl.constexpr,
89
+ ):
90
+ pid_m = tl.program_id(axis=0)
91
+ pid_b = tl.program_id(axis=1)
92
+ pid_h = tl.program_id(axis=2)
93
+
94
+ if HAS_STATE_BATCH_INDICES:
95
+ state_batch_indices_ptr += pid_b
96
+ state_batch_idx = tl.load(state_batch_indices_ptr)
97
+ state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
98
+ else:
99
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
100
+
101
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
102
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
103
+ if HAS_DT_BIAS:
104
+ dt_bias_ptr += pid_h * stride_dt_bias_head
105
+ A_ptr += pid_h * stride_A_head
106
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
107
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
108
+ if HAS_Z:
109
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
110
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
111
+
112
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
113
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
114
+ state_ptrs = state_ptr + (
115
+ offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
116
+ )
117
+ x_ptrs = x_ptr + offs_m * stride_x_dim
118
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
119
+ if HAS_DT_BIAS:
120
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
121
+ if HAS_D:
122
+ D_ptr += pid_h * stride_D_head
123
+ A_ptrs = A_ptr + (
124
+ offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
125
+ )
126
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
127
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
128
+ if HAS_D:
129
+ D_ptrs = D_ptr + offs_m * stride_D_dim
130
+ if HAS_Z:
131
+ z_ptrs = z_ptr + offs_m * stride_z_dim
132
+ out_ptrs = out_ptr + offs_m * stride_out_dim
133
+
134
+ state = tl.load(
135
+ state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
136
+ )
137
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
138
+ if not TIE_HDIM:
139
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
140
+ if HAS_DT_BIAS:
141
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
142
+ if DT_SOFTPLUS:
143
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
144
+ A = tl.load(
145
+ A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
146
+ ).to(tl.float32)
147
+ dA = tl.exp(A * dt[:, None])
148
+ else:
149
+ dt = tl.load(dt_ptr).to(tl.float32)
150
+ if HAS_DT_BIAS:
151
+ dt += tl.load(dt_bias_ptr).to(tl.float32)
152
+ if DT_SOFTPLUS:
153
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
154
+ A = tl.load(A_ptr).to(tl.float32)
155
+ dA = tl.exp(A * dt) # scalar, not a matrix
156
+
157
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
158
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
159
+ if HAS_D:
160
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
161
+ if HAS_Z:
162
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
163
+
164
+ if not TIE_HDIM:
165
+ dB = B[None, :] * dt[:, None]
166
+ else:
167
+ dB = B * dt # vector of size (dstate,)
168
+ state = state * dA + dB * x[:, None]
169
+ tl.store(
170
+ state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
171
+ )
172
+ out = tl.sum(state * C[None, :], axis=1)
173
+ if HAS_D:
174
+ out += x * D
175
+ if HAS_Z:
176
+ out *= z * tl.sigmoid(z)
177
+ tl.store(out_ptrs, out, mask=offs_m < dim)
178
+
179
+
180
+ def selective_state_update(
181
+ state,
182
+ x,
183
+ dt,
184
+ A,
185
+ B,
186
+ C,
187
+ D=None,
188
+ z=None,
189
+ dt_bias=None,
190
+ dt_softplus=False,
191
+ state_batch_indices=None,
192
+ ):
193
+ """
194
+ Argument:
195
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
196
+ x: (batch, dim) or (batch, nheads, dim)
197
+ dt: (batch, dim) or (batch, nheads, dim)
198
+ A: (dim, dstate) or (nheads, dim, dstate)
199
+ B: (batch, dstate) or (batch, ngroups, dstate)
200
+ C: (batch, dstate) or (batch, ngroups, dstate)
201
+ D: (dim,) or (nheads, dim)
202
+ z: (batch, dim) or (batch, nheads, dim)
203
+ dt_bias: (dim,) or (nheads, dim)
204
+ Return:
205
+ out: (batch, dim) or (batch, nheads, dim)
206
+ """
207
+ has_heads = state.dim() > 3
208
+ if state.dim() == 3:
209
+ state = state.unsqueeze(1)
210
+ if x.dim() == 2:
211
+ x = x.unsqueeze(1)
212
+ if dt.dim() == 2:
213
+ dt = dt.unsqueeze(1)
214
+ if A.dim() == 2:
215
+ A = A.unsqueeze(0)
216
+ if B.dim() == 2:
217
+ B = B.unsqueeze(1)
218
+ if C.dim() == 2:
219
+ C = C.unsqueeze(1)
220
+ if D is not None and D.dim() == 1:
221
+ D = D.unsqueeze(0)
222
+ if z is not None and z.dim() == 2:
223
+ z = z.unsqueeze(1)
224
+ if dt_bias is not None and dt_bias.dim() == 1:
225
+ dt_bias = dt_bias.unsqueeze(0)
226
+ _, nheads, dim, dstate = state.shape
227
+ batch = x.shape[0]
228
+ if x.shape != (batch, nheads, dim):
229
+ print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
230
+ assert x.shape == (batch, nheads, dim)
231
+ assert dt.shape == x.shape
232
+ assert A.shape == (nheads, dim, dstate)
233
+ ngroups = B.shape[1]
234
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
235
+ assert B.shape == (batch, ngroups, dstate)
236
+ assert C.shape == B.shape
237
+ if D is not None:
238
+ assert D.shape == (nheads, dim)
239
+ if z is not None:
240
+ assert z.shape == x.shape
241
+ if dt_bias is not None:
242
+ assert dt_bias.shape == (nheads, dim)
243
+ if state_batch_indices is not None:
244
+ assert state_batch_indices.shape == (batch,)
245
+ out = torch.empty_like(x)
246
+ grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
247
+ z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
248
+ # We don't want autotune since it will overwrite the state
249
+ # We instead tune by hand.
250
+ BLOCK_SIZE_M, num_warps = (
251
+ (32, 4)
252
+ if dstate <= 16
253
+ else (
254
+ (16, 4)
255
+ if dstate <= 32
256
+ else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
257
+ )
258
+ )
259
+ tie_hdim = (
260
+ A.stride(-1) == 0
261
+ and A.stride(-2) == 0
262
+ and dt.stride(-1) == 0
263
+ and dt_bias.stride(-1) == 0
264
+ )
265
+ with torch.cuda.device(x.device.index):
266
+ _selective_scan_update_kernel[grid](
267
+ state,
268
+ x,
269
+ dt,
270
+ dt_bias,
271
+ A,
272
+ B,
273
+ C,
274
+ D,
275
+ z,
276
+ out,
277
+ state_batch_indices,
278
+ batch,
279
+ nheads,
280
+ dim,
281
+ dstate,
282
+ nheads // ngroups,
283
+ state.stride(0),
284
+ state.stride(1),
285
+ state.stride(2),
286
+ state.stride(3),
287
+ x.stride(0),
288
+ x.stride(1),
289
+ x.stride(2),
290
+ dt.stride(0),
291
+ dt.stride(1),
292
+ dt.stride(2),
293
+ *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
294
+ A.stride(0),
295
+ A.stride(1),
296
+ A.stride(2),
297
+ B.stride(0),
298
+ B.stride(1),
299
+ B.stride(2),
300
+ C.stride(0),
301
+ C.stride(1),
302
+ C.stride(2),
303
+ *(D.stride(0), D.stride(1)) if D is not None else 0,
304
+ z_strides[0],
305
+ z_strides[1],
306
+ z_strides[2],
307
+ out.stride(0),
308
+ out.stride(1),
309
+ out.stride(2),
310
+ dt_softplus,
311
+ tie_hdim,
312
+ BLOCK_SIZE_M,
313
+ num_warps=num_warps,
314
+ )
315
+ if not has_heads:
316
+ out = out.squeeze(1)
317
+ return out
318
+
319
+
320
+ def selective_state_update_ref(
321
+ state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
322
+ ):
323
+ """
324
+ Argument:
325
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
326
+ x: (batch, dim) or (batch, nheads, dim)
327
+ dt: (batch, dim) or (batch, nheads, dim)
328
+ A: (dim, dstate) or (nheads, dim, dstate)
329
+ B: (batch, dstate) or (batch, ngroups, dstate)
330
+ C: (batch, dstate) or (batch, ngroups, dstate)
331
+ D: (dim,) or (nheads, dim)
332
+ z: (batch, dim) or (batch, nheads, dim)
333
+ dt_bias: (dim,) or (nheads, dim)
334
+ Return:
335
+ out: (batch, dim) or (batch, nheads, dim)
336
+ """
337
+ has_heads = state.dim() > 3
338
+ if state.dim() == 3:
339
+ state = state.unsqueeze(1)
340
+ if x.dim() == 2:
341
+ x = x.unsqueeze(1)
342
+ if dt.dim() == 2:
343
+ dt = dt.unsqueeze(1)
344
+ if A.dim() == 2:
345
+ A = A.unsqueeze(0)
346
+ if B.dim() == 2:
347
+ B = B.unsqueeze(1)
348
+ if C.dim() == 2:
349
+ C = C.unsqueeze(1)
350
+ if D is not None and D.dim() == 1:
351
+ D = D.unsqueeze(0)
352
+ if z is not None and z.dim() == 2:
353
+ z = z.unsqueeze(1)
354
+ if dt_bias is not None and dt_bias.dim() == 1:
355
+ dt_bias = dt_bias.unsqueeze(0)
356
+ batch, nheads, dim, dstate = state.shape
357
+ assert x.shape == (batch, nheads, dim)
358
+ assert dt.shape == x.shape
359
+ assert A.shape == (nheads, dim, dstate)
360
+ ngroups = B.shape[1]
361
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
362
+ assert B.shape == (batch, ngroups, dstate)
363
+ assert C.shape == B.shape
364
+ if D is not None:
365
+ assert D.shape == (nheads, dim)
366
+ if z is not None:
367
+ assert z.shape == x.shape
368
+ if dt_bias is not None:
369
+ assert dt_bias.shape == (nheads, dim)
370
+ dt = dt + dt_bias
371
+ dt = F.softplus(dt) if dt_softplus else dt
372
+ dA = torch.exp(
373
+ rearrange(dt, "b h d -> b h d 1") * A
374
+ ) # (batch, nheads, dim, dstate)
375
+ B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
376
+ C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
377
+ dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
378
+ B, "b h n -> b h 1 n"
379
+ ) # (batch, nheads, dim, dstate)
380
+ state.copy_(
381
+ state * dA + dB * rearrange(x, "b h d -> b h d 1")
382
+ ) # (batch, dim, dstate
383
+ out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
384
+ if D is not None:
385
+ out += (x * D).to(out.dtype)
386
+ out = (out if z is None else out * F.silu(z)).to(x.dtype)
387
+ if not has_heads:
388
+ out = out.squeeze(1)
389
+ return out
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/softplus.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ from packaging import version
4
+
5
+ TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
6
+
7
+
8
+ if TRITON3:
9
+ @triton.jit
10
+ def softplus(dt):
11
+ return tl.math.log(tl.math.exp(dt) + 1)
12
+ else:
13
+ @triton.jit
14
+ def softplus(dt):
15
+ return tl.math.log1p(tl.exp(dt))
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_bmm.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ def init_to_zero(names):
17
+ return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
18
+
19
+
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
23
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
24
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
25
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
26
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
27
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
28
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
29
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
30
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
31
+ ],
32
+ key=['chunk_size', 'K', 'IS_CAUSAL'],
33
+ )
34
+ @triton.jit
35
+ def _bmm_chunk_fwd_kernel(
36
+ # Pointers to matrices
37
+ a_ptr, b_ptr, out_ptr, seq_idx_ptr,
38
+ # Matrix dimensions
39
+ seqlen, chunk_size, K, ngroups,
40
+ stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
41
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
42
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
43
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
44
+ # Meta-parameters
45
+ IS_CAUSAL: tl.constexpr,
46
+ dot_dtype: tl.constexpr,
47
+ HAS_SEQ_IDX: tl.constexpr,
48
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
49
+ ):
50
+ pid_b = tl.program_id(axis=1)
51
+ pid_ch = tl.program_id(axis=2)
52
+ pid_c = pid_ch // ngroups
53
+ pid_h = pid_ch - pid_c * ngroups
54
+ num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
55
+ pid_m = tl.program_id(axis=0) // num_pid_n
56
+ pid_n = tl.program_id(axis=0) % num_pid_n
57
+ if IS_CAUSAL:
58
+ if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
59
+ return
60
+ a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
61
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
62
+ if HAS_SEQ_IDX:
63
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
64
+
65
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
66
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
67
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
68
+ a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
69
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
70
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
71
+
72
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
74
+ a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
75
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
76
+ acc += tl.dot(a, b)
77
+ a_ptrs += BLOCK_SIZE_K * stride_ak
78
+ b_ptrs += BLOCK_SIZE_K * stride_bk
79
+
80
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
81
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
82
+ if HAS_SEQ_IDX:
83
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
84
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
85
+ seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
86
+ acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
87
+ out = acc.to(out_ptr.dtype.element_ty)
88
+
89
+ out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
90
+ out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
91
+ tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
92
+
93
+
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
97
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
98
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
99
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
100
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
101
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
102
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
103
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
104
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
105
+ ],
106
+ key=['chunk_size', 'K'],
107
+ )
108
+ @triton.jit
109
+ def _bmm_chunk_bwd_kernel(
110
+ # Pointers to matrices
111
+ a_ptr, dout_ptr, db_ptr, res_ptr,
112
+ # Matrix dimensions
113
+ seqlen, chunk_size, K, ngroups,
114
+ stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
115
+ stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
116
+ stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
117
+ stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
118
+ # Meta-parameters
119
+ dot_dtype: tl.constexpr,
120
+ HAS_RESIDUAL: tl.constexpr,
121
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
122
+ ):
123
+ pid_b = tl.program_id(axis=1)
124
+ pid_ch = tl.program_id(axis=2)
125
+ pid_c = pid_ch // ngroups
126
+ pid_h = pid_ch - pid_c * ngroups
127
+ num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
128
+ pid_m = tl.program_id(axis=0) // num_pid_n
129
+ pid_n = tl.program_id(axis=0) % num_pid_n
130
+
131
+ a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
132
+ dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
133
+
134
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
135
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
136
+ offs_cs = tl.arange(0, BLOCK_SIZE_CS)
137
+ dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
138
+ a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
139
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
140
+
141
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
142
+ for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
143
+ dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
144
+ a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
145
+ acc += tl.dot(dout, a)
146
+ dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
147
+ a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
148
+
149
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
150
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
151
+ if HAS_RESIDUAL:
152
+ res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
153
+ res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
154
+ res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
155
+ acc += res
156
+ db = acc.to(db_ptr.dtype.element_ty)
157
+
158
+ db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
159
+ db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
160
+ tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
161
+
162
+
163
+ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
164
+ """
165
+ Argument:
166
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
167
+ b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
168
+ seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
169
+ causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
170
+ guaranteed to be correct.
171
+ Return:
172
+ out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
173
+ """
174
+ # Check constraints.
175
+ has_groups = a.dim() == 4
176
+ if not has_groups:
177
+ batch, seqlen, k = a.shape
178
+ else:
179
+ batch, seqlen, ngroups, k = a.shape
180
+ assert b.shape == a.shape
181
+ if seq_idx is not None:
182
+ assert seq_idx.shape == (batch, seqlen)
183
+ if a.stride(-1) != 1 and a.stride(1) != 1:
184
+ a = a.contiguous()
185
+ if b.stride(-1) != 1 and b.stride(1) != 1:
186
+ b = b.contiguous()
187
+ nchunks = math.ceil(seqlen / chunk_size)
188
+ # Allocates output.
189
+ out_dtype = a.dtype if output_dtype is None else output_dtype
190
+ out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
191
+ device=a.device, dtype=out_dtype)
192
+ dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
193
+ (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
194
+ grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
195
+ batch, nchunks if not has_groups else nchunks * ngroups)
196
+ with torch.cuda.device(a.device.index):
197
+ _bmm_chunk_fwd_kernel[grid](
198
+ a, b, out, seq_idx,
199
+ seqlen, chunk_size, k, ngroups if has_groups else 1,
200
+ a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
201
+ b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
202
+ out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
203
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
204
+ causal,
205
+ dot_dtype,
206
+ HAS_SEQ_IDX=seq_idx is not None,
207
+ )
208
+ return out
209
+
210
+
211
+ def _bmm_chunk_bwd(a, dout, residual=None, out=None):
212
+ """
213
+ Argument:
214
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
215
+ dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
216
+ residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
217
+ Return:
218
+ out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
219
+
220
+ If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
221
+ zeroed out before calling this function.
222
+ """
223
+ # Check constraints.
224
+ has_groups = a.dim() == 4
225
+ if not has_groups:
226
+ batch, seqlen, k = a.shape
227
+ else:
228
+ batch, seqlen, ngroups, k = a.shape
229
+ nchunks, chunk_size = dout.shape[1], dout.shape[-1]
230
+ if a.stride(-1) != 1 and a.stride(-2) != 1:
231
+ a = a.contiguous()
232
+ if dout.stride(-1) != 1 and dout.stride(-2) != 1:
233
+ dout = dout.contiguous()
234
+ if residual is not None:
235
+ assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
236
+ if residual.stride(-1) != 1 and residual.stride(1) != 1:
237
+ residual = residual.contiguous()
238
+ # Allocates output.
239
+ if out is not None:
240
+ assert out.shape == a.shape
241
+ assert out.stride(-1) == 1 or out.stride(1) == 1
242
+ else:
243
+ out = torch.empty_like(a)
244
+ dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
245
+ (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
246
+ grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
247
+ nchunks if not has_groups else nchunks * ngroups)
248
+ residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
249
+ residual.stride(-1))
250
+ if residual is not None else (0, 0, 0, 0))
251
+ with torch.cuda.device(a.device.index):
252
+ _bmm_chunk_bwd_kernel[grid](
253
+ a, dout, out, residual,
254
+ seqlen, chunk_size, k, ngroups if has_groups else 1,
255
+ a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
256
+ dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
257
+ out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
258
+ residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
259
+ dot_dtype,
260
+ HAS_RESIDUAL=residual is not None,
261
+ )
262
+ return out
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py ADDED
@@ -0,0 +1,2012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from .softplus import softplus
16
+
17
+
18
+ def init_to_zero(names):
19
+ return lambda nargs: [
20
+ nargs[name].zero_() for name in names if nargs[name] is not None
21
+ ]
22
+
23
+
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({"BLOCK_SIZE_H": 1}),
27
+ triton.Config({"BLOCK_SIZE_H": 2}),
28
+ triton.Config({"BLOCK_SIZE_H": 4}),
29
+ triton.Config({"BLOCK_SIZE_H": 8}),
30
+ triton.Config({"BLOCK_SIZE_H": 16}),
31
+ triton.Config({"BLOCK_SIZE_H": 32}),
32
+ triton.Config({"BLOCK_SIZE_H": 64}),
33
+ ],
34
+ key=["chunk_size", "nheads"],
35
+ )
36
+ @triton.jit
37
+ def _chunk_cumsum_fwd_kernel(
38
+ # Pointers to matrices
39
+ dt_ptr,
40
+ A_ptr,
41
+ dt_bias_ptr,
42
+ dt_out_ptr,
43
+ dA_cumsum_ptr,
44
+ # Matrix dimension
45
+ batch,
46
+ seqlen,
47
+ nheads,
48
+ chunk_size,
49
+ dt_min,
50
+ dt_max,
51
+ # Strides
52
+ stride_dt_batch,
53
+ stride_dt_seqlen,
54
+ stride_dt_head,
55
+ stride_A_head,
56
+ stride_dt_bias_head,
57
+ stride_dt_out_batch,
58
+ stride_dt_out_chunk,
59
+ stride_dt_out_head,
60
+ stride_dt_out_csize,
61
+ stride_dA_cs_batch,
62
+ stride_dA_cs_chunk,
63
+ stride_dA_cs_head,
64
+ stride_dA_cs_csize,
65
+ # Meta-parameters
66
+ DT_SOFTPLUS: tl.constexpr,
67
+ HAS_DT_BIAS: tl.constexpr,
68
+ BLOCK_SIZE_H: tl.constexpr,
69
+ BLOCK_SIZE_CHUNK: tl.constexpr,
70
+ ):
71
+ pid_b = tl.program_id(axis=0)
72
+ pid_c = tl.program_id(axis=1)
73
+ pid_h = tl.program_id(axis=2)
74
+ dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
75
+ dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
76
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
77
+
78
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
79
+ offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
80
+ dt_ptrs = dt_ptr + (
81
+ offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
82
+ )
83
+ A_ptrs = A_ptr + offs_h * stride_A_head
84
+ dt_out_ptrs = dt_out_ptr + (
85
+ offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
86
+ )
87
+ dA_cs_ptrs = dA_cumsum_ptr + (
88
+ offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
89
+ )
90
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
91
+
92
+ dt = tl.load(
93
+ dt_ptrs,
94
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
95
+ other=0.0,
96
+ ).to(tl.float32)
97
+ if HAS_DT_BIAS:
98
+ dt_bias = tl.load(
99
+ dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
100
+ ).to(tl.float32)
101
+ dt += dt_bias[:, None]
102
+ if DT_SOFTPLUS:
103
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
104
+ # As of Triton 2.2.0, tl.clamp is not available yet
105
+ # dt = tl.clamp(dt, dt_min, dt_max)
106
+ dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
107
+ dt = tl.where(
108
+ (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
109
+ )
110
+ tl.store(
111
+ dt_out_ptrs,
112
+ dt,
113
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
114
+ )
115
+ A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
116
+ dA = dt * A[:, None]
117
+ dA_cs = tl.cumsum(dA, axis=1)
118
+ tl.store(
119
+ dA_cs_ptrs,
120
+ dA_cs,
121
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
122
+ )
123
+
124
+
125
+ @triton.autotune(
126
+ configs=[
127
+ triton.Config(
128
+ {"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
129
+ ),
130
+ triton.Config(
131
+ {"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
132
+ ),
133
+ triton.Config(
134
+ {"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
135
+ ),
136
+ triton.Config(
137
+ {"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
138
+ ),
139
+ triton.Config(
140
+ {"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
141
+ ),
142
+ triton.Config(
143
+ {"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
144
+ ),
145
+ triton.Config(
146
+ {"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
147
+ ),
148
+ ],
149
+ key=["chunk_size", "nheads"],
150
+ )
151
+ @triton.jit
152
+ def _chunk_cumsum_bwd_kernel(
153
+ # Pointers to matrices
154
+ ddA_ptr,
155
+ ddt_out_ptr,
156
+ dt_ptr,
157
+ A_ptr,
158
+ dt_bias_ptr,
159
+ ddt_ptr,
160
+ dA_ptr,
161
+ ddt_bias_ptr,
162
+ # Matrix dimensions
163
+ batch,
164
+ seqlen,
165
+ nheads,
166
+ chunk_size,
167
+ dt_min,
168
+ dt_max,
169
+ # Strides
170
+ stride_ddA_batch,
171
+ stride_ddA_chunk,
172
+ stride_ddA_head,
173
+ stride_ddA_csize,
174
+ stride_ddt_out_batch,
175
+ stride_ddt_out_chunk,
176
+ stride_ddt_out_head,
177
+ stride_ddt_out_csize,
178
+ stride_dt_batch,
179
+ stride_dt_seqlen,
180
+ stride_dt_head,
181
+ stride_A_head,
182
+ stride_dt_bias_head,
183
+ stride_ddt_batch,
184
+ stride_ddt_seqlen,
185
+ stride_ddt_head,
186
+ stride_dA_head,
187
+ stride_ddt_bias_head,
188
+ # Meta-parameters
189
+ DT_SOFTPLUS: tl.constexpr,
190
+ HAS_DT_BIAS: tl.constexpr,
191
+ BLOCK_SIZE_H: tl.constexpr,
192
+ BLOCK_SIZE_CHUNK: tl.constexpr,
193
+ ):
194
+ pid_b = tl.program_id(axis=0)
195
+ pid_c = tl.program_id(axis=1)
196
+ pid_h = tl.program_id(axis=2)
197
+ ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
198
+ ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
199
+ dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
200
+ ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
201
+
202
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
203
+ offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
204
+ ddt_out_ptrs = ddt_out_ptr + (
205
+ offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
206
+ )
207
+ ddA_ptrs = ddA_ptr + (
208
+ offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
209
+ )
210
+ dt_ptrs = dt_ptr + (
211
+ offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
212
+ )
213
+ ddt_ptrs = ddt_ptr + (
214
+ offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
215
+ )
216
+ A_ptrs = A_ptr + offs_h * stride_A_head
217
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
218
+
219
+ ddA = tl.load(
220
+ ddA_ptrs,
221
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
222
+ other=0.0,
223
+ ).to(tl.float32)
224
+ ddt_out = tl.load(
225
+ ddt_out_ptrs,
226
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
227
+ other=0.0,
228
+ ).to(tl.float32)
229
+ A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
230
+ ddt = ddA * A[:, None] + ddt_out
231
+ dt = tl.load(
232
+ dt_ptrs,
233
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
234
+ other=0.0,
235
+ ).to(tl.float32)
236
+ if HAS_DT_BIAS:
237
+ dt_bias = tl.load(
238
+ dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
239
+ ).to(tl.float32)
240
+ dt += dt_bias[:, None]
241
+ if DT_SOFTPLUS:
242
+ dt_presoftplus = dt
243
+ dt = tl.where(dt <= 20.0, softplus(dt), ddt)
244
+ clamp_mask = (dt < dt_min) | (dt > dt_max)
245
+ # As of Triton 2.2.0, tl.clamp is not available yet
246
+ # dt = tl.clamp(dt, dt_min, dt_max)
247
+ dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
248
+ dt = tl.where(
249
+ (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
250
+ )
251
+ ddt = tl.where(
252
+ (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
253
+ )
254
+ ddt = tl.where(clamp_mask, 0.0, ddt)
255
+ if DT_SOFTPLUS:
256
+ ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
257
+ tl.store(
258
+ ddt_ptrs,
259
+ ddt,
260
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
261
+ )
262
+ dA = tl.sum(ddA * dt, axis=1)
263
+ tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
264
+ if HAS_DT_BIAS:
265
+ ddt_bias = tl.sum(ddt, axis=1)
266
+ tl.atomic_add(
267
+ ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
268
+ )
269
+
270
+
271
+ @triton.autotune(
272
+ configs=[
273
+ triton.Config(
274
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
275
+ num_stages=3,
276
+ num_warps=8,
277
+ ),
278
+ triton.Config(
279
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
280
+ num_stages=4,
281
+ num_warps=4,
282
+ ),
283
+ triton.Config(
284
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
285
+ num_stages=4,
286
+ num_warps=4,
287
+ ),
288
+ triton.Config(
289
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
290
+ num_stages=4,
291
+ num_warps=4,
292
+ ),
293
+ triton.Config(
294
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
295
+ num_stages=4,
296
+ num_warps=4,
297
+ ),
298
+ triton.Config(
299
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
300
+ num_stages=4,
301
+ num_warps=4,
302
+ ),
303
+ triton.Config(
304
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
305
+ num_stages=5,
306
+ num_warps=2,
307
+ ),
308
+ triton.Config(
309
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
310
+ num_stages=5,
311
+ num_warps=2,
312
+ ),
313
+ triton.Config(
314
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
315
+ num_stages=4,
316
+ num_warps=2,
317
+ ),
318
+ ],
319
+ key=["hdim", "dstate", "chunk_size"],
320
+ )
321
+ @triton.jit
322
+ def _chunk_state_fwd_kernel(
323
+ # Pointers to matrices
324
+ x_ptr,
325
+ b_ptr,
326
+ states_ptr,
327
+ dt_ptr,
328
+ dA_cumsum_ptr,
329
+ seq_idx_ptr,
330
+ # Matrix dimensions
331
+ hdim,
332
+ dstate,
333
+ chunk_size,
334
+ batch,
335
+ seqlen,
336
+ nheads_ngroups_ratio,
337
+ # Strides
338
+ stride_x_batch,
339
+ stride_x_seqlen,
340
+ stride_x_head,
341
+ stride_x_hdim,
342
+ stride_b_batch,
343
+ stride_b_seqlen,
344
+ stride_b_head,
345
+ stride_b_dstate,
346
+ stride_states_batch,
347
+ stride_states_chunk,
348
+ stride_states_head,
349
+ stride_states_hdim,
350
+ stride_states_dstate,
351
+ stride_dt_batch,
352
+ stride_dt_chunk,
353
+ stride_dt_head,
354
+ stride_dt_csize,
355
+ stride_dA_cs_batch,
356
+ stride_dA_cs_chunk,
357
+ stride_dA_cs_head,
358
+ stride_dA_cs_csize,
359
+ stride_seq_idx_batch,
360
+ stride_seq_idx_seqlen,
361
+ # Meta-parameters
362
+ HAS_SEQ_IDX: tl.constexpr,
363
+ BLOCK_SIZE_M: tl.constexpr,
364
+ BLOCK_SIZE_N: tl.constexpr,
365
+ BLOCK_SIZE_K: tl.constexpr,
366
+ ):
367
+ pid_bc = tl.program_id(axis=1)
368
+ pid_c = pid_bc // batch
369
+ pid_b = pid_bc - pid_c * batch
370
+ pid_h = tl.program_id(axis=2)
371
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
372
+ pid_m = tl.program_id(axis=0) // num_pid_n
373
+ pid_n = tl.program_id(axis=0) % num_pid_n
374
+ b_ptr += (
375
+ pid_b * stride_b_batch
376
+ + pid_c * chunk_size * stride_b_seqlen
377
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
378
+ )
379
+ x_ptr += (
380
+ pid_b * stride_x_batch
381
+ + pid_c * chunk_size * stride_x_seqlen
382
+ + pid_h * stride_x_head
383
+ )
384
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
385
+ dA_cumsum_ptr += (
386
+ pid_b * stride_dA_cs_batch
387
+ + pid_c * stride_dA_cs_chunk
388
+ + pid_h * stride_dA_cs_head
389
+ )
390
+ if HAS_SEQ_IDX:
391
+ seq_idx_ptr += (
392
+ pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
393
+ )
394
+
395
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
396
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
397
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
398
+ x_ptrs = x_ptr + (
399
+ offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
400
+ )
401
+ b_ptrs = b_ptr + (
402
+ offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
403
+ )
404
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
405
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
406
+ tl.float32
407
+ )
408
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
409
+ if HAS_SEQ_IDX:
410
+ seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
411
+
412
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
413
+ if HAS_SEQ_IDX:
414
+ seq_idx_last = tl.load(
415
+ seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
416
+ )
417
+
418
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
419
+ for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
420
+ x = tl.load(
421
+ x_ptrs,
422
+ mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
423
+ other=0.0,
424
+ )
425
+ b = tl.load(
426
+ b_ptrs,
427
+ mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
428
+ other=0.0,
429
+ ).to(tl.float32)
430
+ dA_cs_k = tl.load(
431
+ dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
432
+ ).to(tl.float32)
433
+ if HAS_SEQ_IDX:
434
+ seq_idx_k = tl.load(
435
+ seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
436
+ )
437
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
438
+ tl.float32
439
+ )
440
+ if not HAS_SEQ_IDX:
441
+ scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
442
+ else:
443
+ scale = tl.where(
444
+ seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
445
+ )
446
+ b *= scale[:, None]
447
+ b = b.to(x_ptr.dtype.element_ty)
448
+ acc += tl.dot(x, b)
449
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
450
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
451
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
452
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
453
+ if HAS_SEQ_IDX:
454
+ seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
455
+ states = acc.to(states_ptr.dtype.element_ty)
456
+
457
+ states_ptr += (
458
+ pid_b * stride_states_batch
459
+ + pid_c * stride_states_chunk
460
+ + pid_h * stride_states_head
461
+ )
462
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
463
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
464
+ states_ptrs = states_ptr + (
465
+ offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
466
+ )
467
+ c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
468
+ tl.store(states_ptrs, states, mask=c_mask)
469
+
470
+
471
+ @triton.autotune(
472
+ configs=[
473
+ triton.Config(
474
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
475
+ num_stages=3,
476
+ num_warps=8,
477
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
478
+ ),
479
+ triton.Config(
480
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
481
+ num_stages=4,
482
+ num_warps=4,
483
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
484
+ ),
485
+ triton.Config(
486
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
487
+ num_stages=4,
488
+ num_warps=4,
489
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
490
+ ),
491
+ triton.Config(
492
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
493
+ num_stages=4,
494
+ num_warps=4,
495
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
496
+ ),
497
+ triton.Config(
498
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
499
+ num_stages=4,
500
+ num_warps=4,
501
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
502
+ ),
503
+ triton.Config(
504
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
505
+ num_stages=4,
506
+ num_warps=4,
507
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
508
+ ),
509
+ triton.Config(
510
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
511
+ num_stages=5,
512
+ num_warps=4,
513
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
514
+ ),
515
+ triton.Config(
516
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
517
+ num_stages=5,
518
+ num_warps=4,
519
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
520
+ ),
521
+ triton.Config(
522
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
523
+ num_stages=4,
524
+ num_warps=4,
525
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
526
+ ),
527
+ ],
528
+ key=["chunk_size", "hdim", "dstate"],
529
+ )
530
+ @triton.jit
531
+ def _chunk_state_bwd_dx_kernel(
532
+ # Pointers to matrices
533
+ x_ptr,
534
+ b_ptr,
535
+ dstates_ptr,
536
+ dt_ptr,
537
+ dA_cumsum_ptr,
538
+ dx_ptr,
539
+ ddt_ptr,
540
+ ddA_cumsum_ptr,
541
+ # Matrix dimensions
542
+ chunk_size,
543
+ hdim,
544
+ dstate,
545
+ batch,
546
+ seqlen,
547
+ nheads_ngroups_ratio,
548
+ # Strides
549
+ stride_x_batch,
550
+ stride_x_seqlen,
551
+ stride_x_head,
552
+ stride_x_hdim,
553
+ stride_b_batch,
554
+ stride_b_seqlen,
555
+ stride_b_head,
556
+ stride_b_dstate,
557
+ stride_dstates_batch,
558
+ stride_dstates_chunk,
559
+ stride_states_head,
560
+ stride_states_hdim,
561
+ stride_states_dstate,
562
+ stride_dt_batch,
563
+ stride_dt_chunk,
564
+ stride_dt_head,
565
+ stride_dt_csize,
566
+ stride_dA_cs_batch,
567
+ stride_dA_cs_chunk,
568
+ stride_dA_cs_head,
569
+ stride_dA_cs_csize,
570
+ stride_dx_batch,
571
+ stride_dx_seqlen,
572
+ stride_dx_head,
573
+ stride_dx_hdim,
574
+ stride_ddt_batch,
575
+ stride_ddt_chunk,
576
+ stride_ddt_head,
577
+ stride_ddt_csize,
578
+ stride_ddA_cs_batch,
579
+ stride_ddA_cs_chunk,
580
+ stride_ddA_cs_head,
581
+ stride_ddA_cs_csize,
582
+ # Meta-parameters
583
+ BLOCK_SIZE_M: tl.constexpr,
584
+ BLOCK_SIZE_N: tl.constexpr,
585
+ BLOCK_SIZE_K: tl.constexpr,
586
+ BLOCK_SIZE_DSTATE: tl.constexpr,
587
+ ):
588
+ pid_bc = tl.program_id(axis=1)
589
+ pid_c = pid_bc // batch
590
+ pid_b = pid_bc - pid_c * batch
591
+ pid_h = tl.program_id(axis=2)
592
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
593
+ pid_m = tl.program_id(axis=0) // num_pid_n
594
+ pid_n = tl.program_id(axis=0) % num_pid_n
595
+ x_ptr += (
596
+ pid_b * stride_x_batch
597
+ + pid_c * chunk_size * stride_x_seqlen
598
+ + pid_h * stride_x_head
599
+ )
600
+ b_ptr += (
601
+ pid_b * stride_b_batch
602
+ + pid_c * chunk_size * stride_b_seqlen
603
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
604
+ )
605
+ dstates_ptr += (
606
+ pid_b * stride_dstates_batch
607
+ + pid_c * stride_dstates_chunk
608
+ + pid_h * stride_states_head
609
+ )
610
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
611
+ ddt_ptr += (
612
+ pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
613
+ )
614
+ ddA_cumsum_ptr += (
615
+ pid_b * stride_ddA_cs_batch
616
+ + pid_c * stride_ddA_cs_chunk
617
+ + pid_h * stride_ddA_cs_head
618
+ )
619
+ dA_cumsum_ptr += (
620
+ pid_b * stride_dA_cs_batch
621
+ + pid_c * stride_dA_cs_chunk
622
+ + pid_h * stride_dA_cs_head
623
+ )
624
+
625
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
626
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
627
+
628
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
629
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
630
+ offs_k = tl.arange(
631
+ 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
632
+ )
633
+ b_ptrs = b_ptr + (
634
+ offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
635
+ )
636
+ dstates_ptrs = dstates_ptr + (
637
+ offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
638
+ )
639
+ if BLOCK_SIZE_DSTATE <= 128:
640
+ b = tl.load(
641
+ b_ptrs,
642
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
643
+ other=0.0,
644
+ )
645
+ dstates = tl.load(
646
+ dstates_ptrs,
647
+ mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
648
+ other=0.0,
649
+ )
650
+ dstates = dstates.to(b_ptr.dtype.element_ty)
651
+ acc = tl.dot(b, dstates)
652
+ else:
653
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
654
+ for k in range(0, dstate, BLOCK_SIZE_K):
655
+ b = tl.load(
656
+ b_ptrs,
657
+ mask=(offs_m[:, None] < chunk_size_limit)
658
+ & (offs_k[None, :] < dstate - k),
659
+ other=0.0,
660
+ )
661
+ dstates = tl.load(
662
+ dstates_ptrs,
663
+ mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
664
+ other=0.0,
665
+ )
666
+ dstates = dstates.to(b_ptr.dtype.element_ty)
667
+ acc += tl.dot(b, dstates)
668
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
669
+ dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
670
+
671
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
672
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
673
+
674
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
675
+ tl.float32
676
+ )
677
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
678
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
679
+ dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
680
+ tl.float32
681
+ )
682
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
683
+ acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
684
+
685
+ x_ptrs = x_ptr + (
686
+ offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
687
+ )
688
+ x = tl.load(
689
+ x_ptrs,
690
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
691
+ other=0.0,
692
+ ).to(tl.float32)
693
+ ddt = tl.sum(acc * x, axis=1)
694
+ ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
695
+ tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
696
+ ddA_cs = -(ddt * dt_m)
697
+ ddA_cs_last = -tl.sum(ddA_cs)
698
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
699
+ tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
700
+ tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
701
+
702
+ dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
703
+ dx_ptr += (
704
+ pid_b * stride_dx_batch
705
+ + pid_c * chunk_size * stride_dx_seqlen
706
+ + pid_h * stride_dx_head
707
+ )
708
+ dx_ptrs = dx_ptr + (
709
+ offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
710
+ )
711
+ tl.store(
712
+ dx_ptrs,
713
+ dx,
714
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
715
+ )
716
+
717
+
718
+ @triton.autotune(
719
+ configs=[
720
+ triton.Config(
721
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
722
+ num_stages=3,
723
+ num_warps=4,
724
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
725
+ ),
726
+ triton.Config(
727
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
728
+ num_stages=3,
729
+ num_warps=4,
730
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
731
+ ),
732
+ triton.Config(
733
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
734
+ num_stages=3,
735
+ num_warps=4,
736
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
737
+ ),
738
+ triton.Config(
739
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
740
+ num_stages=3,
741
+ num_warps=4,
742
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
743
+ ),
744
+ triton.Config(
745
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
746
+ num_stages=3,
747
+ num_warps=4,
748
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
749
+ ),
750
+ triton.Config(
751
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
752
+ num_stages=3,
753
+ num_warps=4,
754
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
755
+ ),
756
+ triton.Config(
757
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
758
+ num_stages=3,
759
+ num_warps=4,
760
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
761
+ ),
762
+ triton.Config(
763
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
764
+ num_stages=3,
765
+ num_warps=4,
766
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
767
+ ),
768
+ ],
769
+ key=["chunk_size", "dstate", "hdim"],
770
+ )
771
+ @triton.jit
772
+ def _chunk_state_bwd_db_kernel(
773
+ # Pointers to matrices
774
+ x_ptr,
775
+ dstates_ptr,
776
+ b_ptr,
777
+ dt_ptr,
778
+ dA_cumsum_ptr,
779
+ seq_idx_ptr,
780
+ db_ptr,
781
+ ddA_cumsum_ptr,
782
+ # Matrix dimensions
783
+ chunk_size,
784
+ dstate,
785
+ hdim,
786
+ batch,
787
+ seqlen,
788
+ nheads,
789
+ nheads_per_program,
790
+ ngroups,
791
+ # Strides
792
+ stride_x_batch,
793
+ stride_x_seqlen,
794
+ stride_x_head,
795
+ stride_x_hdim,
796
+ stride_dstates_batch,
797
+ stride_dstates_chunk,
798
+ stride_states_head,
799
+ stride_states_hdim,
800
+ stride_states_dstate,
801
+ stride_b_batch,
802
+ stride_b_seqlen,
803
+ stride_b_head,
804
+ stride_b_dstate,
805
+ stride_dt_batch,
806
+ stride_dt_chunk,
807
+ stride_dt_head,
808
+ stride_dt_csize,
809
+ stride_dA_cs_batch,
810
+ stride_dA_cs_chunk,
811
+ stride_dA_cs_head,
812
+ stride_dA_cs_csize,
813
+ stride_seq_idx_batch,
814
+ stride_seq_idx_seqlen,
815
+ stride_db_batch,
816
+ stride_db_seqlen,
817
+ stride_db_split,
818
+ stride_db_group,
819
+ stride_db_dstate,
820
+ stride_ddA_cs_batch,
821
+ stride_ddA_cs_chunk,
822
+ stride_ddA_cs_head,
823
+ stride_ddA_cs_csize,
824
+ # Meta-parameters
825
+ HAS_DDA_CS: tl.constexpr,
826
+ HAS_SEQ_IDX: tl.constexpr,
827
+ BLOCK_SIZE_M: tl.constexpr,
828
+ BLOCK_SIZE_N: tl.constexpr,
829
+ BLOCK_SIZE_K: tl.constexpr,
830
+ ):
831
+ pid_bc = tl.program_id(axis=1)
832
+ pid_c = pid_bc // batch
833
+ pid_b = pid_bc - pid_c * batch
834
+ pid_sg = tl.program_id(axis=2)
835
+ pid_s = pid_sg // ngroups
836
+ pid_g = pid_sg - pid_s * ngroups
837
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
838
+ pid_m = tl.program_id(axis=0) // num_pid_n
839
+ pid_n = tl.program_id(axis=0) % num_pid_n
840
+ x_ptr += (
841
+ pid_b * stride_x_batch
842
+ + pid_c * chunk_size * stride_x_seqlen
843
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
844
+ )
845
+ db_ptr += (
846
+ pid_b * stride_db_batch
847
+ + pid_c * chunk_size * stride_db_seqlen
848
+ + pid_g * stride_db_group
849
+ + pid_s * stride_db_split
850
+ )
851
+ dstates_ptr += (
852
+ pid_b * stride_dstates_batch
853
+ + pid_c * stride_dstates_chunk
854
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
855
+ * stride_states_head
856
+ )
857
+ dt_ptr += (
858
+ pid_b * stride_dt_batch
859
+ + pid_c * stride_dt_chunk
860
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
861
+ )
862
+ dA_cumsum_ptr += (
863
+ pid_b * stride_dA_cs_batch
864
+ + pid_c * stride_dA_cs_chunk
865
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
866
+ )
867
+ if HAS_DDA_CS:
868
+ b_ptr += (
869
+ pid_b * stride_b_batch
870
+ + pid_c * chunk_size * stride_b_seqlen
871
+ + pid_g * stride_b_head
872
+ )
873
+ ddA_cumsum_ptr += (
874
+ pid_b * stride_ddA_cs_batch
875
+ + pid_c * stride_ddA_cs_chunk
876
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
877
+ * stride_ddA_cs_head
878
+ )
879
+ if HAS_SEQ_IDX:
880
+ seq_idx_ptr += (
881
+ pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
882
+ )
883
+
884
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
885
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
886
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
887
+ x_ptrs = x_ptr + (
888
+ offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
889
+ )
890
+ dstates_ptrs = dstates_ptr + (
891
+ offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
892
+ )
893
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
894
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
895
+ if HAS_DDA_CS:
896
+ b_ptrs = b_ptr + (
897
+ offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
898
+ )
899
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
900
+
901
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
902
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
903
+ if HAS_DDA_CS:
904
+ b = tl.load(
905
+ b_ptrs,
906
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
907
+ other=0.0,
908
+ ).to(tl.float32)
909
+ if HAS_SEQ_IDX:
910
+ seq_idx_m = tl.load(
911
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
912
+ mask=offs_m < chunk_size_limit,
913
+ other=-1,
914
+ )
915
+ seq_idx_last = tl.load(
916
+ seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
917
+ )
918
+ nheads_iter = min(
919
+ nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
920
+ )
921
+ for h in range(nheads_iter):
922
+ x = tl.load(
923
+ x_ptrs,
924
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
925
+ other=0.0,
926
+ )
927
+ dstates = tl.load(
928
+ dstates_ptrs,
929
+ mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
930
+ other=0.0,
931
+ )
932
+ dstates = dstates.to(x_ptrs.dtype.element_ty)
933
+ db = tl.dot(x, dstates)
934
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
935
+ tl.float32
936
+ )
937
+ dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
938
+ tl.float32
939
+ )
940
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
941
+ if not HAS_SEQ_IDX:
942
+ scale = tl.exp(dA_cs_last - dA_cs_m)
943
+ else:
944
+ scale = tl.where(
945
+ seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
946
+ )
947
+ db *= (scale * dt_m)[:, None]
948
+ if HAS_DDA_CS:
949
+ # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
950
+ ddA_cs = tl.sum(db * b, axis=1)
951
+ tl.atomic_add(
952
+ ddA_cumsum_ptrs + stride_ddA_cs_csize,
953
+ ddA_cs,
954
+ mask=offs_m < chunk_size - 1,
955
+ )
956
+ acc += db
957
+ x_ptrs += stride_x_head
958
+ dstates_ptrs += stride_states_head
959
+ dt_ptrs += stride_dt_head
960
+ dA_cumsum_ptr += stride_dA_cs_head
961
+ dA_cumsum_ptrs += stride_dA_cs_head
962
+ if HAS_DDA_CS:
963
+ ddA_cumsum_ptrs += stride_ddA_cs_head
964
+
965
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
966
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
967
+ # if HAS_SEQ_IDX:
968
+ # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
969
+ # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
970
+ # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
971
+ db_ptrs = db_ptr + (
972
+ offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
973
+ )
974
+ tl.store(
975
+ db_ptrs,
976
+ acc,
977
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
978
+ )
979
+
980
+
981
+ @triton.autotune(
982
+ configs=[
983
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
984
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
985
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
986
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
987
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
988
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
989
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
990
+ # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
991
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
992
+ triton.Config(
993
+ {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
994
+ num_stages=3,
995
+ num_warps=4,
996
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
997
+ ),
998
+ triton.Config(
999
+ {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1000
+ num_stages=3,
1001
+ num_warps=4,
1002
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1003
+ ),
1004
+ triton.Config(
1005
+ {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1006
+ num_stages=3,
1007
+ num_warps=4,
1008
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1009
+ ),
1010
+ triton.Config(
1011
+ {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1012
+ num_stages=3,
1013
+ num_warps=4,
1014
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1015
+ ),
1016
+ triton.Config(
1017
+ {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
1018
+ num_stages=4,
1019
+ num_warps=8,
1020
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1021
+ ),
1022
+ triton.Config(
1023
+ {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1024
+ num_stages=4,
1025
+ num_warps=8,
1026
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1027
+ ),
1028
+ triton.Config(
1029
+ {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1030
+ num_stages=4,
1031
+ num_warps=8,
1032
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1033
+ ),
1034
+ triton.Config(
1035
+ {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1036
+ num_stages=4,
1037
+ num_warps=8,
1038
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1039
+ ),
1040
+ ],
1041
+ key=["chunk_size", "hdim", "dstate"],
1042
+ )
1043
+ @triton.jit
1044
+ def _chunk_state_bwd_ddAcs_stable_kernel(
1045
+ # Pointers to matrices
1046
+ x_ptr,
1047
+ b_ptr,
1048
+ dstates_ptr,
1049
+ dt_ptr,
1050
+ dA_cumsum_ptr,
1051
+ seq_idx_ptr,
1052
+ ddA_cumsum_ptr,
1053
+ # Matrix dimensions
1054
+ chunk_size,
1055
+ hdim,
1056
+ dstate,
1057
+ batch,
1058
+ seqlen,
1059
+ nheads_ngroups_ratio,
1060
+ # Strides
1061
+ stride_x_batch,
1062
+ stride_x_seqlen,
1063
+ stride_x_head,
1064
+ stride_x_hdim,
1065
+ stride_b_batch,
1066
+ stride_b_seqlen,
1067
+ stride_b_head,
1068
+ stride_b_dstate,
1069
+ stride_dstates_batch,
1070
+ stride_dstates_chunk,
1071
+ stride_states_head,
1072
+ stride_states_hdim,
1073
+ stride_states_dstate,
1074
+ stride_dt_batch,
1075
+ stride_dt_chunk,
1076
+ stride_dt_head,
1077
+ stride_dt_csize,
1078
+ stride_dA_cs_batch,
1079
+ stride_dA_cs_chunk,
1080
+ stride_dA_cs_head,
1081
+ stride_dA_cs_csize,
1082
+ stride_seq_idx_batch,
1083
+ stride_seq_idx_seqlen,
1084
+ stride_ddA_cs_batch,
1085
+ stride_ddA_cs_chunk,
1086
+ stride_ddA_cs_head,
1087
+ stride_ddA_cs_csize,
1088
+ # Meta-parameters
1089
+ HAS_SEQ_IDX: tl.constexpr,
1090
+ BLOCK_SIZE_M: tl.constexpr,
1091
+ BLOCK_SIZE_N: tl.constexpr,
1092
+ BLOCK_SIZE_K: tl.constexpr,
1093
+ BLOCK_SIZE_DSTATE: tl.constexpr,
1094
+ ):
1095
+ pid_bc = tl.program_id(axis=1)
1096
+ pid_c = pid_bc // batch
1097
+ pid_b = pid_bc - pid_c * batch
1098
+ pid_h = tl.program_id(axis=2)
1099
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
1100
+ pid_m = tl.program_id(axis=0) // num_pid_n
1101
+ pid_n = tl.program_id(axis=0) % num_pid_n
1102
+ x_ptr += (
1103
+ pid_b * stride_x_batch
1104
+ + pid_c * chunk_size * stride_x_seqlen
1105
+ + pid_h * stride_x_head
1106
+ )
1107
+ b_ptr += (
1108
+ pid_b * stride_b_batch
1109
+ + pid_c * chunk_size * stride_b_seqlen
1110
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
1111
+ )
1112
+ dstates_ptr += (
1113
+ pid_b * stride_dstates_batch
1114
+ + pid_c * stride_dstates_chunk
1115
+ + pid_h * stride_states_head
1116
+ )
1117
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
1118
+ ddA_cumsum_ptr += (
1119
+ pid_b * stride_ddA_cs_batch
1120
+ + pid_c * stride_ddA_cs_chunk
1121
+ + pid_h * stride_ddA_cs_head
1122
+ )
1123
+ dA_cumsum_ptr += (
1124
+ pid_b * stride_dA_cs_batch
1125
+ + pid_c * stride_dA_cs_chunk
1126
+ + pid_h * stride_dA_cs_head
1127
+ )
1128
+ if HAS_SEQ_IDX:
1129
+ seq_idx_ptr += (
1130
+ pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
1131
+ )
1132
+
1133
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1134
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1135
+
1136
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
1137
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
1138
+ offs_k = tl.arange(
1139
+ 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
1140
+ )
1141
+ b_ptrs = b_ptr + (
1142
+ offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
1143
+ )
1144
+ dstates_ptrs = dstates_ptr + (
1145
+ offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
1146
+ )
1147
+ if BLOCK_SIZE_DSTATE <= 128:
1148
+ b = tl.load(
1149
+ b_ptrs,
1150
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
1151
+ other=0.0,
1152
+ )
1153
+ dstates = tl.load(
1154
+ dstates_ptrs,
1155
+ mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
1156
+ other=0.0,
1157
+ )
1158
+ dstates = dstates.to(b_ptr.dtype.element_ty)
1159
+ acc = tl.dot(b, dstates)
1160
+ else:
1161
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1162
+ for k in range(0, dstate, BLOCK_SIZE_K):
1163
+ b = tl.load(
1164
+ b_ptrs,
1165
+ mask=(offs_m[:, None] < chunk_size_limit)
1166
+ & (offs_k[None, :] < dstate - k),
1167
+ other=0.0,
1168
+ )
1169
+ dstates = tl.load(
1170
+ dstates_ptrs,
1171
+ mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
1172
+ other=0.0,
1173
+ )
1174
+ dstates = dstates.to(b_ptr.dtype.element_ty)
1175
+ acc += tl.dot(b, dstates)
1176
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
1177
+ dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
1178
+
1179
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1180
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1181
+
1182
+ dA_cs_m = tl.load(
1183
+ dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
1184
+ ).to(tl.float32)
1185
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
1186
+ tl.float32
1187
+ )
1188
+ if not HAS_SEQ_IDX:
1189
+ scale = tl.exp(dA_cs_last - dA_cs_m)
1190
+ else:
1191
+ seq_idx_m = tl.load(
1192
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
1193
+ mask=offs_m < chunk_size_limit,
1194
+ other=-1,
1195
+ )
1196
+ seq_idx_last = tl.load(
1197
+ seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
1198
+ )
1199
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
1200
+ acc *= scale[:, None]
1201
+
1202
+ x_ptrs = x_ptr + (
1203
+ offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
1204
+ )
1205
+ x = tl.load(
1206
+ x_ptrs,
1207
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
1208
+ other=0.0,
1209
+ ).to(tl.float32)
1210
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
1211
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
1212
+ ddt = tl.sum(acc * x, axis=1)
1213
+ # ddA_cs = -(ddt * dt_m)
1214
+ # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
1215
+ # then call torch.cumsum outside this kernel.
1216
+ # ddA_cs = tl.cumsum(ddt * dt_m)
1217
+ ddA_cs = ddt * dt_m
1218
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
1219
+ # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
1220
+ tl.atomic_add(
1221
+ ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
1222
+ )
1223
+
1224
+
1225
+ @triton.autotune(
1226
+ configs=[
1227
+ triton.Config(
1228
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
1229
+ num_stages=3,
1230
+ num_warps=8,
1231
+ ),
1232
+ triton.Config(
1233
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
1234
+ num_stages=4,
1235
+ num_warps=4,
1236
+ ),
1237
+ triton.Config(
1238
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1239
+ num_stages=4,
1240
+ num_warps=4,
1241
+ ),
1242
+ triton.Config(
1243
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1244
+ num_stages=4,
1245
+ num_warps=4,
1246
+ ),
1247
+ triton.Config(
1248
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1249
+ num_stages=4,
1250
+ num_warps=4,
1251
+ ),
1252
+ triton.Config(
1253
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1254
+ num_stages=4,
1255
+ num_warps=4,
1256
+ ),
1257
+ triton.Config(
1258
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1259
+ num_stages=5,
1260
+ num_warps=2,
1261
+ ),
1262
+ triton.Config(
1263
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1264
+ num_stages=5,
1265
+ num_warps=2,
1266
+ ),
1267
+ triton.Config(
1268
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1269
+ num_stages=4,
1270
+ num_warps=2,
1271
+ ),
1272
+ ],
1273
+ key=["hdim", "dstate", "chunk_size"],
1274
+ )
1275
+ @triton.jit
1276
+ def _chunk_state_varlen_kernel(
1277
+ # Pointers to matrices
1278
+ x_ptr,
1279
+ b_ptr,
1280
+ dt_ptr,
1281
+ dA_cumsum_ptr,
1282
+ chunk_states_ptr,
1283
+ cu_seqlens_ptr,
1284
+ states_ptr,
1285
+ # Matrix dimensions
1286
+ hdim,
1287
+ dstate,
1288
+ chunk_size,
1289
+ seqlen,
1290
+ nheads_ngroups_ratio,
1291
+ # Strides
1292
+ stride_x_seqlen,
1293
+ stride_x_head,
1294
+ stride_x_hdim,
1295
+ stride_b_seqlen,
1296
+ stride_b_head,
1297
+ stride_b_dstate,
1298
+ stride_dt_chunk,
1299
+ stride_dt_head,
1300
+ stride_dt_csize,
1301
+ stride_dA_cs_chunk,
1302
+ stride_dA_cs_head,
1303
+ stride_dA_cs_csize,
1304
+ stride_chunk_states_chunk,
1305
+ stride_chunk_states_head,
1306
+ stride_chunk_states_hdim,
1307
+ stride_chunk_states_dstate,
1308
+ stride_states_batch,
1309
+ stride_states_head,
1310
+ stride_states_hdim,
1311
+ stride_states_dstate,
1312
+ # Meta-parameters
1313
+ BLOCK_SIZE_M: tl.constexpr,
1314
+ BLOCK_SIZE_N: tl.constexpr,
1315
+ BLOCK_SIZE_K: tl.constexpr,
1316
+ ):
1317
+ pid_b = tl.program_id(axis=1)
1318
+ pid_h = tl.program_id(axis=2)
1319
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
1320
+ pid_m = tl.program_id(axis=0) // num_pid_n
1321
+ pid_n = tl.program_id(axis=0) % num_pid_n
1322
+ end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
1323
+ pid_c = (end_idx - 1) // chunk_size
1324
+ b_ptr += (
1325
+ pid_c * chunk_size * stride_b_seqlen
1326
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
1327
+ )
1328
+ x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
1329
+ dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
1330
+ dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
1331
+ chunk_states_ptr += (
1332
+ pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
1333
+ )
1334
+
1335
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1336
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1337
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
1338
+ x_ptrs = x_ptr + (
1339
+ offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
1340
+ )
1341
+ b_ptrs = b_ptr + (
1342
+ offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
1343
+ )
1344
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
1345
+ dA_cs_last = tl.load(
1346
+ dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
1347
+ ).to(tl.float32)
1348
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
1349
+
1350
+ chunk_size_limit = end_idx - pid_c * chunk_size
1351
+ start_idx = tl.load(cu_seqlens_ptr + pid_b)
1352
+ start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
1353
+
1354
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1355
+ for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
1356
+ x = tl.load(
1357
+ x_ptrs,
1358
+ mask=(offs_m[:, None] < hdim)
1359
+ & (offs_k[None, :] < chunk_size_limit - k)
1360
+ & (offs_k[None, :] >= start_idx_cur - k),
1361
+ other=0.0,
1362
+ )
1363
+ b = tl.load(
1364
+ b_ptrs,
1365
+ mask=(offs_k[:, None] < chunk_size_limit - k)
1366
+ & (offs_n[None, :] < dstate)
1367
+ & (offs_k[:, None] >= start_idx_cur - k),
1368
+ other=0.0,
1369
+ ).to(tl.float32)
1370
+ dA_cs_k = tl.load(
1371
+ dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
1372
+ ).to(tl.float32)
1373
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
1374
+ tl.float32
1375
+ )
1376
+ scale = tl.where(
1377
+ (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
1378
+ tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
1379
+ 0.0,
1380
+ )
1381
+ b *= scale[:, None]
1382
+ b = b.to(x_ptr.dtype.element_ty)
1383
+ acc += tl.dot(x, b)
1384
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
1385
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
1386
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
1387
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
1388
+
1389
+ # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
1390
+ if start_idx < pid_c * chunk_size:
1391
+ chunk_states_ptrs = chunk_states_ptr + (
1392
+ offs_m[:, None] * stride_chunk_states_hdim
1393
+ + offs_n[None, :] * stride_chunk_states_dstate
1394
+ )
1395
+ chunk_states = tl.load(
1396
+ chunk_states_ptrs,
1397
+ mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
1398
+ other=0.0,
1399
+ ).to(tl.float32)
1400
+ # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
1401
+ scale = tl.exp(dA_cs_last)
1402
+ acc += chunk_states * scale
1403
+
1404
+ states = acc.to(states_ptr.dtype.element_ty)
1405
+
1406
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
1407
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1408
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1409
+ states_ptrs = states_ptr + (
1410
+ offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
1411
+ )
1412
+ c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
1413
+ tl.store(states_ptrs, states, mask=c_mask)
1414
+
1415
+
1416
+ def _chunk_cumsum_fwd(
1417
+ dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
1418
+ ):
1419
+ batch, seqlen, nheads = dt.shape
1420
+ assert A.shape == (nheads,)
1421
+ if dt_bias is not None:
1422
+ assert dt_bias.shape == (nheads,)
1423
+ nchunks = math.ceil(seqlen / chunk_size)
1424
+ dt_out = torch.empty(
1425
+ batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1426
+ )
1427
+ dA_cumsum = torch.empty(
1428
+ batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1429
+ )
1430
+ grid_chunk_cs = lambda META: (
1431
+ batch,
1432
+ nchunks,
1433
+ triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1434
+ )
1435
+ with torch.cuda.device(dt.device.index):
1436
+ _chunk_cumsum_fwd_kernel[grid_chunk_cs](
1437
+ dt,
1438
+ A,
1439
+ dt_bias,
1440
+ dt_out,
1441
+ dA_cumsum,
1442
+ batch,
1443
+ seqlen,
1444
+ nheads,
1445
+ chunk_size,
1446
+ dt_limit[0],
1447
+ dt_limit[1],
1448
+ dt.stride(0),
1449
+ dt.stride(1),
1450
+ dt.stride(2),
1451
+ A.stride(0),
1452
+ dt_bias.stride(0) if dt_bias is not None else 0,
1453
+ dt_out.stride(0),
1454
+ dt_out.stride(2),
1455
+ dt_out.stride(1),
1456
+ dt_out.stride(3),
1457
+ dA_cumsum.stride(0),
1458
+ dA_cumsum.stride(2),
1459
+ dA_cumsum.stride(1),
1460
+ dA_cumsum.stride(3),
1461
+ dt_softplus,
1462
+ HAS_DT_BIAS=dt_bias is not None,
1463
+ BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1464
+ )
1465
+ return dA_cumsum, dt_out
1466
+
1467
+
1468
+ def _chunk_cumsum_bwd(
1469
+ ddA,
1470
+ ddt_out,
1471
+ dt,
1472
+ A,
1473
+ dt_bias=None,
1474
+ dt_softplus=False,
1475
+ dt_limit=(0.0, float("inf")),
1476
+ ddt=None,
1477
+ ):
1478
+ batch, seqlen, nheads = dt.shape
1479
+ _, _, nchunks, chunk_size = ddA.shape
1480
+ assert ddA.shape == (batch, nheads, nchunks, chunk_size)
1481
+ assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
1482
+ assert A.shape == (nheads,)
1483
+ if dt_bias is not None:
1484
+ assert dt_bias.shape == (nheads,)
1485
+ ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
1486
+ else:
1487
+ ddt_bias = None
1488
+ if ddt is not None:
1489
+ assert ddt.shape == dt.shape
1490
+ else:
1491
+ ddt = torch.empty_like(dt)
1492
+ dA = torch.empty_like(A, dtype=torch.float32)
1493
+ grid_chunk_cs = lambda META: (
1494
+ batch,
1495
+ nchunks,
1496
+ triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1497
+ )
1498
+ with torch.cuda.device(dt.device.index):
1499
+ _chunk_cumsum_bwd_kernel[grid_chunk_cs](
1500
+ ddA,
1501
+ ddt_out,
1502
+ dt,
1503
+ A,
1504
+ dt_bias,
1505
+ ddt,
1506
+ dA,
1507
+ ddt_bias,
1508
+ batch,
1509
+ seqlen,
1510
+ nheads,
1511
+ chunk_size,
1512
+ dt_limit[0],
1513
+ dt_limit[1],
1514
+ ddA.stride(0),
1515
+ ddA.stride(2),
1516
+ ddA.stride(1),
1517
+ ddA.stride(3),
1518
+ ddt_out.stride(0),
1519
+ ddt_out.stride(2),
1520
+ ddt_out.stride(1),
1521
+ ddt_out.stride(3),
1522
+ dt.stride(0),
1523
+ dt.stride(1),
1524
+ dt.stride(2),
1525
+ A.stride(0),
1526
+ dt_bias.stride(0) if dt_bias is not None else 0,
1527
+ ddt.stride(0),
1528
+ ddt.stride(1),
1529
+ ddt.stride(2),
1530
+ dA.stride(0),
1531
+ ddt_bias.stride(0) if ddt_bias is not None else 0,
1532
+ dt_softplus,
1533
+ HAS_DT_BIAS=dt_bias is not None,
1534
+ BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1535
+ )
1536
+ return ddt, dA, ddt_bias
1537
+
1538
+
1539
+ def _chunk_state_fwd(
1540
+ B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
1541
+ ):
1542
+ batch, seqlen, nheads, headdim = x.shape
1543
+ _, _, nchunks, chunk_size = dt.shape
1544
+ _, _, ngroups, dstate = B.shape
1545
+ assert nheads % ngroups == 0
1546
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1547
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1548
+ assert dA_cumsum.shape == dt.shape
1549
+ if seq_idx is not None:
1550
+ assert seq_idx.shape == (batch, seqlen)
1551
+ if states is not None:
1552
+ assert states.shape == (batch, nchunks, nheads, headdim, dstate)
1553
+ else:
1554
+ states_dtype = torch.float32 if states_in_fp32 else B.dtype
1555
+ states = torch.empty(
1556
+ (batch, nchunks, nheads, headdim, dstate),
1557
+ device=x.device,
1558
+ dtype=states_dtype,
1559
+ )
1560
+ grid = lambda META: (
1561
+ triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1562
+ * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1563
+ batch * nchunks,
1564
+ nheads,
1565
+ )
1566
+ with torch.cuda.device(x.device.index):
1567
+ _chunk_state_fwd_kernel[grid](
1568
+ x,
1569
+ B,
1570
+ states,
1571
+ dt,
1572
+ dA_cumsum,
1573
+ seq_idx,
1574
+ headdim,
1575
+ dstate,
1576
+ chunk_size,
1577
+ batch,
1578
+ seqlen,
1579
+ nheads // ngroups,
1580
+ x.stride(0),
1581
+ x.stride(1),
1582
+ x.stride(2),
1583
+ x.stride(3),
1584
+ B.stride(0),
1585
+ B.stride(1),
1586
+ B.stride(2),
1587
+ B.stride(-1),
1588
+ states.stride(0),
1589
+ states.stride(1),
1590
+ states.stride(2),
1591
+ states.stride(3),
1592
+ states.stride(4),
1593
+ dt.stride(0),
1594
+ dt.stride(2),
1595
+ dt.stride(1),
1596
+ dt.stride(3),
1597
+ dA_cumsum.stride(0),
1598
+ dA_cumsum.stride(2),
1599
+ dA_cumsum.stride(1),
1600
+ dA_cumsum.stride(3),
1601
+ *(
1602
+ (seq_idx.stride(0), seq_idx.stride(1))
1603
+ if seq_idx is not None
1604
+ else (0, 0)
1605
+ ),
1606
+ HAS_SEQ_IDX=seq_idx is not None,
1607
+ )
1608
+ return states
1609
+
1610
+
1611
+ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
1612
+ batch, seqlen, nheads, headdim = x.shape
1613
+ _, _, nchunks, chunk_size = dt.shape
1614
+ _, _, ngroups, dstate = B.shape
1615
+ assert nheads % ngroups == 0
1616
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1617
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1618
+ assert dA_cumsum.shape == dt.shape
1619
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1620
+ if dx is not None:
1621
+ assert dx.shape == x.shape
1622
+ else:
1623
+ dx = torch.empty_like(x)
1624
+ ddt = torch.empty(
1625
+ batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1626
+ )
1627
+ ddA_cumsum = torch.empty(
1628
+ batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
1629
+ )
1630
+ grid_dx = lambda META: (
1631
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1632
+ * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1633
+ batch * nchunks,
1634
+ nheads,
1635
+ )
1636
+ with torch.cuda.device(x.device.index):
1637
+ _chunk_state_bwd_dx_kernel[grid_dx](
1638
+ x,
1639
+ B,
1640
+ dstates,
1641
+ dt,
1642
+ dA_cumsum,
1643
+ dx,
1644
+ ddt,
1645
+ ddA_cumsum,
1646
+ chunk_size,
1647
+ headdim,
1648
+ dstate,
1649
+ batch,
1650
+ seqlen,
1651
+ nheads // ngroups,
1652
+ x.stride(0),
1653
+ x.stride(1),
1654
+ x.stride(2),
1655
+ x.stride(3),
1656
+ B.stride(0),
1657
+ B.stride(1),
1658
+ B.stride(2),
1659
+ B.stride(-1),
1660
+ dstates.stride(0),
1661
+ dstates.stride(1),
1662
+ dstates.stride(2),
1663
+ dstates.stride(3),
1664
+ dstates.stride(4),
1665
+ dt.stride(0),
1666
+ dt.stride(2),
1667
+ dt.stride(1),
1668
+ dt.stride(3),
1669
+ dA_cumsum.stride(0),
1670
+ dA_cumsum.stride(2),
1671
+ dA_cumsum.stride(1),
1672
+ dA_cumsum.stride(3),
1673
+ dx.stride(0),
1674
+ dx.stride(1),
1675
+ dx.stride(2),
1676
+ dx.stride(3),
1677
+ ddt.stride(0),
1678
+ ddt.stride(2),
1679
+ ddt.stride(1),
1680
+ ddt.stride(3),
1681
+ ddA_cumsum.stride(0),
1682
+ ddA_cumsum.stride(2),
1683
+ ddA_cumsum.stride(1),
1684
+ ddA_cumsum.stride(3),
1685
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1686
+ )
1687
+ return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
1688
+
1689
+
1690
+ def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
1691
+ batch, seqlen, nheads, headdim = x.shape
1692
+ _, _, nchunks, chunk_size = dt.shape
1693
+ dstate = dstates.shape[-1]
1694
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1695
+ assert dA_cumsum.shape == dt.shape
1696
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1697
+ if seq_idx is not None:
1698
+ assert seq_idx.shape == (batch, seqlen)
1699
+ if B is not None:
1700
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1701
+ B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
1702
+ # Use torch.empty since the Triton kernel will call init_to_zero
1703
+ ddA_cumsum = torch.empty(
1704
+ batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1705
+ )
1706
+ ddA_cumsum_strides = (
1707
+ ddA_cumsum.stride(0),
1708
+ ddA_cumsum.stride(2),
1709
+ ddA_cumsum.stride(1),
1710
+ ddA_cumsum.stride(3),
1711
+ )
1712
+ else:
1713
+ B_strides = (0, 0, 0, 0)
1714
+ ddA_cumsum = None
1715
+ ddA_cumsum_strides = (0, 0, 0, 0)
1716
+ nheads_ngroups_ratio = nheads // ngroups
1717
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
1718
+ nheads_per_program = max(
1719
+ min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
1720
+ )
1721
+ nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
1722
+ dB = torch.empty(
1723
+ batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
1724
+ )
1725
+ grid_db = lambda META: (
1726
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1727
+ * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1728
+ batch * nchunks,
1729
+ nsplits * ngroups,
1730
+ )
1731
+ with torch.cuda.device(x.device.index):
1732
+ _chunk_state_bwd_db_kernel[grid_db](
1733
+ x,
1734
+ dstates,
1735
+ B,
1736
+ dt,
1737
+ dA_cumsum,
1738
+ seq_idx,
1739
+ dB,
1740
+ ddA_cumsum,
1741
+ chunk_size,
1742
+ dstate,
1743
+ headdim,
1744
+ batch,
1745
+ seqlen,
1746
+ nheads,
1747
+ nheads_per_program,
1748
+ ngroups,
1749
+ x.stride(0),
1750
+ x.stride(1),
1751
+ x.stride(2),
1752
+ x.stride(3),
1753
+ dstates.stride(0),
1754
+ dstates.stride(1),
1755
+ dstates.stride(2),
1756
+ dstates.stride(3),
1757
+ dstates.stride(4),
1758
+ *B_strides,
1759
+ dt.stride(0),
1760
+ dt.stride(2),
1761
+ dt.stride(1),
1762
+ dt.stride(3),
1763
+ dA_cumsum.stride(0),
1764
+ dA_cumsum.stride(2),
1765
+ dA_cumsum.stride(1),
1766
+ dA_cumsum.stride(3),
1767
+ *(
1768
+ (seq_idx.stride(0), seq_idx.stride(1))
1769
+ if seq_idx is not None
1770
+ else (0, 0)
1771
+ ),
1772
+ dB.stride(0),
1773
+ dB.stride(1),
1774
+ dB.stride(2),
1775
+ dB.stride(3),
1776
+ dB.stride(4),
1777
+ *ddA_cumsum_strides,
1778
+ HAS_DDA_CS=ddA_cumsum is not None,
1779
+ HAS_SEQ_IDX=seq_idx is not None,
1780
+ BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
1781
+ )
1782
+ dB = dB.sum(2)
1783
+ if ddA_cumsum is not None:
1784
+ # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
1785
+ # to the state of the chunk.
1786
+ # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1787
+ # But it's easier to just do the cumsum for all elements, the result will be the same.
1788
+ torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
1789
+ return dB if B is None else (dB, ddA_cumsum)
1790
+
1791
+
1792
+ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
1793
+ batch, seqlen, nheads, headdim = x.shape
1794
+ _, _, nchunks, chunk_size = dt.shape
1795
+ _, _, ngroups, dstate = B.shape
1796
+ assert nheads % ngroups == 0
1797
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1798
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1799
+ assert dA_cumsum.shape == dt.shape
1800
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1801
+ if seq_idx is not None:
1802
+ assert seq_idx.shape == (batch, seqlen)
1803
+ # Use torch.empty since the Triton kernel will call init_to_zero
1804
+ ddA_cumsum = torch.empty(
1805
+ batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1806
+ )
1807
+ grid_ddtcs = lambda META: (
1808
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1809
+ * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1810
+ batch * nchunks,
1811
+ nheads,
1812
+ )
1813
+ with torch.cuda.device(x.device.index):
1814
+ _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
1815
+ x,
1816
+ B,
1817
+ dstates,
1818
+ dt,
1819
+ dA_cumsum,
1820
+ seq_idx,
1821
+ ddA_cumsum,
1822
+ chunk_size,
1823
+ headdim,
1824
+ dstate,
1825
+ batch,
1826
+ seqlen,
1827
+ nheads // ngroups,
1828
+ x.stride(0),
1829
+ x.stride(1),
1830
+ x.stride(2),
1831
+ x.stride(3),
1832
+ B.stride(0),
1833
+ B.stride(1),
1834
+ B.stride(2),
1835
+ B.stride(-1),
1836
+ dstates.stride(0),
1837
+ dstates.stride(1),
1838
+ dstates.stride(2),
1839
+ dstates.stride(3),
1840
+ dstates.stride(4),
1841
+ dt.stride(0),
1842
+ dt.stride(2),
1843
+ dt.stride(1),
1844
+ dt.stride(3),
1845
+ dA_cumsum.stride(0),
1846
+ dA_cumsum.stride(2),
1847
+ dA_cumsum.stride(1),
1848
+ dA_cumsum.stride(3),
1849
+ *(
1850
+ (seq_idx.stride(0), seq_idx.stride(1))
1851
+ if seq_idx is not None
1852
+ else (0, 0)
1853
+ ),
1854
+ ddA_cumsum.stride(0),
1855
+ ddA_cumsum.stride(2),
1856
+ ddA_cumsum.stride(1),
1857
+ ddA_cumsum.stride(3),
1858
+ HAS_SEQ_IDX=seq_idx is not None,
1859
+ BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
1860
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1861
+ )
1862
+ torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1863
+ return ddA_cumsum
1864
+
1865
+
1866
+ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
1867
+ total_seqlen, nheads, headdim = x.shape
1868
+ _, nchunks, chunk_size = dt.shape
1869
+ _, ngroups, dstate = B.shape
1870
+ batch = cu_seqlens.shape[0] - 1
1871
+ cu_seqlens = cu_seqlens.contiguous()
1872
+ assert nheads % ngroups == 0
1873
+ assert B.shape == (total_seqlen, ngroups, dstate)
1874
+ assert dt.shape == (nheads, nchunks, chunk_size)
1875
+ assert dA_cumsum.shape == dt.shape
1876
+ assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
1877
+ states = torch.empty(
1878
+ batch,
1879
+ nheads,
1880
+ headdim,
1881
+ dstate,
1882
+ dtype=chunk_states.dtype,
1883
+ device=chunk_states.device,
1884
+ )
1885
+ grid = lambda META: (
1886
+ triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1887
+ * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1888
+ batch,
1889
+ nheads,
1890
+ )
1891
+ with torch.cuda.device(x.device.index):
1892
+ _chunk_state_varlen_kernel[grid](
1893
+ x,
1894
+ B,
1895
+ dt,
1896
+ dA_cumsum,
1897
+ chunk_states,
1898
+ cu_seqlens,
1899
+ states,
1900
+ headdim,
1901
+ dstate,
1902
+ chunk_size,
1903
+ total_seqlen,
1904
+ nheads // ngroups,
1905
+ x.stride(0),
1906
+ x.stride(1),
1907
+ x.stride(2),
1908
+ B.stride(0),
1909
+ B.stride(1),
1910
+ B.stride(2),
1911
+ dt.stride(1),
1912
+ dt.stride(0),
1913
+ dt.stride(2),
1914
+ dA_cumsum.stride(1),
1915
+ dA_cumsum.stride(0),
1916
+ dA_cumsum.stride(2),
1917
+ chunk_states.stride(0),
1918
+ chunk_states.stride(1),
1919
+ chunk_states.stride(2),
1920
+ chunk_states.stride(3),
1921
+ states.stride(0),
1922
+ states.stride(1),
1923
+ states.stride(2),
1924
+ states.stride(3),
1925
+ )
1926
+ return states
1927
+
1928
+
1929
+ class ChunkStateFn(torch.autograd.Function):
1930
+
1931
+ @staticmethod
1932
+ def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
1933
+ batch, seqlen, nheads, headdim = x.shape
1934
+ _, _, nchunks, chunk_size = dt.shape
1935
+ assert seqlen <= nchunks * chunk_size
1936
+ _, _, ngroups, dstate = B.shape
1937
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1938
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1939
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
1940
+ if B.stride(-1) != 1:
1941
+ B = B.contiguous()
1942
+ if (
1943
+ x.stride(-1) != 1 and x.stride(1) != 1
1944
+ ): # Either M or K dimension should be contiguous
1945
+ x = x.contiguous()
1946
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
1947
+ ctx.save_for_backward(B, x, dt, dA_cumsum)
1948
+ return states
1949
+
1950
+ @staticmethod
1951
+ def backward(ctx, dstates):
1952
+ B, x, dt, dA_cumsum = ctx.saved_tensors
1953
+ batch, seqlen, nheads, headdim = x.shape
1954
+ _, _, nchunks, chunk_size = dt.shape
1955
+ _, _, ngroups, dstate = B.shape
1956
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1957
+ if dstates.stride(-1) != 1:
1958
+ dstates = dstates.contiguous()
1959
+ dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
1960
+ dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
1961
+ dB = dB.to(B.dtype)
1962
+ return dB, dx, ddt, ddA_cumsum, None
1963
+
1964
+
1965
+ def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
1966
+ """
1967
+ Argument:
1968
+ B: (batch, seqlen, ngroups, headdim)
1969
+ x: (batch, seqlen, nheads, headdim)
1970
+ dt: (batch, nheads, nchunks, chunk_size)
1971
+ dA_cumsum: (batch, nheads, nchunks, chunk_size)
1972
+ Return:
1973
+ states: (batch, nchunks, nheads, headdim, dstate)
1974
+ """
1975
+ return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
1976
+
1977
+
1978
+ def chunk_state_ref(B, x, dt, dA_cumsum):
1979
+ """
1980
+ Argument:
1981
+ B: (batch, seqlen, ngroups, headdim)
1982
+ x: (batch, seqlen, nheads, headdim)
1983
+ dt: (batch, nheads, nchunks, chunk_size)
1984
+ dA_cumsum: (batch, nheads, nchunks, chunk_size)
1985
+ Return:
1986
+ states: (batch, nchunks, nheads, headdim, dstate)
1987
+ """
1988
+ # Check constraints.
1989
+ batch, seqlen, nheads, headdim = x.shape
1990
+ dstate = B.shape[-1]
1991
+ _, _, nchunks, chunk_size = dt.shape
1992
+ assert seqlen <= nchunks * chunk_size
1993
+ assert x.shape == (batch, seqlen, nheads, headdim)
1994
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1995
+ ngroups = B.shape[2]
1996
+ assert nheads % ngroups == 0
1997
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1998
+ B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
1999
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
2000
+ if seqlen < nchunks * chunk_size:
2001
+ x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2002
+ B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2003
+ x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
2004
+ B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
2005
+ decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
2006
+ return torch.einsum(
2007
+ "bclhn,bhcl,bhcl,bclhp->bchpn",
2008
+ B.to(x.dtype),
2009
+ decay_states.to(x.dtype),
2010
+ dt.to(x.dtype),
2011
+ x,
2012
+ )
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py ADDED
@@ -0,0 +1,1884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ from typing import Optional
7
+
8
+ import math
9
+ from packaging import version
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from ...utils.torch import custom_bwd, custom_fwd
15
+
16
+ import triton
17
+ import triton.language as tl
18
+
19
+ from einops import rearrange, repeat
20
+
21
+ try:
22
+ from causal_conv1d import causal_conv1d_fn
23
+ import causal_conv1d_cuda
24
+ except ImportError:
25
+ causal_conv1d_fn, causal_conv1d_cuda = None, None
26
+
27
+ from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
28
+ from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
29
+ from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
30
+ from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
31
+ from .ssd_chunk_state import chunk_state, chunk_state_ref
32
+ from .ssd_chunk_state import chunk_state_varlen
33
+ from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
34
+ from .ssd_state_passing import state_passing, state_passing_ref
35
+ from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
36
+ from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
37
+ from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
38
+ from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
39
+ from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
40
+ from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
41
+ from .k_activations import _swiglu_fwd, _swiglu_bwd
42
+
43
+ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
44
+
45
+
46
+ def init_to_zero(names):
47
+ return lambda nargs: [
48
+ nargs[name].zero_() for name in names if nargs[name] is not None
49
+ ]
50
+
51
+
52
+ @triton.autotune(
53
+ configs=[
54
+ triton.Config(
55
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
56
+ num_stages=3,
57
+ num_warps=8,
58
+ pre_hook=init_to_zero(["ddt_ptr"]),
59
+ ),
60
+ triton.Config(
61
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
62
+ num_stages=4,
63
+ num_warps=4,
64
+ pre_hook=init_to_zero(["ddt_ptr"]),
65
+ ),
66
+ triton.Config(
67
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
68
+ num_stages=4,
69
+ num_warps=4,
70
+ pre_hook=init_to_zero(["ddt_ptr"]),
71
+ ),
72
+ triton.Config(
73
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
74
+ num_stages=4,
75
+ num_warps=4,
76
+ pre_hook=init_to_zero(["ddt_ptr"]),
77
+ ),
78
+ triton.Config(
79
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
80
+ num_stages=4,
81
+ num_warps=4,
82
+ pre_hook=init_to_zero(["ddt_ptr"]),
83
+ ),
84
+ triton.Config(
85
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
86
+ num_stages=4,
87
+ num_warps=4,
88
+ pre_hook=init_to_zero(["ddt_ptr"]),
89
+ ),
90
+ triton.Config(
91
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
92
+ num_stages=5,
93
+ num_warps=4,
94
+ pre_hook=init_to_zero(["ddt_ptr"]),
95
+ ),
96
+ triton.Config(
97
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
98
+ num_stages=5,
99
+ num_warps=4,
100
+ pre_hook=init_to_zero(["ddt_ptr"]),
101
+ ),
102
+ triton.Config(
103
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
104
+ num_stages=4,
105
+ num_warps=4,
106
+ pre_hook=init_to_zero(["ddt_ptr"]),
107
+ ),
108
+ ],
109
+ key=["chunk_size", "hdim", "dstate"],
110
+ )
111
+ @triton.jit
112
+ def _chunk_scan_chunk_state_bwd_dx_kernel(
113
+ # Pointers to matrices
114
+ x_ptr,
115
+ cb_ptr,
116
+ dout_ptr,
117
+ dt_ptr,
118
+ dA_cumsum_ptr,
119
+ seq_idx_ptr,
120
+ D_ptr,
121
+ b_ptr,
122
+ dstates_ptr,
123
+ dx_ptr,
124
+ ddt_ptr,
125
+ dD_ptr,
126
+ # Matrix dimensions
127
+ chunk_size,
128
+ hdim,
129
+ dstate,
130
+ batch,
131
+ seqlen,
132
+ nheads_ngroups_ratio,
133
+ # Strides
134
+ stride_x_batch,
135
+ stride_x_seqlen,
136
+ stride_x_head,
137
+ stride_x_hdim,
138
+ stride_cb_batch,
139
+ stride_cb_chunk,
140
+ stride_cb_head,
141
+ stride_cb_csize_m,
142
+ stride_cb_csize_k,
143
+ stride_dout_batch,
144
+ stride_dout_seqlen,
145
+ stride_dout_head,
146
+ stride_dout_hdim,
147
+ stride_dt_batch,
148
+ stride_dt_chunk,
149
+ stride_dt_head,
150
+ stride_dt_csize,
151
+ stride_dA_cs_batch,
152
+ stride_dA_cs_chunk,
153
+ stride_dA_cs_head,
154
+ stride_dA_cs_csize,
155
+ stride_seq_idx_batch,
156
+ stride_seq_idx_seqlen,
157
+ stride_D_head,
158
+ stride_b_batch,
159
+ stride_b_seqlen,
160
+ stride_b_head,
161
+ stride_b_dstate,
162
+ stride_dstates_batch,
163
+ stride_dstates_chunk,
164
+ stride_dstates_head,
165
+ stride_dstates_hdim,
166
+ stride_dstates_dstate,
167
+ stride_dx_batch,
168
+ stride_dx_seqlen,
169
+ stride_dx_head,
170
+ stride_dx_hdim,
171
+ stride_ddt_batch,
172
+ stride_ddt_chunk,
173
+ stride_ddt_head,
174
+ stride_ddt_csize,
175
+ stride_dD_batch,
176
+ stride_dD_chunk,
177
+ stride_dD_head,
178
+ stride_dD_csize,
179
+ stride_dD_hdim,
180
+ # Meta-parameters
181
+ HAS_D: tl.constexpr,
182
+ D_HAS_HDIM: tl.constexpr,
183
+ HAS_SEQ_IDX: tl.constexpr,
184
+ BLOCK_SIZE_M: tl.constexpr,
185
+ BLOCK_SIZE_N: tl.constexpr,
186
+ BLOCK_SIZE_K: tl.constexpr,
187
+ BLOCK_SIZE_DSTATE: tl.constexpr,
188
+ IS_TRITON_22: tl.constexpr,
189
+ ):
190
+ pid_bc = tl.program_id(axis=1)
191
+ pid_c = pid_bc // batch
192
+ pid_b = pid_bc - pid_c * batch
193
+ pid_h = tl.program_id(axis=2)
194
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
195
+ pid_m = tl.program_id(axis=0) // num_pid_n
196
+ pid_n = tl.program_id(axis=0) % num_pid_n
197
+ x_ptr += (
198
+ pid_b * stride_x_batch
199
+ + pid_c * chunk_size * stride_x_seqlen
200
+ + pid_h * stride_x_head
201
+ )
202
+ cb_ptr += (
203
+ pid_b * stride_cb_batch
204
+ + pid_c * stride_cb_chunk
205
+ + (pid_h // nheads_ngroups_ratio) * stride_cb_head
206
+ )
207
+ dout_ptr += (
208
+ pid_b * stride_dout_batch
209
+ + pid_c * chunk_size * stride_dout_seqlen
210
+ + pid_h * stride_dout_head
211
+ )
212
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
213
+ ddt_ptr += (
214
+ pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
215
+ )
216
+ dA_cumsum_ptr += (
217
+ pid_b * stride_dA_cs_batch
218
+ + pid_c * stride_dA_cs_chunk
219
+ + pid_h * stride_dA_cs_head
220
+ )
221
+ b_ptr += (
222
+ pid_b * stride_b_batch
223
+ + pid_c * chunk_size * stride_b_seqlen
224
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
225
+ )
226
+ dstates_ptr += (
227
+ pid_b * stride_dstates_batch
228
+ + pid_c * stride_dstates_chunk
229
+ + pid_h * stride_dstates_head
230
+ )
231
+ if HAS_SEQ_IDX:
232
+ seq_idx_ptr += (
233
+ pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
234
+ )
235
+
236
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
237
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
238
+
239
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
240
+
241
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
242
+
243
+ dA_cs_m = tl.load(
244
+ dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
245
+ mask=offs_m < chunk_size_limit,
246
+ other=0.0,
247
+ ).to(tl.float32)
248
+
249
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
250
+ tl.float32
251
+ )
252
+ if not HAS_SEQ_IDX:
253
+ scale = tl.exp(dA_cs_last - dA_cs_m)
254
+ else:
255
+ seq_idx_m = tl.load(
256
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
257
+ mask=offs_m < chunk_size_limit,
258
+ other=-1,
259
+ )
260
+ seq_idx_last = tl.load(
261
+ seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
262
+ )
263
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
264
+ # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
265
+ # However, we're getting error with the Triton compiler 2.1.0 for that code path:
266
+ # Unexpected mma -> mma layout conversion
267
+ # Triton 2.2.0 fixes this
268
+ offs_dstate = tl.arange(
269
+ 0,
270
+ (
271
+ BLOCK_SIZE_DSTATE
272
+ if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
273
+ else BLOCK_SIZE_K
274
+ ),
275
+ )
276
+ b_ptrs = b_ptr + (
277
+ offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
278
+ )
279
+ dstates_ptrs = dstates_ptr + (
280
+ offs_n[None, :] * stride_dstates_hdim
281
+ + offs_dstate[:, None] * stride_dstates_dstate
282
+ )
283
+ if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
284
+ b = tl.load(
285
+ b_ptrs,
286
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
287
+ other=0.0,
288
+ )
289
+ dstates = tl.load(
290
+ dstates_ptrs,
291
+ mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
292
+ other=0.0,
293
+ )
294
+ dstates = dstates.to(b_ptr.dtype.element_ty)
295
+ acc = tl.dot(b, dstates) * scale[:, None]
296
+ else:
297
+ for k in range(0, dstate, BLOCK_SIZE_K):
298
+ b = tl.load(
299
+ b_ptrs,
300
+ mask=(offs_m[:, None] < chunk_size_limit)
301
+ & (offs_dstate[None, :] < dstate - k),
302
+ other=0.0,
303
+ )
304
+ dstates = tl.load(
305
+ dstates_ptrs,
306
+ mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
307
+ other=0.0,
308
+ )
309
+ dstates = dstates.to(b_ptr.dtype.element_ty)
310
+ acc += tl.dot(b, dstates)
311
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
312
+ dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
313
+ acc *= scale[:, None]
314
+
315
+ # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
316
+ # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
317
+ # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
318
+ # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
319
+ # ddt = tl.sum(acc * x, axis=1) * dt_m
320
+ # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
321
+ # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
322
+
323
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
324
+ cb_ptrs = cb_ptr + (
325
+ offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
326
+ )
327
+ dout_ptrs = dout_ptr + (
328
+ offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
329
+ )
330
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
331
+ K_MAX = chunk_size_limit
332
+ K_MIN = pid_m * BLOCK_SIZE_M
333
+ cb_ptrs += K_MIN * stride_cb_csize_k
334
+ dout_ptrs += K_MIN * stride_dout_seqlen
335
+ dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
336
+ for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
337
+ k = tl.multiple_of(k, BLOCK_SIZE_K)
338
+ # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
339
+ cb = tl.load(
340
+ cb_ptrs,
341
+ mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
342
+ other=0.0,
343
+ )
344
+ dout = tl.load(
345
+ dout_ptrs,
346
+ mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
347
+ other=0.0,
348
+ )
349
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
350
+ tl.float32
351
+ )
352
+ cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
353
+ # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
354
+ # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
355
+ # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
356
+ # This will cause NaN in acc, and hence NaN in dx and ddt.
357
+ mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
358
+ cb = tl.where(mask, cb, 0.0)
359
+ cb = cb.to(dout_ptr.dtype.element_ty)
360
+ acc += tl.dot(cb, dout)
361
+ cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
362
+ dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
363
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
364
+
365
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
366
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
367
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
368
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
369
+ dx = acc * dt_m[:, None]
370
+ dx_ptr += (
371
+ pid_b * stride_dx_batch
372
+ + pid_c * chunk_size * stride_dx_seqlen
373
+ + pid_h * stride_dx_head
374
+ )
375
+ dx_ptrs = dx_ptr + (
376
+ offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
377
+ )
378
+ if HAS_D:
379
+ dout_res_ptrs = dout_ptr + (
380
+ offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
381
+ )
382
+ dout_res = tl.load(
383
+ dout_res_ptrs,
384
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
385
+ other=0.0,
386
+ ).to(tl.float32)
387
+ if D_HAS_HDIM:
388
+ D = tl.load(
389
+ D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
390
+ ).to(tl.float32)
391
+ else:
392
+ D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
393
+ dx += dout_res * D
394
+ tl.store(
395
+ dx_ptrs,
396
+ dx,
397
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
398
+ )
399
+
400
+ x_ptrs = x_ptr + (
401
+ offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
402
+ )
403
+ x = tl.load(
404
+ x_ptrs,
405
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
406
+ other=0.0,
407
+ ).to(tl.float32)
408
+ if HAS_D:
409
+ dD_ptr += (
410
+ pid_b * stride_dD_batch
411
+ + pid_c * stride_dD_chunk
412
+ + pid_h * stride_dD_head
413
+ + pid_m * stride_dD_csize
414
+ )
415
+ if D_HAS_HDIM:
416
+ dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
417
+ dD = tl.sum(dout_res * x, axis=0)
418
+ tl.store(dD_ptrs, dD, mask=offs_n < hdim)
419
+ else:
420
+ dD = tl.sum(dout_res * x)
421
+ tl.store(dD_ptr, dD)
422
+ ddt = tl.sum(acc * x, axis=1)
423
+ ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
424
+ tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
425
+
426
+
427
+ def _chunk_scan_chunk_state_bwd_dx(
428
+ x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
429
+ ):
430
+ batch, seqlen, nheads, headdim = x.shape
431
+ _, _, nchunks, chunk_size = dt.shape
432
+ _, _, ngroups, dstate = B.shape
433
+ assert nheads % ngroups == 0
434
+ assert B.shape == (batch, seqlen, ngroups, dstate)
435
+ assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
436
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
437
+ assert dA_cumsum.shape == dt.shape
438
+ assert dout.shape == x.shape
439
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
440
+ if seq_idx is not None:
441
+ assert seq_idx.shape == (batch, seqlen)
442
+ if D is not None:
443
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
444
+ assert D.stride(-1) == 1
445
+ BLOCK_SIZE_min = 32
446
+ dD = torch.empty(
447
+ triton.cdiv(chunk_size, BLOCK_SIZE_min),
448
+ batch,
449
+ nchunks,
450
+ nheads,
451
+ headdim if D.dim() == 2 else 1,
452
+ device=D.device,
453
+ dtype=torch.float32,
454
+ )
455
+ else:
456
+ dD = None
457
+ dD_strides = (
458
+ (dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
459
+ if D is not None
460
+ else (0, 0, 0, 0, 0)
461
+ )
462
+ if dx is None:
463
+ dx = torch.empty_like(x)
464
+ else:
465
+ assert dx.shape == x.shape
466
+ ddt = torch.empty(
467
+ batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
468
+ )
469
+ grid_dx = lambda META: (
470
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
471
+ * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
472
+ batch * nchunks,
473
+ nheads,
474
+ )
475
+ with torch.cuda.device(x.device.index):
476
+ _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
477
+ x,
478
+ CB,
479
+ dout,
480
+ dt,
481
+ dA_cumsum,
482
+ seq_idx,
483
+ D,
484
+ B,
485
+ dstates,
486
+ dx,
487
+ ddt,
488
+ dD,
489
+ chunk_size,
490
+ headdim,
491
+ dstate,
492
+ batch,
493
+ seqlen,
494
+ nheads // ngroups,
495
+ x.stride(0),
496
+ x.stride(1),
497
+ x.stride(2),
498
+ x.stride(3),
499
+ CB.stride(0),
500
+ CB.stride(1),
501
+ CB.stride(2),
502
+ CB.stride(-1),
503
+ CB.stride(-2),
504
+ dout.stride(0),
505
+ dout.stride(1),
506
+ dout.stride(2),
507
+ dout.stride(3),
508
+ dt.stride(0),
509
+ dt.stride(2),
510
+ dt.stride(1),
511
+ dt.stride(3),
512
+ dA_cumsum.stride(0),
513
+ dA_cumsum.stride(2),
514
+ dA_cumsum.stride(1),
515
+ dA_cumsum.stride(3),
516
+ *(
517
+ (seq_idx.stride(0), seq_idx.stride(1))
518
+ if seq_idx is not None
519
+ else (0, 0)
520
+ ),
521
+ D.stride(0) if D is not None else 0,
522
+ B.stride(0),
523
+ B.stride(1),
524
+ B.stride(2),
525
+ B.stride(3),
526
+ dstates.stride(0),
527
+ dstates.stride(1),
528
+ dstates.stride(2),
529
+ dstates.stride(3),
530
+ dstates.stride(4),
531
+ dx.stride(0),
532
+ dx.stride(1),
533
+ dx.stride(2),
534
+ dx.stride(3),
535
+ ddt.stride(0),
536
+ ddt.stride(2),
537
+ ddt.stride(1),
538
+ ddt.stride(3),
539
+ dD_strides[1],
540
+ dD_strides[2],
541
+ dD_strides[3],
542
+ dD_strides[0],
543
+ dD_strides[4],
544
+ D is not None,
545
+ D.dim() == 2 if D is not None else True,
546
+ HAS_SEQ_IDX=seq_idx is not None,
547
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
548
+ IS_TRITON_22=TRITON_22
549
+ )
550
+ if D is not None:
551
+ BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
552
+ "BLOCK_SIZE_M"
553
+ ]
554
+ n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
555
+ dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
556
+ if D.dim() == 1:
557
+ dD = rearrange(dD, "h 1 -> h")
558
+ return dx, ddt.to(dtype=dt.dtype), dD
559
+
560
+
561
+ def _mamba_chunk_scan_combined_fwd(
562
+ x,
563
+ dt,
564
+ A,
565
+ B,
566
+ C,
567
+ chunk_size,
568
+ D=None,
569
+ z=None,
570
+ dt_bias=None,
571
+ initial_states=None,
572
+ seq_idx=None,
573
+ cu_seqlens=None,
574
+ dt_softplus=False,
575
+ dt_limit=(0.0, float("inf")),
576
+ ):
577
+ batch, seqlen, nheads, headdim = x.shape
578
+ _, _, ngroups, dstate = B.shape
579
+ assert nheads % ngroups == 0
580
+ assert B.shape == (batch, seqlen, ngroups, dstate)
581
+ assert x.shape == (batch, seqlen, nheads, headdim)
582
+ assert dt.shape == (batch, seqlen, nheads)
583
+ assert A.shape == (nheads,)
584
+ assert C.shape == B.shape
585
+ if z is not None:
586
+ assert z.shape == x.shape
587
+ if D is not None:
588
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
589
+ if seq_idx is not None:
590
+ assert seq_idx.shape == (batch, seqlen)
591
+ if B.stride(-1) != 1:
592
+ B = B.contiguous()
593
+ if C.stride(-1) != 1:
594
+ C = C.contiguous()
595
+ if (
596
+ x.stride(-1) != 1 and x.stride(1) != 1
597
+ ): # Either M or K dimension should be contiguous
598
+ x = x.contiguous()
599
+ if (
600
+ z is not None and z.stride(-1) != 1 and z.stride(1) != 1
601
+ ): # Either M or K dimension should be contiguous
602
+ z = z.contiguous()
603
+ if D is not None and D.stride(-1) != 1:
604
+ D = D.contiguous()
605
+ if initial_states is not None:
606
+ assert initial_states.shape == (batch, nheads, headdim, dstate)
607
+ # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
608
+ # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
609
+ # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
610
+ # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
611
+ dA_cumsum, dt = _chunk_cumsum_fwd(
612
+ dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
613
+ )
614
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
615
+ # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
616
+ # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
617
+ # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
618
+ states, final_states = _state_passing_fwd(
619
+ rearrange(states, "... p n -> ... (p n)"),
620
+ dA_cumsum[:, :, :, -1],
621
+ initial_states=(
622
+ rearrange(initial_states, "... p n -> ... (p n)")
623
+ if initial_states is not None
624
+ else None
625
+ ),
626
+ seq_idx=seq_idx,
627
+ chunk_size=chunk_size,
628
+ out_dtype=C.dtype,
629
+ )
630
+ states, final_states = [
631
+ rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
632
+ ]
633
+ # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
634
+ # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
635
+ CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
636
+ out, out_x = _chunk_scan_fwd(
637
+ CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
638
+ )
639
+ if cu_seqlens is None:
640
+ return out, out_x, dt, dA_cumsum, states, final_states
641
+ else:
642
+ assert (
643
+ batch == 1
644
+ ), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
645
+ varlen_states = chunk_state_varlen(
646
+ B.squeeze(0),
647
+ x.squeeze(0),
648
+ dt.squeeze(0),
649
+ dA_cumsum.squeeze(0),
650
+ cu_seqlens,
651
+ states.squeeze(0),
652
+ )
653
+ return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
654
+
655
+
656
+ def _mamba_chunk_scan_combined_bwd(
657
+ dout,
658
+ x,
659
+ dt,
660
+ A,
661
+ B,
662
+ C,
663
+ out,
664
+ chunk_size,
665
+ D=None,
666
+ z=None,
667
+ dt_bias=None,
668
+ initial_states=None,
669
+ dfinal_states=None,
670
+ seq_idx=None,
671
+ dt_softplus=False,
672
+ dt_limit=(0.0, float("inf")),
673
+ dx=None,
674
+ ddt=None,
675
+ dB=None,
676
+ dC=None,
677
+ dz=None,
678
+ recompute_output=False,
679
+ ):
680
+ if dout.stride(-1) != 1:
681
+ dout = dout.contiguous()
682
+ batch, seqlen, nheads, headdim = x.shape
683
+ nchunks = math.ceil(seqlen / chunk_size)
684
+ _, _, ngroups, dstate = B.shape
685
+ assert dout.shape == (batch, seqlen, nheads, headdim)
686
+ assert dt.shape == (batch, seqlen, nheads)
687
+ assert A.shape == (nheads,)
688
+ assert nheads % ngroups == 0
689
+ assert B.shape == (batch, seqlen, ngroups, dstate)
690
+ assert C.shape == B.shape
691
+ assert out.shape == x.shape
692
+ if initial_states is not None:
693
+ assert initial_states.shape == (batch, nheads, headdim, dstate)
694
+ if seq_idx is not None:
695
+ assert seq_idx.shape == (batch, seqlen)
696
+ if dx is not None:
697
+ assert dx.shape == x.shape
698
+ if dB is not None:
699
+ assert dB.shape == B.shape
700
+ dB_given = dB
701
+ else:
702
+ dB_given = torch.empty_like(B)
703
+ if dC is not None:
704
+ assert dC.shape == C.shape
705
+ dC_given = dC
706
+ else:
707
+ dC_given = torch.empty_like(C)
708
+ if dz is not None:
709
+ assert z is not None
710
+ assert dz.shape == z.shape
711
+ if ddt is not None:
712
+ assert ddt.shape == dt.shape
713
+ ddt_given = ddt
714
+ else:
715
+ ddt_given = torch.empty_like(dt)
716
+ # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
717
+ # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
718
+ dt_in = dt.clone()
719
+ dA_cumsum, dt = _chunk_cumsum_fwd(
720
+ dt_in,
721
+ A,
722
+ chunk_size,
723
+ dt_bias=dt_bias,
724
+ dt_softplus=dt_softplus,
725
+ dt_limit=dt_limit,
726
+ )
727
+ CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
728
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
729
+ states, _ = _state_passing_fwd(
730
+ rearrange(states, "... p n -> ... (p n)"),
731
+ dA_cumsum[:, :, :, -1],
732
+ initial_states=(
733
+ rearrange(initial_states, "... p n -> ... (p n)")
734
+ if initial_states is not None
735
+ else None
736
+ ),
737
+ seq_idx=seq_idx,
738
+ chunk_size=chunk_size,
739
+ )
740
+ states = rearrange(states, "... (p n) -> ... p n", n=dstate)
741
+ if z is not None:
742
+ dz, dout, dD, *rest = _chunk_scan_bwd_dz(
743
+ x,
744
+ z,
745
+ out,
746
+ dout,
747
+ chunk_size=chunk_size,
748
+ has_ddAcs=False,
749
+ D=D,
750
+ dz=dz,
751
+ recompute_output=recompute_output,
752
+ )
753
+ outz = rest[0] if recompute_output else out
754
+ else:
755
+ dz = None
756
+ outz = out
757
+ dstates = _chunk_scan_bwd_dstates(
758
+ C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
759
+ )
760
+ # dstates has length nchunks, containing the gradient to initial states at index 0 and
761
+ # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
762
+ # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
763
+ # will be used in matmul in the next kernels.
764
+ dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
765
+ rearrange(states, "... p n -> ... (p n)"),
766
+ dA_cumsum[:, :, :, -1],
767
+ rearrange(dstates, "... p n -> ... (p n)"),
768
+ dfinal_states=(
769
+ rearrange(dfinal_states, "... p n -> ... (p n)")
770
+ if dfinal_states is not None
771
+ else None
772
+ ),
773
+ seq_idx=seq_idx,
774
+ has_initial_states=initial_states is not None,
775
+ dstates_dtype=x.dtype,
776
+ states_dtype=x.dtype,
777
+ chunk_size=chunk_size,
778
+ )
779
+ # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
780
+ # gradient to the final states at index (nchunks - 1)
781
+ # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
782
+ # The final states is not stored.
783
+ states = rearrange(states, "... (p n) -> ... p n", n=dstate)
784
+ dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
785
+ dinitial_states = (
786
+ rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
787
+ if dinitial_states is not None
788
+ else None
789
+ )
790
+ dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
791
+ x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
792
+ )
793
+ # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
794
+ dB, ddA_next = _chunk_state_bwd_db(
795
+ x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
796
+ )
797
+ # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
798
+ dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
799
+ states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
800
+ )
801
+ # Computing ddA with the dcb kernel is much slower, so we're not using it for now
802
+ dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
803
+ # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
804
+ dCB = dCB.to(CB.dtype)
805
+ _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
806
+ _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
807
+ # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
808
+ # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
809
+ if z is None:
810
+ dD = dD_from_x
811
+ # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
812
+ # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
813
+ # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
814
+ # be a lot of underflow.
815
+
816
+ # This is already done as part of bwd_dC kernel
817
+ # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
818
+ ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
819
+ ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
820
+ # This is already done as part of bwd_dB kernel
821
+ # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
822
+ # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
823
+ ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
824
+ ddA += ddA_next + ddA_prev
825
+
826
+ ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
827
+ ddA,
828
+ ddt,
829
+ dt_in,
830
+ A,
831
+ dt_bias=dt_bias,
832
+ dt_softplus=dt_softplus,
833
+ dt_limit=dt_limit,
834
+ ddt=ddt_given,
835
+ )
836
+
837
+ # These 2 lines are just to test ddt and dA being computed by old code
838
+ # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
839
+ # ddt_given.copy_(ddt)
840
+
841
+ return_vals = (
842
+ dx,
843
+ ddt_given,
844
+ dA,
845
+ dB_given,
846
+ dC_given,
847
+ dD,
848
+ dz,
849
+ ddt_bias,
850
+ dinitial_states,
851
+ )
852
+ return return_vals if not recompute_output else (*return_vals, outz)
853
+
854
+
855
+ def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
856
+ """
857
+ Argument:
858
+ dout: (batch, seqlen, nheads, headdim)
859
+ x: (batch, seqlen, nheads, headdim)
860
+ dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
861
+ A: (nheads) or (dim, dstate)
862
+ B: (batch, seqlen, ngroups, dstate)
863
+ C: (batch, seqlen, ngroups, dstate)
864
+ D: (nheads, headdim) or (nheads,)
865
+ z: (batch, seqlen, nheads, headdim)
866
+ Return:
867
+ out: (batch, seqlen, nheads, headdim)
868
+ """
869
+ import selective_scan
870
+
871
+ batch, seqlen, nheads, headdim = x.shape
872
+ chunk_size = dt.shape[-1]
873
+ _, _, ngroups, dstate = B.shape
874
+ assert nheads % ngroups == 0
875
+ x = rearrange(x, "b l h p -> b (h p) l")
876
+ squeeze_dt = dt.dim() == 4
877
+ if dt.dim() == 4:
878
+ dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
879
+ dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
880
+ squeeze_A = A.dim() == 1
881
+ if A.dim() == 1:
882
+ A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
883
+ else:
884
+ A = A.to(dtype=torch.float32)
885
+ B = rearrange(B, "b l g n -> b g n l")
886
+ C = rearrange(C, "b l g n -> b g n l")
887
+ if D is not None:
888
+ if D.dim() == 2:
889
+ D = rearrange(D, "h p -> (h p)")
890
+ else:
891
+ D = repeat(D, "h -> (h p)", p=headdim)
892
+ if z is not None:
893
+ z = rearrange(z, "b l h p -> b (h p) l")
894
+
895
+ if x.stride(-1) != 1:
896
+ x = x.contiguous()
897
+ if dt.stride(-1) != 1:
898
+ dt = dt.contiguous()
899
+ if D is not None:
900
+ D = D.contiguous()
901
+ if B.stride(-1) != 1:
902
+ B = B.contiguous()
903
+ if C.stride(-1) != 1:
904
+ C = C.contiguous()
905
+ if z is not None and z.stride(-1) != 1:
906
+ z = z.contiguous()
907
+ _, intermediate, *rest = selective_scan.fwd(
908
+ x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
909
+ )
910
+ if z is not None:
911
+ out = rest[0]
912
+ else:
913
+ out = None
914
+
915
+ dout = rearrange(dout, "b l h p -> b (h p) l")
916
+
917
+ if dout.stride(-1) != 1:
918
+ dout = dout.contiguous()
919
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
920
+ # backward of selective_scan with the backward of chunk).
921
+ # Here we just pass in None and dz will be allocated in the C++ code.
922
+ _, ddt, dA, *rest = selective_scan.bwd(
923
+ x,
924
+ dt.to(dtype=x.dtype),
925
+ A,
926
+ B,
927
+ C,
928
+ D,
929
+ z,
930
+ None,
931
+ dout,
932
+ intermediate,
933
+ out,
934
+ None,
935
+ False,
936
+ False, # option to recompute out_z, not used here
937
+ )
938
+ ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
939
+ if squeeze_dt:
940
+ ddt = ddt.float().sum(dim=2)
941
+ if squeeze_A:
942
+ dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
943
+ return ddt, dA
944
+
945
+
946
+ class MambaChunkScanCombinedFn(torch.autograd.Function):
947
+
948
+ @staticmethod
949
+ def forward(
950
+ ctx,
951
+ x,
952
+ dt,
953
+ A,
954
+ B,
955
+ C,
956
+ chunk_size,
957
+ D=None,
958
+ z=None,
959
+ dt_bias=None,
960
+ initial_states=None,
961
+ seq_idx=None,
962
+ cu_seqlens=None,
963
+ dt_softplus=False,
964
+ dt_limit=(0.0, float("inf")),
965
+ return_final_states=False,
966
+ return_varlen_states=False,
967
+ ):
968
+ ctx.dt_dtype = dt.dtype
969
+ if not return_varlen_states:
970
+ cu_seqlens = None
971
+ else:
972
+ assert (
973
+ cu_seqlens is not None
974
+ ), "cu_seqlens must be provided if return_varlen_states is True"
975
+ out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
976
+ _mamba_chunk_scan_combined_fwd(
977
+ x,
978
+ dt,
979
+ A,
980
+ B,
981
+ C,
982
+ chunk_size,
983
+ D=D,
984
+ z=z,
985
+ dt_bias=dt_bias,
986
+ initial_states=initial_states,
987
+ seq_idx=seq_idx,
988
+ cu_seqlens=cu_seqlens,
989
+ dt_softplus=dt_softplus,
990
+ dt_limit=dt_limit,
991
+ )
992
+ )
993
+ ctx.save_for_backward(
994
+ out if z is None else out_x,
995
+ x,
996
+ dt,
997
+ dA_cumsum,
998
+ A,
999
+ B,
1000
+ C,
1001
+ D,
1002
+ z,
1003
+ dt_bias,
1004
+ initial_states,
1005
+ seq_idx,
1006
+ )
1007
+ ctx.dt_softplus = dt_softplus
1008
+ ctx.chunk_size = chunk_size
1009
+ ctx.dt_limit = dt_limit
1010
+ ctx.return_final_states = return_final_states
1011
+ ctx.return_varlen_states = return_varlen_states
1012
+ if not return_varlen_states:
1013
+ return out if not return_final_states else (out, final_states)
1014
+ else:
1015
+ varlen_states = rest[0]
1016
+ return (
1017
+ (out, varlen_states)
1018
+ if not return_final_states
1019
+ else (out, final_states, varlen_states)
1020
+ )
1021
+
1022
+ @staticmethod
1023
+ def backward(ctx, dout, *args):
1024
+ out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
1025
+ ctx.saved_tensors
1026
+ )
1027
+ assert (
1028
+ not ctx.return_varlen_states
1029
+ ), "return_varlen_states is not supported in backward"
1030
+ dfinal_states = args[0] if ctx.return_final_states else None
1031
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
1032
+ _mamba_chunk_scan_combined_bwd(
1033
+ dout,
1034
+ x,
1035
+ dt,
1036
+ A,
1037
+ B,
1038
+ C,
1039
+ out,
1040
+ ctx.chunk_size,
1041
+ D=D,
1042
+ z=z,
1043
+ dt_bias=dt_bias,
1044
+ initial_states=initial_states,
1045
+ dfinal_states=dfinal_states,
1046
+ seq_idx=seq_idx,
1047
+ dt_softplus=ctx.dt_softplus,
1048
+ dt_limit=ctx.dt_limit,
1049
+ )
1050
+ )
1051
+ return (
1052
+ dx,
1053
+ ddt,
1054
+ dA,
1055
+ dB,
1056
+ dC,
1057
+ None,
1058
+ dD,
1059
+ dz,
1060
+ ddt_bias,
1061
+ dinitial_states,
1062
+ None,
1063
+ None,
1064
+ None,
1065
+ None,
1066
+ None,
1067
+ None,
1068
+ )
1069
+
1070
+
1071
+ def mamba_chunk_scan_combined(
1072
+ x,
1073
+ dt,
1074
+ A,
1075
+ B,
1076
+ C,
1077
+ chunk_size,
1078
+ D=None,
1079
+ z=None,
1080
+ dt_bias=None,
1081
+ initial_states=None,
1082
+ seq_idx=None,
1083
+ cu_seqlens=None,
1084
+ dt_softplus=False,
1085
+ dt_limit=(0.0, float("inf")),
1086
+ return_final_states=False,
1087
+ return_varlen_states=False,
1088
+ ):
1089
+ """
1090
+ Argument:
1091
+ x: (batch, seqlen, nheads, headdim)
1092
+ dt: (batch, seqlen, nheads)
1093
+ A: (nheads)
1094
+ B: (batch, seqlen, ngroups, dstate)
1095
+ C: (batch, seqlen, ngroups, dstate)
1096
+ chunk_size: int
1097
+ D: (nheads, headdim) or (nheads,)
1098
+ z: (batch, seqlen, nheads, headdim)
1099
+ dt_bias: (nheads,)
1100
+ initial_states: (batch, nheads, headdim, dstate)
1101
+ seq_idx: (batch, seqlen)
1102
+ cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
1103
+ dt_softplus: Whether to apply softplus to dt
1104
+ Return:
1105
+ out: (batch, seqlen, nheads, headdim)
1106
+ """
1107
+ return MambaChunkScanCombinedFn.apply(
1108
+ x,
1109
+ dt,
1110
+ A,
1111
+ B,
1112
+ C,
1113
+ chunk_size,
1114
+ D,
1115
+ z,
1116
+ dt_bias,
1117
+ initial_states,
1118
+ seq_idx,
1119
+ cu_seqlens,
1120
+ dt_softplus,
1121
+ dt_limit,
1122
+ return_final_states,
1123
+ return_varlen_states,
1124
+ )
1125
+
1126
+
1127
+ def mamba_chunk_scan(
1128
+ x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1129
+ ):
1130
+ """
1131
+ Argument:
1132
+ x: (batch, seqlen, nheads, headdim)
1133
+ dt: (batch, seqlen, nheads)
1134
+ A: (nheads)
1135
+ B: (batch, seqlen, ngroups, dstate)
1136
+ C: (batch, seqlen, ngroups, dstate)
1137
+ D: (nheads, headdim) or (nheads,)
1138
+ z: (batch, seqlen, nheads, headdim)
1139
+ dt_bias: (nheads,)
1140
+ Return:
1141
+ out: (batch, seqlen, nheads, headdim)
1142
+ """
1143
+ batch, seqlen, nheads, headdim = x.shape
1144
+ dstate = B.shape[-1]
1145
+ if seqlen % chunk_size != 0:
1146
+ dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1147
+ dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1148
+ dt = dt.float() # We want high precision for this before cumsum
1149
+ if dt_bias is not None:
1150
+ dt = dt + rearrange(dt_bias, "h -> h 1 1")
1151
+ if dt_softplus:
1152
+ dt = F.softplus(dt)
1153
+ dA = dt * rearrange(A, "h -> h 1 1")
1154
+ dA = dt * rearrange(A, "h -> h 1 1")
1155
+ dA_cumsum = torch.cumsum(dA, dim=-1)
1156
+ # 1. Compute the state for each chunk
1157
+ states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
1158
+ # 2. Pass the state to all the chunks by weighted cumsum.
1159
+ states = rearrange(
1160
+ state_passing(
1161
+ rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1162
+ )[0],
1163
+ "... (p n) -> ... p n",
1164
+ n=dstate,
1165
+ )
1166
+ # 3. Compute the output for each chunk
1167
+ out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1168
+ return out
1169
+
1170
+
1171
+ def ssd_chunk_scan_combined_ref(
1172
+ x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
1173
+ ):
1174
+ """
1175
+ Argument:
1176
+ x: (batch, seqlen, nheads, headdim)
1177
+ dt: (batch, seqlen, nheads)
1178
+ A: (nheads)
1179
+ B: (batch, seqlen, ngroups, dstate)
1180
+ C: (batch, seqlen, ngroups, dstate)
1181
+ D: (nheads, headdim) or (nheads,)
1182
+ z: (batch, seqlen, nheads, headdim)
1183
+ dt_bias: (nheads,)
1184
+ Return:
1185
+ out: (batch, seqlen, nheads, headdim)
1186
+ """
1187
+ batch, seqlen, nheads, headdim = x.shape
1188
+ dstate = B.shape[-1]
1189
+ if seqlen % chunk_size != 0:
1190
+ dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
1191
+ dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
1192
+ dt = dt.float() # We want high precision for this before cumsum
1193
+ if dt_bias is not None:
1194
+ dt = dt + rearrange(dt_bias, "h -> h 1 1")
1195
+ if dt_softplus:
1196
+ dt = F.softplus(dt)
1197
+ dA = dt * rearrange(A, "h -> h 1 1")
1198
+ dA_cumsum = torch.cumsum(dA, dim=-1)
1199
+ # 1. Compute the state for each chunk
1200
+ states = chunk_state_ref(B, x, dt, dA_cumsum)
1201
+ states_dtype = states.dtype
1202
+ if states.dtype not in [torch.float32, torch.float64]:
1203
+ states = states.to(torch.float32)
1204
+ # 2. Pass the state to all the chunks by weighted cumsum.
1205
+ # state_passing_ref is much less numerically stable
1206
+ states = rearrange(
1207
+ state_passing_ref(
1208
+ rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
1209
+ )[0],
1210
+ "... (p n) -> ... p n",
1211
+ n=dstate,
1212
+ )
1213
+ states = states.to(states_dtype)
1214
+ # 3. Compute the output for each chunk
1215
+ out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
1216
+ return out
1217
+
1218
+
1219
+ def ssd_selective_scan(
1220
+ x,
1221
+ dt,
1222
+ A,
1223
+ B,
1224
+ C,
1225
+ D=None,
1226
+ z=None,
1227
+ dt_bias=None,
1228
+ dt_softplus=False,
1229
+ dt_limit=(0.0, float("inf")),
1230
+ ):
1231
+ """
1232
+ Argument:
1233
+ x: (batch, seqlen, nheads, headdim)
1234
+ dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1235
+ A: (nheads) or (dim, dstate)
1236
+ B: (batch, seqlen, ngroups, dstate)
1237
+ C: (batch, seqlen, ngroups, dstate)
1238
+ D: (nheads, headdim) or (nheads,)
1239
+ z: (batch, seqlen, nheads, headdim)
1240
+ dt_bias: (nheads,) or (nheads, headdim)
1241
+ Return:
1242
+ out: (batch, seqlen, nheads, headdim)
1243
+ """
1244
+ from ..selective_scan_interface import selective_scan_fn
1245
+
1246
+ batch, seqlen, nheads, headdim = x.shape
1247
+ _, _, ngroups, dstate = B.shape
1248
+ x = rearrange(x, "b l h p -> b (h p) l")
1249
+ if dt.dim() == 3:
1250
+ dt = repeat(dt, "b l h -> b l h p", p=headdim)
1251
+ dt = rearrange(dt, "b l h p -> b (h p) l")
1252
+ if A.dim() == 1:
1253
+ A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
1254
+ else:
1255
+ A = A.to(dtype=torch.float32)
1256
+ B = rearrange(B, "b l g n -> b g n l")
1257
+ C = rearrange(C, "b l g n -> b g n l")
1258
+ if D is not None:
1259
+ if D.dim() == 2:
1260
+ D = rearrange(D, "h p -> (h p)")
1261
+ else:
1262
+ D = repeat(D, "h -> (h p)", p=headdim)
1263
+ if z is not None:
1264
+ z = rearrange(z, "b l h p -> b (h p) l")
1265
+ if dt_bias is not None:
1266
+ if dt_bias.dim() == 1:
1267
+ dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
1268
+ dt_bias = rearrange(dt_bias, "h p -> (h p)")
1269
+ if dt_limit != (0.0, float("inf")):
1270
+ if dt_bias is not None:
1271
+ dt = dt + rearrange(dt_bias, "d -> d 1")
1272
+ if dt_softplus:
1273
+ dt = F.softplus(dt)
1274
+ dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
1275
+ dt_bias = None
1276
+ dt_softplus = None
1277
+ out = selective_scan_fn(
1278
+ x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
1279
+ )
1280
+ return rearrange(out, "b (h p) l -> b l h p", p=headdim)
1281
+
1282
+
1283
+ def mamba_conv1d_scan_ref(
1284
+ xBC,
1285
+ conv1d_weight,
1286
+ conv1d_bias,
1287
+ dt,
1288
+ A,
1289
+ chunk_size,
1290
+ D=None,
1291
+ z=None,
1292
+ dt_bias=None,
1293
+ dt_softplus=False,
1294
+ dt_limit=(0.0, float("inf")),
1295
+ activation="silu",
1296
+ headdim=None,
1297
+ ngroups=1,
1298
+ ):
1299
+ """
1300
+ Argument:
1301
+ xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
1302
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
1303
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
1304
+ dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
1305
+ A: (nheads)
1306
+ D: (nheads, headdim) or (nheads,)
1307
+ z: (batch, seqlen, dim)
1308
+ dt_bias: (nheads) or (nheads, headdim)
1309
+ headdim: if D is 1D and z is None, headdim must be passed in
1310
+ Return:
1311
+ out: (batch, seqlen, dim)
1312
+ """
1313
+ batch, seqlen, nheads = dt.shape[:3]
1314
+ assert nheads % ngroups == 0
1315
+ if z is not None:
1316
+ dim = z.shape[-1]
1317
+ assert dim % nheads == 0
1318
+ headdim = dim // nheads
1319
+ else:
1320
+ if D.dim() == 1:
1321
+ assert headdim is not None
1322
+ else:
1323
+ headdim = D.shape[1]
1324
+ dim = nheads * headdim
1325
+ xBC = rearrange(
1326
+ causal_conv1d_fn(
1327
+ rearrange(xBC, "b s d -> b d s"),
1328
+ conv1d_weight,
1329
+ conv1d_bias,
1330
+ activation=activation,
1331
+ ),
1332
+ "b d s -> b s d",
1333
+ )
1334
+ dstate = (xBC.shape[-1] - dim) // ngroups // 2
1335
+ x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1336
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1337
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1338
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1339
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1340
+ out = ssd_selective_scan(
1341
+ x,
1342
+ dt.to(x.dtype),
1343
+ A,
1344
+ B,
1345
+ C,
1346
+ D=D.float(),
1347
+ z=z,
1348
+ dt_bias=dt_bias,
1349
+ dt_softplus=dt_softplus,
1350
+ dt_limit=dt_limit,
1351
+ )
1352
+ return rearrange(out, "b s h p -> b s (h p)")
1353
+
1354
+
1355
+ class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
1356
+
1357
+ @staticmethod
1358
+ @custom_fwd
1359
+ def forward(
1360
+ ctx,
1361
+ zxbcdt,
1362
+ conv1d_weight,
1363
+ conv1d_bias,
1364
+ dt_bias,
1365
+ A,
1366
+ D,
1367
+ chunk_size,
1368
+ initial_states=None,
1369
+ seq_idx=None,
1370
+ dt_limit=(0.0, float("inf")),
1371
+ return_final_states=False,
1372
+ activation="silu",
1373
+ rmsnorm_weight=None,
1374
+ rmsnorm_eps=1e-6,
1375
+ outproj_weight=None,
1376
+ outproj_bias=None,
1377
+ headdim=None,
1378
+ ngroups=1,
1379
+ norm_before_gate=True,
1380
+ ):
1381
+ assert activation in [None, "silu", "swish"]
1382
+ if D.dim() == 1:
1383
+ assert headdim is not None
1384
+ (nheads,) = D.shape
1385
+ else:
1386
+ nheads, headdim = D.shape
1387
+ batch, seqlen, _ = zxbcdt.shape
1388
+ dim = nheads * headdim
1389
+ assert nheads % ngroups == 0
1390
+ dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
1391
+ d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
1392
+ assert d_nonssm >= 0
1393
+ assert zxbcdt.shape == (
1394
+ batch,
1395
+ seqlen,
1396
+ 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
1397
+ )
1398
+ assert dt_bias.shape == (nheads,)
1399
+ assert A.shape == (nheads,)
1400
+ zx0, z, xBC, dt = torch.split(
1401
+ zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
1402
+ )
1403
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
1404
+ xBC_conv = rearrange(
1405
+ causal_conv1d_cuda.causal_conv1d_fwd(
1406
+ rearrange(xBC, "b s d -> b d s"),
1407
+ conv1d_weight,
1408
+ conv1d_bias,
1409
+ seq_idx,
1410
+ None,
1411
+ None,
1412
+ activation in ["silu", "swish"],
1413
+ ),
1414
+ "b d s -> b s d",
1415
+ )
1416
+ x, B, C = torch.split(
1417
+ xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
1418
+ )
1419
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1420
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1421
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1422
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
1423
+ if rmsnorm_weight is None:
1424
+ out, out_x, dt_out, dA_cumsum, states, final_states = (
1425
+ _mamba_chunk_scan_combined_fwd(
1426
+ x,
1427
+ dt,
1428
+ A,
1429
+ B,
1430
+ C,
1431
+ chunk_size=chunk_size,
1432
+ D=D,
1433
+ z=z,
1434
+ dt_bias=dt_bias,
1435
+ initial_states=initial_states,
1436
+ seq_idx=seq_idx,
1437
+ dt_softplus=True,
1438
+ dt_limit=dt_limit,
1439
+ )
1440
+ )
1441
+ out = rearrange(out, "b s h p -> b s (h p)")
1442
+ rstd = None
1443
+ if d_nonssm > 0:
1444
+ out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
1445
+ else:
1446
+ out_x, _, dt_out, dA_cumsum, states, final_states = (
1447
+ _mamba_chunk_scan_combined_fwd(
1448
+ x,
1449
+ dt,
1450
+ A,
1451
+ B,
1452
+ C,
1453
+ chunk_size=chunk_size,
1454
+ D=D,
1455
+ z=None,
1456
+ dt_bias=dt_bias,
1457
+ initial_states=initial_states,
1458
+ seq_idx=seq_idx,
1459
+ dt_softplus=True,
1460
+ dt_limit=dt_limit,
1461
+ )
1462
+ )
1463
+ # reshape input data into 2D tensor
1464
+ x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
1465
+ z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1466
+ rmsnorm_weight = rmsnorm_weight.contiguous()
1467
+ if d_nonssm == 0:
1468
+ out = None
1469
+ else:
1470
+ out01 = torch.empty(
1471
+ (batch, seqlen, d_nonssm + dim),
1472
+ dtype=x_rms.dtype,
1473
+ device=x_rms.device,
1474
+ )
1475
+ out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
1476
+ _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
1477
+ out, _, rstd = _layer_norm_fwd(
1478
+ x_rms,
1479
+ rmsnorm_weight,
1480
+ None,
1481
+ rmsnorm_eps,
1482
+ z_rms,
1483
+ out=out,
1484
+ group_size=dim // ngroups,
1485
+ norm_before_gate=norm_before_gate,
1486
+ is_rms_norm=True,
1487
+ )
1488
+ if d_nonssm == 0:
1489
+ out = rearrange(out, "(b s) d -> b s d", b=batch)
1490
+ else:
1491
+ out = out01
1492
+ ctx.outproj_weight_dtype = (
1493
+ outproj_weight.dtype if outproj_weight is not None else None
1494
+ )
1495
+ if outproj_weight is not None:
1496
+ if torch.is_autocast_enabled():
1497
+ dtype = torch.get_autocast_gpu_dtype()
1498
+ out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
1499
+ outproj_bias = (
1500
+ outproj_bias.to(dtype) if outproj_bias is not None else None
1501
+ )
1502
+ out = F.linear(out, outproj_weight, outproj_bias)
1503
+ else:
1504
+ assert outproj_bias is None
1505
+ ctx.save_for_backward(
1506
+ zxbcdt,
1507
+ conv1d_weight,
1508
+ conv1d_bias,
1509
+ out_x,
1510
+ A,
1511
+ D,
1512
+ dt_bias,
1513
+ initial_states,
1514
+ seq_idx,
1515
+ rmsnorm_weight,
1516
+ rstd,
1517
+ outproj_weight,
1518
+ outproj_bias,
1519
+ )
1520
+ ctx.dt_limit = dt_limit
1521
+ ctx.return_final_states = return_final_states
1522
+ ctx.activation = activation
1523
+ ctx.rmsnorm_eps = rmsnorm_eps
1524
+ ctx.norm_before_gate = norm_before_gate
1525
+ ctx.chunk_size = chunk_size
1526
+ ctx.headdim = headdim
1527
+ ctx.ngroups = ngroups
1528
+ return out if not return_final_states else (out, final_states)
1529
+
1530
+ @staticmethod
1531
+ @custom_bwd
1532
+ def backward(ctx, dout, *args):
1533
+ (
1534
+ zxbcdt,
1535
+ conv1d_weight,
1536
+ conv1d_bias,
1537
+ out,
1538
+ A,
1539
+ D,
1540
+ dt_bias,
1541
+ initial_states,
1542
+ seq_idx,
1543
+ rmsnorm_weight,
1544
+ rstd,
1545
+ outproj_weight,
1546
+ outproj_bias,
1547
+ ) = ctx.saved_tensors
1548
+ dfinal_states = args[0] if ctx.return_final_states else None
1549
+ headdim = ctx.headdim
1550
+ nheads = D.shape[0]
1551
+ dim = nheads * headdim
1552
+ assert nheads % ctx.ngroups == 0
1553
+ dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
1554
+ d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
1555
+ assert d_nonssm >= 0
1556
+ recompute_output = outproj_weight is not None
1557
+ if recompute_output:
1558
+ out_recompute = torch.empty(
1559
+ *out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
1560
+ )
1561
+ out0_recompute, out1_recompute = out_recompute.split(
1562
+ [d_nonssm, dim], dim=-1
1563
+ )
1564
+ zx0, z, xBC, dt = torch.split(
1565
+ zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1566
+ )
1567
+ # Recompute x, B, C
1568
+ xBC_conv = rearrange(
1569
+ causal_conv1d_cuda.causal_conv1d_fwd(
1570
+ rearrange(xBC, "b s d -> b d s"),
1571
+ conv1d_weight,
1572
+ conv1d_bias,
1573
+ seq_idx,
1574
+ None,
1575
+ None,
1576
+ ctx.activation in ["silu", "swish"],
1577
+ ),
1578
+ "b d s -> b s d",
1579
+ )
1580
+ x, B, C = torch.split(
1581
+ xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1582
+ )
1583
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1584
+ B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
1585
+ C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
1586
+ dzxbcdt = torch.empty_like(zxbcdt)
1587
+ dzx0, dz, dxBC_given, ddt_given = torch.split(
1588
+ dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
1589
+ )
1590
+ dxBC = torch.empty_like(xBC)
1591
+ dx, dB, dC = torch.split(
1592
+ dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
1593
+ )
1594
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1595
+ dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
1596
+ dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
1597
+ dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
1598
+ if outproj_weight is not None:
1599
+ dout_og = dout
1600
+ dout = F.linear(dout, outproj_weight.t())
1601
+ if d_nonssm > 0:
1602
+ dout0, dout = dout.split([d_nonssm, dim], dim=-1)
1603
+ _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
1604
+ dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
1605
+ if rmsnorm_weight is None:
1606
+ dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
1607
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
1608
+ _mamba_chunk_scan_combined_bwd(
1609
+ dout,
1610
+ x,
1611
+ dt,
1612
+ A,
1613
+ B,
1614
+ C,
1615
+ out,
1616
+ ctx.chunk_size,
1617
+ D=D,
1618
+ z=z,
1619
+ dt_bias=dt_bias,
1620
+ initial_states=initial_states,
1621
+ dfinal_states=dfinal_states,
1622
+ seq_idx=seq_idx,
1623
+ dt_softplus=True,
1624
+ dt_limit=ctx.dt_limit,
1625
+ dx=dx,
1626
+ ddt=ddt_given,
1627
+ dB=dB,
1628
+ dC=dC,
1629
+ dz=dz,
1630
+ recompute_output=recompute_output,
1631
+ )
1632
+ )
1633
+ out_for_linear = (
1634
+ rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
1635
+ )
1636
+ drmsnorm_weight = None
1637
+ else:
1638
+ batch = dout.shape[0]
1639
+ dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
1640
+ dz = rearrange(dz, "b l d -> (b l) d")
1641
+ x_rms = rearrange(out, "b s h p -> (b s) (h p)")
1642
+ z_rms = rearrange(z, "b s h p -> (b s) (h p)")
1643
+ out1_recompute = (
1644
+ rearrange(out1_recompute, "b s d -> (b s) d")
1645
+ if recompute_output
1646
+ else None
1647
+ )
1648
+ dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
1649
+ dy_rms,
1650
+ x_rms,
1651
+ rmsnorm_weight,
1652
+ None,
1653
+ ctx.rmsnorm_eps,
1654
+ None,
1655
+ rstd,
1656
+ z_rms,
1657
+ group_size=dim // ctx.ngroups,
1658
+ norm_before_gate=ctx.norm_before_gate,
1659
+ is_rms_norm=True,
1660
+ recompute_output=recompute_output,
1661
+ dz=dz,
1662
+ out=out1_recompute if recompute_output else None,
1663
+ )
1664
+ out_for_linear = out_recompute if recompute_output else None
1665
+ dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
1666
+ dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
1667
+ _mamba_chunk_scan_combined_bwd(
1668
+ dout,
1669
+ x,
1670
+ dt,
1671
+ A,
1672
+ B,
1673
+ C,
1674
+ out,
1675
+ ctx.chunk_size,
1676
+ D=D,
1677
+ z=None,
1678
+ dt_bias=dt_bias,
1679
+ initial_states=initial_states,
1680
+ dfinal_states=dfinal_states,
1681
+ seq_idx=seq_idx,
1682
+ dt_softplus=True,
1683
+ dt_limit=ctx.dt_limit,
1684
+ dx=dx,
1685
+ ddt=ddt_given,
1686
+ dB=dB,
1687
+ dC=dC,
1688
+ )
1689
+ )
1690
+
1691
+ if outproj_weight is not None:
1692
+ doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
1693
+ doutproj_bias = (
1694
+ dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
1695
+ )
1696
+ else:
1697
+ doutproj_weight, doutproj_bias = None, None
1698
+ dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
1699
+ dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
1700
+ rearrange(xBC, "b s d -> b d s"),
1701
+ conv1d_weight,
1702
+ conv1d_bias,
1703
+ rearrange(dxBC, "b s d -> b d s"),
1704
+ seq_idx,
1705
+ None,
1706
+ None,
1707
+ dxBC_given,
1708
+ False,
1709
+ ctx.activation in ["silu", "swish"],
1710
+ )
1711
+ dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
1712
+ return (
1713
+ dzxbcdt,
1714
+ dweight,
1715
+ dbias,
1716
+ ddt_bias,
1717
+ dA,
1718
+ dD,
1719
+ None,
1720
+ dinitial_states,
1721
+ None,
1722
+ None,
1723
+ None,
1724
+ None,
1725
+ drmsnorm_weight,
1726
+ None,
1727
+ doutproj_weight,
1728
+ doutproj_bias,
1729
+ None,
1730
+ None,
1731
+ None,
1732
+ )
1733
+
1734
+
1735
+ def mamba_split_conv1d_scan_combined(
1736
+ zxbcdt,
1737
+ conv1d_weight,
1738
+ conv1d_bias,
1739
+ dt_bias,
1740
+ A,
1741
+ D,
1742
+ chunk_size,
1743
+ initial_states=None,
1744
+ seq_idx=None,
1745
+ dt_limit=(0.0, float("inf")),
1746
+ return_final_states=False,
1747
+ activation="silu",
1748
+ rmsnorm_weight=None,
1749
+ rmsnorm_eps=1e-6,
1750
+ outproj_weight=None,
1751
+ outproj_bias=None,
1752
+ headdim=None,
1753
+ ngroups=1,
1754
+ norm_before_gate=True,
1755
+ ):
1756
+ """
1757
+ Argument:
1758
+ zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1759
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
1760
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
1761
+ dt_bias: (nheads,)
1762
+ A: (nheads)
1763
+ D: (nheads, headdim) or (nheads,)
1764
+ initial_states: (batch, nheads, headdim, dstate)
1765
+ seq_idx: (batch, seqlen), int32
1766
+ rmsnorm_weight: (dim,)
1767
+ outproj_weight: (out_dim, dim)
1768
+ outproj_bias: (out_dim,)
1769
+ headdim: if D is 1D, headdim must be passed in
1770
+ norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1771
+ Return:
1772
+ out: (batch, seqlen, dim)
1773
+ """
1774
+ return MambaSplitConv1dScanCombinedFn.apply(
1775
+ zxbcdt,
1776
+ conv1d_weight,
1777
+ conv1d_bias,
1778
+ dt_bias,
1779
+ A,
1780
+ D,
1781
+ chunk_size,
1782
+ initial_states,
1783
+ seq_idx,
1784
+ dt_limit,
1785
+ return_final_states,
1786
+ activation,
1787
+ rmsnorm_weight,
1788
+ rmsnorm_eps,
1789
+ outproj_weight,
1790
+ outproj_bias,
1791
+ headdim,
1792
+ ngroups,
1793
+ norm_before_gate,
1794
+ )
1795
+
1796
+
1797
+ def mamba_split_conv1d_scan_ref(
1798
+ zxbcdt,
1799
+ conv1d_weight,
1800
+ conv1d_bias,
1801
+ dt_bias,
1802
+ A,
1803
+ D,
1804
+ chunk_size,
1805
+ dt_limit=(0.0, float("inf")),
1806
+ activation="silu",
1807
+ rmsnorm_weight=None,
1808
+ rmsnorm_eps=1e-6,
1809
+ outproj_weight=None,
1810
+ outproj_bias=None,
1811
+ headdim=None,
1812
+ ngroups=1,
1813
+ norm_before_gate=True,
1814
+ ):
1815
+ """
1816
+ Argument:
1817
+ zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
1818
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
1819
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
1820
+ dt_bias: (nheads,)
1821
+ A: (nheads)
1822
+ D: (nheads, headdim) or (nheads,)
1823
+ rmsnorm_weight: (dim,)
1824
+ outproj_weight: (out_dim, dim)
1825
+ outproj_bias: (out_dim,)
1826
+ headdim: if D is 1D, headdim must be passed in
1827
+ norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
1828
+ Return:
1829
+ out: (batch, seqlen, dim)
1830
+ """
1831
+ if D.dim() == 1:
1832
+ assert headdim is not None
1833
+ (nheads,) = D.shape
1834
+ else:
1835
+ nheads, headdim = D.shape
1836
+ assert nheads % ngroups == 0
1837
+ batch, seqlen, _ = zxbcdt.shape
1838
+ dim = nheads * headdim
1839
+ dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
1840
+ assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
1841
+ assert dt_bias.shape == (nheads,)
1842
+ assert A.shape == (nheads,)
1843
+ if rmsnorm_weight is not None:
1844
+ assert rmsnorm_weight.shape == (dim,)
1845
+ z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
1846
+ xBC = rearrange(
1847
+ causal_conv1d_fn(
1848
+ rearrange(xBC, "b s d -> b d s"),
1849
+ conv1d_weight,
1850
+ conv1d_bias,
1851
+ activation=activation,
1852
+ ),
1853
+ "b d s -> b s d",
1854
+ )
1855
+ x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
1856
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
1857
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
1858
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
1859
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
1860
+ out = ssd_selective_scan(
1861
+ x,
1862
+ dt.to(x.dtype),
1863
+ A,
1864
+ B,
1865
+ C,
1866
+ D=D.float(),
1867
+ z=z if rmsnorm_weight is None else None,
1868
+ dt_bias=dt_bias,
1869
+ dt_softplus=True,
1870
+ dt_limit=dt_limit,
1871
+ )
1872
+ out = rearrange(out, "b s h p -> b s (h p)")
1873
+ if rmsnorm_weight is not None:
1874
+ out = rmsnorm_fn(
1875
+ out,
1876
+ rmsnorm_weight,
1877
+ None,
1878
+ z=rearrange(z, "b l h p -> b l (h p)"),
1879
+ eps=rmsnorm_eps,
1880
+ norm_before_gate=norm_before_gate,
1881
+ )
1882
+ if outproj_weight is not None:
1883
+ out = F.linear(out, outproj_weight, outproj_bias)
1884
+ return out
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_state_passing.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BLOCK_SIZE': 64}),
19
+ triton.Config({'BLOCK_SIZE': 128}),
20
+ triton.Config({'BLOCK_SIZE': 256}),
21
+ triton.Config({'BLOCK_SIZE': 512}),
22
+ triton.Config({'BLOCK_SIZE': 1024}),
23
+ triton.Config({'BLOCK_SIZE': 2048}),
24
+ ],
25
+ key=['dim'],
26
+ )
27
+ @triton.jit
28
+ def _state_passing_fwd_kernel(
29
+ # Pointers to matrices
30
+ states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
31
+ # Matrix dimensions
32
+ dim, nchunks, seqlen, chunk_size,
33
+ # Strides
34
+ stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
35
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
36
+ stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
37
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
38
+ stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
39
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
40
+ # Meta-parameters
41
+ HAS_INITSTATES: tl.constexpr,
42
+ HAS_SEQ_IDX: tl.constexpr,
43
+ BLOCK_SIZE: tl.constexpr,
44
+ ):
45
+ pid_b = tl.program_id(axis=1)
46
+ pid_h = tl.program_id(axis=2)
47
+ pid_m = tl.program_id(axis=0)
48
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
49
+ dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
50
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
51
+ final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
52
+ if HAS_INITSTATES:
53
+ initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
54
+ if HAS_SEQ_IDX:
55
+ seq_idx_ptr += pid_b * stride_seq_idx_batch
56
+
57
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
58
+ states_ptrs = states_ptr + offs_m * stride_states_dim
59
+ out_ptrs = out_ptr + offs_m * stride_out_dim
60
+ final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
61
+
62
+ if not HAS_INITSTATES:
63
+ states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
64
+ else:
65
+ initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
66
+ states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
67
+ tl.store(out_ptrs, states, mask=offs_m < dim)
68
+ out_ptrs += stride_out_chunk
69
+ seq_idx = 0
70
+ for c in range(nchunks):
71
+ new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
72
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
73
+ scale = tl.exp(dA_cs)
74
+ if HAS_SEQ_IDX:
75
+ seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
76
+ scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
77
+ seq_idx = seq_idx_new
78
+ states = scale * states + new_states
79
+ if c < nchunks - 1:
80
+ tl.store(out_ptrs, states, mask=offs_m < dim)
81
+ else:
82
+ tl.store(final_states_ptrs, states, mask=offs_m < dim)
83
+ states_ptrs += stride_states_chunk
84
+ dA_cs_ptr += stride_dA_cs_chunk
85
+ out_ptrs += stride_out_chunk
86
+
87
+
88
+ @triton.autotune(
89
+ configs=[
90
+ triton.Config({'BLOCK_SIZE': 64}),
91
+ triton.Config({'BLOCK_SIZE': 128}),
92
+ triton.Config({'BLOCK_SIZE': 256}),
93
+ triton.Config({'BLOCK_SIZE': 512}),
94
+ triton.Config({'BLOCK_SIZE': 1024}),
95
+ triton.Config({'BLOCK_SIZE': 2048}),
96
+ ],
97
+ key=['dim'],
98
+ )
99
+ @triton.jit
100
+ def _state_passing_bwd_kernel(
101
+ # Pointers to matrices
102
+ dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
103
+ dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
104
+ # Matrix dimensions
105
+ dim, nchunks, seqlen, chunk_size,
106
+ # Strides
107
+ stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
108
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
109
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
110
+ stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
111
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
112
+ stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
113
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
114
+ stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
115
+ # Meta-parameters
116
+ CONVERT_STATES: tl.constexpr,
117
+ HAS_DFINAL_STATES: tl.constexpr,
118
+ HAS_DINITSTATES: tl.constexpr,
119
+ HAS_SEQ_IDX: tl.constexpr,
120
+ BLOCK_SIZE: tl.constexpr,
121
+ ):
122
+ pid_b = tl.program_id(axis=1)
123
+ pid_h = tl.program_id(axis=2)
124
+ pid_m = tl.program_id(axis=0)
125
+ dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
126
+ dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
127
+ ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
128
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
129
+ dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
130
+ if CONVERT_STATES:
131
+ states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
132
+ if HAS_DFINAL_STATES:
133
+ dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
134
+ if HAS_DINITSTATES:
135
+ dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
136
+ if HAS_SEQ_IDX:
137
+ seq_idx_ptr += pid_b * stride_seq_idx_batch
138
+
139
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
140
+ dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
141
+ out_ptrs = out_ptr + offs_m * stride_out_dim
142
+ dout_ptrs = dout_ptr + offs_m * stride_dout_dim
143
+ if CONVERT_STATES:
144
+ states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
145
+
146
+ if HAS_DFINAL_STATES:
147
+ dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
148
+ else:
149
+ dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
150
+ tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
151
+ if HAS_SEQ_IDX:
152
+ seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
153
+ dstates_ptrs -= stride_dstates_chunk
154
+ for c in range(nchunks - 1):
155
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
156
+ scale = tl.exp(dA_cs)
157
+ if HAS_SEQ_IDX:
158
+ seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
159
+ scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
160
+ seq_idx = seq_idx_new
161
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
162
+ if CONVERT_STATES:
163
+ tl.store(states_converted_ptrs, out, mask=offs_m < dim)
164
+ ddA = tl.sum(out * dstates) * scale
165
+ tl.store(ddA_cs_ptr, ddA)
166
+ dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
167
+ dstates = scale * dstates + dout
168
+ tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
169
+ dout_ptrs -= stride_dout_chunk
170
+ dstates_ptrs -= stride_dstates_chunk
171
+ dA_cs_ptr -= stride_dA_cs_chunk
172
+ ddA_cs_ptr -= stride_ddA_cs_chunk
173
+ out_ptrs -= stride_out_chunk
174
+ if CONVERT_STATES:
175
+ states_converted_ptrs -= stride_out_chunk
176
+ if CONVERT_STATES:
177
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
178
+ tl.store(states_converted_ptrs, out, mask=offs_m < dim)
179
+ if not HAS_DINITSTATES:
180
+ tl.store(ddA_cs_ptr, 0.0)
181
+ else:
182
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
183
+ scale = tl.exp(dA_cs)
184
+ if HAS_SEQ_IDX:
185
+ scale = tl.where(seq_idx == 0, scale, 0.0)
186
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
187
+ ddA = tl.sum(out * dstates) * scale
188
+ tl.store(ddA_cs_ptr, ddA)
189
+ dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
190
+ dstates = scale * dstates + dout
191
+ tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
192
+
193
+
194
+ def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
195
+ out_dtype=None):
196
+ batch, nchunks, nheads, dim = states.shape
197
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
198
+ if initial_states is not None:
199
+ assert initial_states.shape == (batch, nheads, dim)
200
+ if seq_idx is not None:
201
+ assert chunk_size is not None
202
+ seqlen = seq_idx.shape[-1]
203
+ assert seq_idx.shape == (batch, seqlen)
204
+ out_dtype = states.dtype if out_dtype is None else out_dtype
205
+ out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
206
+ final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
207
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
208
+ with torch.cuda.device(states.device.index):
209
+ _state_passing_fwd_kernel[grid](
210
+ states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
211
+ dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
212
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3),
213
+ out.stride(0), out.stride(1), out.stride(2), out.stride(3),
214
+ final_states.stride(0), final_states.stride(1), final_states.stride(2),
215
+ dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
216
+ *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
217
+ if initial_states is not None else (0, 0, 0)),
218
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
219
+ HAS_INITSTATES=initial_states is not None,
220
+ HAS_SEQ_IDX=seq_idx is not None,
221
+ )
222
+ return out, final_states
223
+
224
+
225
+ def _state_passing_bwd(
226
+ states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
227
+ dstates_dtype=None, states_dtype=None, chunk_size=None
228
+ ):
229
+ """
230
+ states contains the initial_states at index 0. The final states are not included in states.
231
+ """
232
+ batch, nchunks, nheads, dim = states.shape
233
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
234
+ assert dout.shape == (batch, nchunks, nheads, dim)
235
+ if seq_idx is not None:
236
+ assert chunk_size is not None
237
+ seqlen = seq_idx.shape[-1]
238
+ assert seq_idx.shape == (batch, seqlen)
239
+ dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
240
+ if states_dtype is not None and states_dtype != states.dtype:
241
+ states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
242
+ assert states_converted.stride() == states.stride()
243
+ else:
244
+ states_converted = None
245
+ if has_initial_states:
246
+ dinitstates = torch.empty_like(dstates[:, 0])
247
+ else:
248
+ dinitstates = None
249
+ if dfinal_states is not None:
250
+ assert dfinal_states.shape == (batch, nheads, dim)
251
+ BLOCK_SIZE_min = 64
252
+ n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
253
+ ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
254
+ dtype=torch.float32, device=dA_chunk_cumsum.device)
255
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
256
+ with torch.cuda.device(dout.device.index):
257
+ _state_passing_bwd_kernel[grid](
258
+ dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
259
+ dstates, ddA_chunk_cumsum, dinitstates, states_converted,
260
+ dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
261
+ dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
262
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3),
263
+ dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
264
+ *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
265
+ if dfinal_states is not None else (0, 0, 0)),
266
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
267
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
268
+ ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
269
+ *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
270
+ if dinitstates is not None else (0, 0, 0)),
271
+ CONVERT_STATES=states_converted is not None,
272
+ HAS_DFINAL_STATES=dfinal_states is not None,
273
+ HAS_DINITSTATES=dinitstates is not None,
274
+ HAS_SEQ_IDX=seq_idx is not None,
275
+ )
276
+ BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
277
+ n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
278
+ ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
279
+ if states_dtype is not None and states_dtype == states.dtype:
280
+ states_converted = states
281
+ return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
282
+
283
+
284
+ class StatePassingFn(torch.autograd.Function):
285
+
286
+ @staticmethod
287
+ def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
288
+ batch, nchunks, nheads, dim = states.shape
289
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
290
+ if states.stride(-1) != 1:
291
+ states = states.contiguous()
292
+ out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
293
+ ctx.save_for_backward(out, dA_chunk_cumsum)
294
+ ctx.has_initial_states = initial_states is not None
295
+ return out, final_states
296
+
297
+ @staticmethod
298
+ def backward(ctx, dout, dfinal_states):
299
+ out, dA_chunk_cumsum = ctx.saved_tensors
300
+ batch, nchunks, nheads, dim = out.shape
301
+ assert dout.shape == (batch, nchunks, nheads, dim)
302
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
303
+ assert dfinal_states.shape == (batch, nheads, dim)
304
+ if dout.stride(-1) != 1:
305
+ dout = dout.contiguous()
306
+ dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
307
+ out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
308
+ )
309
+ return dstates, ddA_chunk_cumsum, dinitstates
310
+
311
+
312
+ def state_passing(states, dA_chunk_cumsum, initial_states=None):
313
+ """
314
+ Argument:
315
+ states: (batch, nchunks, nheads, dim)
316
+ dA_chunk_cumsum: (batch, nheads, nchunks)
317
+ initial_states: (batch, nheads, dim)
318
+ Return:
319
+ out: (batch, nchunks, nheads, dim)
320
+ final_states: (batch, nheads, dim)
321
+ """
322
+ return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
323
+
324
+
325
+ def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
326
+ """
327
+ Argument:
328
+ states: (batch, nchunks, nheads, dim)
329
+ dA_chunk_cumsum: (batch, nheads, nchunks)
330
+ initial_states: (batch, nheads, dim)
331
+ Return:
332
+ out: (batch, nchunks, nheads, dim)
333
+ final_states: (batch, nheads, dim)
334
+ """
335
+ if initial_states is None:
336
+ initial_states = torch.zeros_like(states[:, 0])
337
+ states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
338
+ dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
339
+ dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
340
+ nchunks = dA_chunk_cumsum.shape[-1]
341
+ # (batch, nheads, nchunks, nchunks)
342
+ dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
343
+ # (batch, nheads, nchunks, nchunks)
344
+ decay_chunk = torch.exp(dt_chunk_segment_sum)
345
+ causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
346
+ decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
347
+ out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
348
+ return out[:, :-1], out[:, -1]
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/__init__.py ADDED
File without changes
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/generation.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ def modify_logits_for_min_p_filtering(logits, min_p):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0:
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
+ def modify_logits_for_top_k_filtering(logits, top_k):
46
+ """Set the logits for none top-k values to -inf. Done in-place."""
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
49
+
50
+
51
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
+ def modify_logits_for_top_p_filtering(logits, top_p):
54
+ """Set the logits for none top-p values to -inf. Done in-place."""
55
+ if top_p <= 0.0 or top_p >= 1.0:
56
+ return
57
+ # First sort and calculate cumulative sum of probabilities.
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
+ # scatter sorted tensors to original indexing
63
+ indices_to_remove = sorted_indices_to_remove.scatter(
64
+ 1, sorted_indices, sorted_indices_to_remove
65
+ )
66
+ logits.masked_fill_(indices_to_remove, float("-inf"))
67
+
68
+
69
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
+ logits: (batch_size, vocab_size)
72
+ prev_output_tokens: (batch_size, seq_len)
73
+ """
74
+ if repetition_penalty == 1.0:
75
+ return logits
76
+ score = torch.gather(logits, 1, prev_output_tokens)
77
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(1, prev_output_tokens, score)
80
+ return logits
81
+
82
+
83
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
+ """Sample from top-k logits.
85
+ Arguments:
86
+ logits: Tensor of shape (batch_size, vocab_size)
87
+ """
88
+ if top_k == 1: # Short-circuit for greedy decoding
89
+ return logits.argmax(dim=-1)
90
+ else:
91
+ if top_p > 0.0:
92
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
93
+ if top_k > 0:
94
+ top_k = min(top_k, logits.size(-1)) # Safety check
95
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
+ if temperature != 1.0:
97
+ logits_top /= temperature
98
+ modify_logits_for_top_p_filtering(logits_top, top_p)
99
+ return indices[
100
+ torch.arange(indices.shape[0], device=indices.device),
101
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
+ ]
103
+ else:
104
+ if min_p > 0.0:
105
+ logits_top = logits.clone()
106
+ max_prob = logits_top[..., 0].item()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering(logits_top, min_prob)
109
+ if temperature != 1.0:
110
+ logits_top /= temperature
111
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
+ # Clone so that when we modify for top_p we don't change the original logits
113
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
+ modify_logits_for_top_p_filtering(logits_top, top_p)
115
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
+ dim=-1
117
+ )
118
+
119
+
120
+ @torch.inference_mode()
121
+ def decode(
122
+ input_ids,
123
+ model,
124
+ max_length,
125
+ top_k=1,
126
+ top_p=0.0,
127
+ min_p=0.0,
128
+ temperature=1.0,
129
+ repetition_penalty=1.0,
130
+ eos_token_id=None,
131
+ teacher_outputs=None,
132
+ vocab_size=None,
133
+ cg=False,
134
+ enable_timing=False,
135
+ output_scores=False,
136
+ streamer: Optional[TextStreamer] = None
137
+ ):
138
+ """Decoding, either greedy or with top-k or top-p sampling.
139
+ If top-k = 0, don't limit the number of candidates (pure sampling).
140
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
141
+ then top-p.
142
+ We assume that all sequences in the same batch have the same length.
143
+
144
+ Arguments:
145
+ input_ids: (batch, seq_len)
146
+ max_length: int
147
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
148
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
149
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
150
+ sequences: (batch, max_length)
151
+ scores: tuples of (batch, vocab_size)
152
+ """
153
+ if streamer is not None:
154
+ streamer.put(input_ids.cpu())
155
+
156
+ batch_size, seqlen_og = input_ids.shape
157
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
158
+ if cg:
159
+ if not hasattr(model, "_decoding_cache"):
160
+ model._decoding_cache = None
161
+ model._decoding_cache = update_graph_cache(
162
+ model,
163
+ model._decoding_cache,
164
+ batch_size,
165
+ seqlen_og,
166
+ max_length,
167
+ )
168
+ inference_params = model._decoding_cache.inference_params
169
+ inference_params.reset(max_length, batch_size)
170
+ else:
171
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
172
+
173
+ def get_logits(input_ids, inference_params):
174
+ decoding = inference_params.seqlen_offset > 0
175
+ if decoding:
176
+ position_ids = torch.full(
177
+ (batch_size, 1),
178
+ inference_params.seqlen_offset,
179
+ dtype=torch.long,
180
+ device=input_ids.device,
181
+ )
182
+ else:
183
+ position_ids = None
184
+ if not cg or not decoding:
185
+ logits = model(
186
+ input_ids,
187
+ position_ids=position_ids,
188
+ inference_params=inference_params,
189
+ num_last_tokens=1,
190
+ ).logits.squeeze(dim=1)
191
+ else:
192
+ logits = model._decoding_cache.run(
193
+ input_ids, position_ids, inference_params.seqlen_offset
194
+ ).squeeze(dim=1)
195
+ return logits[..., :vocab_size] if vocab_size is not None else logits
196
+
197
+ def sample_tokens(logits, inference_params):
198
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
199
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
200
+ else:
201
+ token = teacher_outputs[:, inference_params.seqlen_offset]
202
+ # return rearrange(token, "b -> b 1")
203
+ return token.unsqueeze(1)
204
+
205
+ def should_stop(current_token, inference_params):
206
+ if inference_params.seqlen_offset == 0:
207
+ return False
208
+ if eos_token_id is not None and (current_token == eos_token_id).all():
209
+ return True
210
+ if inference_params.seqlen_offset >= max_length - 1:
211
+ return True
212
+ return False
213
+
214
+ start = torch.cuda.Event(enable_timing=enable_timing)
215
+ end = torch.cuda.Event(enable_timing=enable_timing)
216
+
217
+ if enable_timing:
218
+ start.record()
219
+ scores, sequences = [], [input_ids]
220
+ sequences_cat = input_ids
221
+ while not should_stop(sequences[-1], inference_params):
222
+ logits = get_logits(sequences[-1], inference_params)
223
+ if output_scores:
224
+ scores.append(logits.clone())
225
+ inference_params.seqlen_offset += sequences[-1].shape[1]
226
+ if repetition_penalty == 1.0:
227
+ sampled_tokens = sample_tokens(logits, inference_params)
228
+ else:
229
+ logits = modify_logit_for_repetition_penalty(
230
+ logits, sequences_cat, repetition_penalty
231
+ )
232
+ sampled_tokens = sample_tokens(logits, inference_params)
233
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
234
+ sequences.append(sampled_tokens)
235
+ if streamer is not None:
236
+ streamer.put(sampled_tokens.cpu())
237
+ if streamer is not None:
238
+ streamer.end()
239
+ if enable_timing:
240
+ end.record()
241
+ torch.cuda.synchronize()
242
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
243
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
244
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
245
+
246
+
247
+ class GenerationMixin:
248
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
249
+ raise NotImplementedError
250
+
251
+ def generate(
252
+ self,
253
+ input_ids,
254
+ max_length,
255
+ top_k=1,
256
+ top_p=0.0,
257
+ min_p=0.0,
258
+ temperature=1.0,
259
+ return_dict_in_generate=False,
260
+ output_scores=False,
261
+ **kwargs,
262
+ ):
263
+ output = decode(
264
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
265
+ )
266
+ if not output_scores:
267
+ output.scores = None
268
+ return output if return_dict_in_generate else output.sequences
269
+
270
+
271
+ @dataclass
272
+ class DecodingCGCache:
273
+ max_batch_size: int = 0
274
+ max_seqlen: int = 0
275
+ device = None
276
+ dtype = None
277
+ callables: dict = field(default_factory=dict)
278
+ mempool = None
279
+ inference_params: Optional[InferenceParams] = None
280
+ run: Optional[Callable] = None
281
+
282
+
283
+ @torch.inference_mode()
284
+ def update_graph_cache(
285
+ model,
286
+ cache,
287
+ batch_size,
288
+ seqlen_og,
289
+ max_seqlen,
290
+ decoding_seqlens=(1,),
291
+ dtype=None,
292
+ n_warmups=2,
293
+ ):
294
+ if cache is None:
295
+ cache = DecodingCGCache()
296
+ param_example = next(iter(model.parameters()))
297
+ device = param_example.device
298
+ if dtype is None:
299
+ dtype = param_example.dtype
300
+ if (
301
+ (device, dtype) != (cache.device, cache.dtype)
302
+ or batch_size > cache.max_batch_size
303
+ or max_seqlen > cache.max_seqlen
304
+ ): # Invalidate the cache
305
+ cache.callables = {}
306
+ cache.mempool = None
307
+ cache.inference_params = None
308
+ gc.collect()
309
+ cache.device, cache.dtype = device, dtype
310
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
311
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
312
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
313
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
314
+ cache.inference_params = InferenceParams(
315
+ max_seqlen=max_seqlen,
316
+ max_batch_size=batch_size,
317
+ seqlen_offset=seqlen_og,
318
+ key_value_memory_dict=inf_cache,
319
+ lengths_per_sample=lengths_per_sample,
320
+ )
321
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
322
+ for decoding_seqlen in decoding_seqlens:
323
+ if (batch_size, decoding_seqlen) not in cache.callables:
324
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
325
+ model,
326
+ cache.inference_params,
327
+ batch_size,
328
+ max_seqlen,
329
+ decoding_seqlen=decoding_seqlen,
330
+ mempool=cache.mempool,
331
+ n_warmups=n_warmups,
332
+ )
333
+
334
+ def dispatch(input_ids, position_ids, seqlen):
335
+ batch_size, decoding_seqlen = input_ids.shape[:2]
336
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
337
+
338
+ cache.run = dispatch
339
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
340
+ return cache
341
+
342
+
343
+ def capture_graph(
344
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
345
+ ):
346
+ device = next(iter(model.parameters())).device
347
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
348
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
349
+ seqlen_offset_og = inference_params.seqlen_offset
350
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
351
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
352
+
353
+ # Warmup before capture
354
+ s = torch.cuda.Stream()
355
+ s.wait_stream(torch.cuda.current_stream())
356
+ with torch.cuda.stream(s):
357
+ for _ in range(n_warmups):
358
+ logits = model(
359
+ input_ids,
360
+ position_ids=position_ids,
361
+ inference_params=inference_params,
362
+ num_last_tokens=decoding_seqlen,
363
+ ).logits
364
+ s.synchronize()
365
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
366
+ # which requires that graph launch and non-captured launch to not overlap (I think,
367
+ # that's how I interpret the documentation). I'm not sure if this is required.
368
+ if torch.distributed.is_initialized():
369
+ torch.distributed.barrier()
370
+ torch.cuda.current_stream().wait_stream(s)
371
+ # Captures the graph
372
+ # To allow capture, automatically sets a side stream as the current stream in the context
373
+ graph = torch.cuda.CUDAGraph()
374
+ with torch.cuda.graph(graph, pool=mempool):
375
+ logits = model(
376
+ input_ids,
377
+ position_ids=position_ids,
378
+ inference_params=inference_params,
379
+ num_last_tokens=decoding_seqlen,
380
+ ).logits
381
+
382
+ def run(new_input_ids, new_position_ids, seqlen):
383
+ inference_params.lengths_per_sample[:] = seqlen
384
+ input_ids.copy_(new_input_ids)
385
+ position_ids.copy_(new_position_ids)
386
+ graph.replay()
387
+ return logits.clone()
388
+
389
+ inference_params.seqlen_offset = seqlen_offset_og
390
+ return run
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/hf.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+
5
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6
+ from transformers.utils.hub import cached_file
7
+
8
+
9
+ def load_config_hf(model_name):
10
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11
+ return json.load(open(resolved_archive_file))
12
+
13
+
14
+ def load_state_dict_hf(model_name, device=None, dtype=None):
15
+ # If not fp32, then we don't want to load directly to the GPU
16
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18
+ return torch.load(resolved_archive_file, map_location=mapped_device)
19
+ # Convert dtype before moving to GPU to save memory
20
+ if dtype is not None:
21
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23
+ return state_dict
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/torch.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from functools import partial
3
+ from typing import Callable
4
+
5
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6
+ def decorator(*args, **kwargs):
7
+ if cuda_amp_deprecated:
8
+ kwargs["device_type"] = "cuda"
9
+ return dec(*args, **kwargs)
10
+ return decorator
11
+
12
+
13
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
14
+ deprecated = True
15
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
16
+ else:
17
+ deprecated = False
18
+ from torch.cuda.amp import custom_fwd, custom_bwd
19
+
20
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
21
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "2.2.4"
2
+
3
+ from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from .modules.mamba_simple import Mamba
5
+ from .modules.mamba2 import Mamba2
6
+ from .models.mixer_seq_simple import MambaLMHeadModel
7
+
8
+ __all__ = [
9
+ "selective_scan_fn",
10
+ "mamba_inner_fn",
11
+ "Mamba",
12
+ "Mamba2",
13
+ "MambaLMHeadModel",
14
+ ]
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_nmrmresto7zfi.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:854cfdd1c899869de1c88a6a56de1494a3d4a0edd1a04412167599485bc1093e
3
+ size 247806288
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _mamba_ssm_nmrmresto7zfi
3
+ ops = torch.ops._mamba_ssm_nmrmresto7zfi
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_mamba_ssm_nmrmresto7zfi::{op_name}"
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/__init__.py ADDED
File without changes
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/distributed_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.distributed import ProcessGroup
6
+
7
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
+ # version of PyTorch. The following 4 lines are for backward compatibility with
10
+ # older PyTorch.
11
+ if "all_gather_into_tensor" not in dir(torch.distributed):
12
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
+ if "reduce_scatter_tensor" not in dir(torch.distributed):
14
+ torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
+
16
+
17
+ # Raw operation, does not support autograd, but does support async
18
+ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
+ world_size = torch.distributed.get_world_size(process_group)
20
+ output = torch.empty(
21
+ world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
+ )
23
+ handle = torch.distributed.all_gather_into_tensor(
24
+ output, input_.contiguous(), group=process_group, async_op=async_op
25
+ )
26
+ return output, handle
27
+
28
+
29
+ # Raw operation, does not support autograd, but does support async
30
+ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
+ world_size = torch.distributed.get_world_size(process_group)
32
+ assert input_.shape[0] % world_size == 0
33
+ output = torch.empty(
34
+ input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
+ )
36
+ handle = torch.distributed.reduce_scatter_tensor(
37
+ output, input_.contiguous(), group=process_group, async_op=async_op
38
+ )
39
+ return output, handle
40
+
41
+
42
+ # Raw operation, does not support autograd, but does support async
43
+ def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
+ input_ = input_.contiguous()
45
+ handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
+ return input_, handle
47
+
48
+
49
+ class AllGatherFunc(torch.autograd.Function):
50
+ """Gather the input from sequence parallel region and concatenate."""
51
+
52
+ @staticmethod
53
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
+ ctx.process_group = process_group
55
+ output, _ = all_gather_raw(input_, process_group)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output: Tensor):
60
+ grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
+ return grad_input, None
62
+
63
+
64
+ # Supports autograd, but does not support async
65
+ all_gather = AllGatherFunc.apply
66
+
67
+
68
+ class ReduceScatterFunc(torch.autograd.Function):
69
+ """Reduce scatter the input from the sequence parallel region and concatenate."""
70
+
71
+ @staticmethod
72
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
+ ctx.process_group = process_group
74
+ output, _ = reduce_scatter_raw(input_, process_group)
75
+ return output
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output: Tensor):
79
+ grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
+ return grad_input, None
81
+
82
+
83
+ # Supports autograd, but does not support async
84
+ reduce_scatter = ReduceScatterFunc.apply
85
+
86
+
87
+ class AllReduceFunc(torch.autograd.Function):
88
+ """Gather the input from sequence parallel region and concatenate."""
89
+
90
+ @staticmethod
91
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
+ ctx.process_group = process_group
93
+ output, _ = all_reduce_raw(input_, process_group)
94
+ return output
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output: Tensor):
98
+ return grad_output, None
99
+
100
+
101
+ # Supports autograd, but does not support async
102
+ all_reduce = AllReduceFunc.apply
103
+
104
+
105
+ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
+ # We want to iterate over parameters with _shared_params=True in the same order,
107
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
+ pamams_shared = {
109
+ name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
+ }
111
+ for _, p in sorted(pamams_shared.items()):
112
+ with torch.no_grad():
113
+ # Broadcast needs src to be global rank, not group rank
114
+ torch.distributed.broadcast(
115
+ p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
+ )
117
+
118
+
119
+ # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
+ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
+ # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
+ params_seqparallel = {
124
+ name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
+ }
126
+ grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
+ if grads:
128
+ with torch.no_grad():
129
+ coalesced = torch._utils._flatten_dense_tensors(grads)
130
+ torch.distributed.all_reduce(coalesced, group=process_group)
131
+ for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
+ buf.copy_(synced)
133
+
134
+
135
+ def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
+ """Get the dim for the local rank derived from splitting dim on world_size processes.
137
+
138
+ The split may not be even across the world_size processes.
139
+ """
140
+ multiple = dim // multiple_of
141
+ div = multiple // world_size
142
+ mod = multiple % world_size
143
+ local_multiple = div + int(local_rank < mod)
144
+ return local_multiple * multiple_of
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.distributed import ProcessGroup
10
+ from ..utils.torch import custom_bwd, custom_fwd
11
+
12
+ from einops import rearrange
13
+
14
+ from ..distributed.distributed_utils import (
15
+ all_gather_raw,
16
+ all_reduce,
17
+ all_reduce_raw,
18
+ reduce_scatter,
19
+ reduce_scatter_raw,
20
+ )
21
+
22
+
23
+ class ParallelLinearFunc(torch.autograd.Function):
24
+ @staticmethod
25
+ @custom_fwd
26
+ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
+ """
28
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
+ with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
+ """
31
+ ctx.compute_weight_gradient = weight.requires_grad
32
+ ctx.process_group = process_group
33
+ ctx.sequence_parallel = sequence_parallel
34
+
35
+ if torch.is_autocast_enabled():
36
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
+ x = x.contiguous()
38
+ if process_group is not None and sequence_parallel:
39
+ # We want to kick off the all_gather early, before weight dtype conversion
40
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
+ else:
42
+ total_x = x
43
+
44
+ if torch.is_autocast_enabled():
45
+ weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
+ bias = (
47
+ bias.to(dtype=torch.get_autocast_gpu_dtype())
48
+ if bias is not None
49
+ else None
50
+ )
51
+ weight = weight.contiguous()
52
+ if process_group is not None and sequence_parallel:
53
+ handle_x.wait()
54
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
55
+ batch_dim = batch_shape.numel()
56
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
57
+ output = F.linear(total_x, weight, bias)
58
+ if ctx.compute_weight_gradient:
59
+ ctx.save_for_backward(x, weight)
60
+ else:
61
+ ctx.save_for_backward(weight)
62
+ return output
63
+
64
+ @staticmethod
65
+ @custom_bwd
66
+ def backward(ctx, grad_output):
67
+ grad_output = grad_output.contiguous()
68
+ process_group = ctx.process_group
69
+ sequence_parallel = ctx.sequence_parallel
70
+ if ctx.compute_weight_gradient:
71
+ x, weight = ctx.saved_tensors
72
+ if process_group is not None and sequence_parallel:
73
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
74
+ else:
75
+ total_x = x
76
+ else:
77
+ (weight,) = ctx.saved_tensors
78
+ total_x = None
79
+ batch_shape = grad_output.shape[:-1]
80
+ batch_dim = batch_shape.numel()
81
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
82
+ if ctx.needs_input_grad[0]:
83
+ grad_input = F.linear(grad_output, weight.t())
84
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
85
+ if process_group is not None:
86
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
87
+ grad_input, handle_grad_input = reduce_fn(
88
+ grad_input, process_group, async_op=True
89
+ )
90
+ else:
91
+ grad_input = None
92
+ if ctx.needs_input_grad[1]:
93
+ assert ctx.compute_weight_gradient
94
+ if process_group is not None and sequence_parallel:
95
+ handle_x.wait()
96
+ grad_weight = torch.einsum(
97
+ "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
98
+ )
99
+ else:
100
+ grad_weight = None
101
+ grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
102
+ if process_group is not None and ctx.needs_input_grad[0]:
103
+ handle_grad_input.wait()
104
+ return grad_input, grad_weight, grad_bias, None, None
105
+
106
+
107
+ def parallel_linear_func(
108
+ x: Tensor,
109
+ weight: Tensor,
110
+ bias: Optional[Tensor] = None,
111
+ process_group: Optional[ProcessGroup] = None,
112
+ sequence_parallel: bool = True,
113
+ ):
114
+ return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
115
+
116
+
117
+ class ColumnParallelLinear(nn.Linear):
118
+ def __init__(
119
+ self,
120
+ in_features: int,
121
+ out_features: int,
122
+ process_group: ProcessGroup,
123
+ bias: bool = True,
124
+ sequence_parallel=True,
125
+ multiple_of=1,
126
+ device=None,
127
+ dtype=None,
128
+ ) -> None:
129
+ world_size = torch.distributed.get_world_size(process_group)
130
+ if out_features % multiple_of:
131
+ raise ValueError(
132
+ f"out_features ({out_features}) must be a multiple of {multiple_of}"
133
+ )
134
+ multiple = out_features // multiple_of
135
+ # We want to split @multiple across world_size, but it could be an uneven split
136
+ div = multiple // world_size
137
+ mod = multiple % world_size
138
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
139
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
140
+ super().__init__(
141
+ in_features,
142
+ local_multiple * multiple_of,
143
+ bias=bias,
144
+ device=device,
145
+ dtype=dtype,
146
+ )
147
+ self.process_group = process_group
148
+ self.sequence_parallel = sequence_parallel
149
+
150
+ def forward(self, x):
151
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
152
+ # we do an all_gather of x before doing the matmul.
153
+ # If not, then the input is already gathered.
154
+ return parallel_linear_func(
155
+ x,
156
+ self.weight,
157
+ self.bias,
158
+ process_group=self.process_group,
159
+ sequence_parallel=self.sequence_parallel,
160
+ )
161
+
162
+
163
+ class RowParallelLinear(nn.Linear):
164
+ def __init__(
165
+ self,
166
+ in_features: int,
167
+ out_features: int,
168
+ process_group: ProcessGroup,
169
+ bias: bool = True,
170
+ sequence_parallel=True,
171
+ multiple_of=1,
172
+ device=None,
173
+ dtype=None,
174
+ ) -> None:
175
+ world_size = torch.distributed.get_world_size(process_group)
176
+ rank = torch.distributed.get_rank(process_group)
177
+ if in_features % multiple_of:
178
+ raise ValueError(
179
+ f"in_features ({in_features}) must be a multiple of {multiple_of}"
180
+ )
181
+ multiple = in_features // multiple_of
182
+ # We want to split @multiple across world_size, but it could be an uneven split
183
+ div = multiple // world_size
184
+ mod = multiple % world_size
185
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
+ # Only rank 0 will have bias
188
+ super().__init__(
189
+ local_multiple * multiple_of,
190
+ out_features,
191
+ bias=bias and rank == 0,
192
+ device=device,
193
+ dtype=dtype,
194
+ )
195
+ self.process_group = process_group
196
+ self.sequence_parallel = sequence_parallel
197
+
198
+ def forward(self, x):
199
+ """
200
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
201
+ a reduce_scatter of the result.
202
+ """
203
+ out = parallel_linear_func(x, self.weight, self.bias)
204
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
205
+ return reduce_fn(out, self.process_group)
206
+
207
+
208
+ class VocabParallelEmbedding(nn.Embedding):
209
+ def __init__(
210
+ self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
211
+ ):
212
+ self.process_group = process_group
213
+ if process_group is not None:
214
+ world_size = torch.distributed.get_world_size(process_group)
215
+ if num_embeddings % world_size != 0:
216
+ raise ValueError(
217
+ f"num_embeddings ({num_embeddings}) must be divisible by "
218
+ f"world_size ({world_size})"
219
+ )
220
+ if world_size > 1 and padding_idx is not None:
221
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
222
+ else:
223
+ world_size = 1
224
+ super().__init__(
225
+ num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
226
+ )
227
+
228
+ def forward(self, input: Tensor) -> Tensor:
229
+ if self.process_group is None:
230
+ return super().forward(input)
231
+ else:
232
+ rank = torch.distributed.get_rank(self.process_group)
233
+ vocab_size = self.num_embeddings
234
+ vocab_start_index, vocab_end_index = (
235
+ rank * vocab_size,
236
+ (rank + 1) * vocab_size,
237
+ )
238
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
239
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
240
+ input = input - vocab_start_index
241
+ input[input_ids_mask] = 0
242
+ embeddings = super().forward(input)
243
+ embeddings[input_ids_mask] = 0.0
244
+ return embeddings
245
+
246
+
247
+ class ColumnParallelEmbedding(nn.Embedding):
248
+ def __init__(
249
+ self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
250
+ ):
251
+ self.process_group = process_group
252
+ if process_group is not None:
253
+ world_size = torch.distributed.get_world_size(process_group)
254
+ if embedding_dim % world_size != 0:
255
+ raise ValueError(
256
+ f"embedding_dim ({embedding_dim}) must be divisible by "
257
+ f"world_size ({world_size})"
258
+ )
259
+ else:
260
+ world_size = 1
261
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
262
+
263
+
264
+ class ParallelEmbeddings(nn.Module):
265
+ def __init__(
266
+ self,
267
+ embed_dim,
268
+ vocab_size,
269
+ max_position_embeddings,
270
+ process_group,
271
+ padding_idx=None,
272
+ sequence_parallel=True,
273
+ device=None,
274
+ dtype=None,
275
+ ):
276
+ """
277
+ If max_position_embeddings <= 0, there's no position embeddings
278
+ """
279
+ factory_kwargs = {"device": device, "dtype": dtype}
280
+ super().__init__()
281
+ self.process_group = process_group
282
+ self.sequence_parallel = sequence_parallel
283
+ self.word_embeddings = VocabParallelEmbedding(
284
+ vocab_size,
285
+ embed_dim,
286
+ padding_idx=padding_idx,
287
+ process_group=process_group,
288
+ **factory_kwargs,
289
+ )
290
+ self.max_position_embeddings = max_position_embeddings
291
+ if self.max_position_embeddings > 0:
292
+ self.position_embeddings = ColumnParallelEmbedding(
293
+ max_position_embeddings,
294
+ embed_dim,
295
+ process_group=process_group,
296
+ **factory_kwargs,
297
+ )
298
+
299
+ def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
300
+ """
301
+ input_ids: (batch, seqlen)
302
+ position_ids: (batch, seqlen)
303
+ """
304
+ batch_size, seqlen = input_ids.shape
305
+ world_size = torch.distributed.get_world_size(self.process_group)
306
+ embeddings = self.word_embeddings(input_ids)
307
+ if self.max_position_embeddings > 0:
308
+ if position_ids is None:
309
+ position_ids = torch.arange(
310
+ seqlen, dtype=torch.long, device=input_ids.device
311
+ )
312
+ position_embeddings = self.position_embeddings(position_ids)
313
+ if world_size <= 1:
314
+ embeddings = embeddings + position_embeddings
315
+ else:
316
+ partition_dim = self.position_embeddings.embedding_dim
317
+ rank = torch.distributed.get_rank(self.process_group)
318
+ embeddings[
319
+ ..., rank * partition_dim : (rank + 1) * partition_dim
320
+ ] += position_embeddings
321
+ if combine_batch_seqlen_dim:
322
+ embeddings = rearrange(embeddings, "b s d -> (b s) d")
323
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
324
+ return (
325
+ embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
326
+ )
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/__init__.py ADDED
File without changes
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/config_mamba.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ d_intermediate: int = 0
9
+ n_layer: int = 64
10
+ vocab_size: int = 50277
11
+ ssm_cfg: dict = field(default_factory=dict)
12
+ attn_layer_idx: list = field(default_factory=list)
13
+ attn_cfg: dict = field(default_factory=dict)
14
+ rms_norm: bool = True
15
+ residual_in_fp32: bool = True
16
+ fused_add_norm: bool = True
17
+ pad_vocab_size_multiple: int = 8
18
+ tie_embeddings: bool = True
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+ import copy
8
+
9
+ from collections import namedtuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .config_mamba import MambaConfig
15
+ from ..modules.mamba_simple import Mamba
16
+ from ..modules.mamba2 import Mamba2
17
+ from ..modules.mha import MHA
18
+ from ..modules.mlp import GatedMLP
19
+ from ..modules.block import Block
20
+ from ..utils.generation import GenerationMixin
21
+ from ..utils.hf import load_config_hf, load_state_dict_hf
22
+
23
+ try:
24
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
+ except ImportError:
26
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
+
28
+
29
+ def create_block(
30
+ d_model,
31
+ d_intermediate,
32
+ ssm_cfg=None,
33
+ attn_layer_idx=None,
34
+ attn_cfg=None,
35
+ norm_epsilon=1e-5,
36
+ rms_norm=False,
37
+ residual_in_fp32=False,
38
+ fused_add_norm=False,
39
+ layer_idx=None,
40
+ device=None,
41
+ dtype=None,
42
+ ):
43
+ if ssm_cfg is None:
44
+ ssm_cfg = {}
45
+ if attn_layer_idx is None:
46
+ attn_layer_idx = []
47
+ if attn_cfg is None:
48
+ attn_cfg = {}
49
+ factory_kwargs = {"device": device, "dtype": dtype}
50
+ if layer_idx not in attn_layer_idx:
51
+ # Create a copy of the config to modify
52
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
55
+ raise ValueError(
56
+ f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
57
+ )
58
+ mixer_cls = partial(
59
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
60
+ layer_idx=layer_idx,
61
+ **ssm_cfg,
62
+ **factory_kwargs,
63
+ )
64
+ else:
65
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
66
+ norm_cls = partial(
67
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
+ )
69
+ if d_intermediate == 0:
70
+ mlp_cls = nn.Identity
71
+ else:
72
+ mlp_cls = partial(
73
+ GatedMLP,
74
+ hidden_features=d_intermediate,
75
+ out_features=d_model,
76
+ **factory_kwargs,
77
+ )
78
+ block = Block(
79
+ d_model,
80
+ mixer_cls,
81
+ mlp_cls,
82
+ norm_cls=norm_cls,
83
+ fused_add_norm=fused_add_norm,
84
+ residual_in_fp32=residual_in_fp32,
85
+ )
86
+ block.layer_idx = layer_idx
87
+ return block
88
+
89
+
90
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
91
+ def _init_weights(
92
+ module,
93
+ n_layer,
94
+ initializer_range=0.02, # Now only used for embedding layer.
95
+ rescale_prenorm_residual=True,
96
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
97
+ ):
98
+ if isinstance(module, nn.Linear):
99
+ if module.bias is not None:
100
+ if not getattr(module.bias, "_no_reinit", False):
101
+ nn.init.zeros_(module.bias)
102
+ elif isinstance(module, nn.Embedding):
103
+ nn.init.normal_(module.weight, std=initializer_range)
104
+
105
+ if rescale_prenorm_residual:
106
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
107
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
108
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
109
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
110
+ #
111
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
112
+ for name, p in module.named_parameters():
113
+ if name in ["out_proj.weight", "fc2.weight"]:
114
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
115
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
116
+ # We need to reinit p since this code could be called multiple times
117
+ # Having just p *= scale would repeatedly scale it down
118
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
119
+ with torch.no_grad():
120
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
121
+
122
+
123
+ class MixerModel(nn.Module):
124
+ def __init__(
125
+ self,
126
+ d_model: int,
127
+ n_layer: int,
128
+ d_intermediate: int,
129
+ vocab_size: int,
130
+ ssm_cfg=None,
131
+ attn_layer_idx=None,
132
+ attn_cfg=None,
133
+ norm_epsilon: float = 1e-5,
134
+ rms_norm: bool = False,
135
+ initializer_cfg=None,
136
+ fused_add_norm=False,
137
+ residual_in_fp32=False,
138
+ device=None,
139
+ dtype=None,
140
+ ) -> None:
141
+ factory_kwargs = {"device": device, "dtype": dtype}
142
+ super().__init__()
143
+ self.residual_in_fp32 = residual_in_fp32
144
+
145
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
146
+
147
+ # We change the order of residual and layer norm:
148
+ # Instead of LN -> Attn / MLP -> Add, we do:
149
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
150
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
151
+ # This is for performance reason: we can fuse add + layer_norm.
152
+ self.fused_add_norm = fused_add_norm
153
+ if self.fused_add_norm:
154
+ if layer_norm_fn is None or rms_norm_fn is None:
155
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
156
+
157
+ self.layers = nn.ModuleList(
158
+ [
159
+ create_block(
160
+ d_model,
161
+ d_intermediate=d_intermediate,
162
+ ssm_cfg=ssm_cfg,
163
+ attn_layer_idx=attn_layer_idx,
164
+ attn_cfg=attn_cfg,
165
+ norm_epsilon=norm_epsilon,
166
+ rms_norm=rms_norm,
167
+ residual_in_fp32=residual_in_fp32,
168
+ fused_add_norm=fused_add_norm,
169
+ layer_idx=i,
170
+ **factory_kwargs,
171
+ )
172
+ for i in range(n_layer)
173
+ ]
174
+ )
175
+
176
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
+ d_model, eps=norm_epsilon, **factory_kwargs
178
+ )
179
+
180
+ self.apply(
181
+ partial(
182
+ _init_weights,
183
+ n_layer=n_layer,
184
+ **(initializer_cfg if initializer_cfg is not None else {}),
185
+ n_residuals_per_layer=(
186
+ 1 if d_intermediate == 0 else 2
187
+ ), # 2 if we have MLP
188
+ )
189
+ )
190
+
191
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
192
+ return {
193
+ i: layer.allocate_inference_cache(
194
+ batch_size, max_seqlen, dtype=dtype, **kwargs
195
+ )
196
+ for i, layer in enumerate(self.layers)
197
+ }
198
+
199
+ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
200
+ hidden_states = self.embedding(input_ids)
201
+ residual = None
202
+ for layer in self.layers:
203
+ hidden_states, residual = layer(
204
+ hidden_states,
205
+ residual,
206
+ inference_params=inference_params,
207
+ **mixer_kwargs,
208
+ )
209
+ if not self.fused_add_norm:
210
+ residual = (
211
+ (hidden_states + residual) if residual is not None else hidden_states
212
+ )
213
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
214
+ else:
215
+ # Set prenorm=False here since we don't need the residual
216
+ hidden_states = layer_norm_fn(
217
+ hidden_states,
218
+ self.norm_f.weight,
219
+ self.norm_f.bias,
220
+ eps=self.norm_f.eps,
221
+ residual=residual,
222
+ prenorm=False,
223
+ residual_in_fp32=self.residual_in_fp32,
224
+ is_rms_norm=isinstance(self.norm_f, RMSNorm),
225
+ )
226
+ return hidden_states
227
+
228
+
229
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
230
+
231
+ def __init__(
232
+ self,
233
+ config: MambaConfig,
234
+ initializer_cfg=None,
235
+ device=None,
236
+ dtype=None,
237
+ ) -> None:
238
+ self.config = config
239
+ d_model = config.d_model
240
+ n_layer = config.n_layer
241
+ d_intermediate = config.d_intermediate
242
+ vocab_size = config.vocab_size
243
+ ssm_cfg = config.ssm_cfg
244
+ attn_layer_idx = config.attn_layer_idx
245
+ attn_cfg = config.attn_cfg
246
+ rms_norm = config.rms_norm
247
+ residual_in_fp32 = config.residual_in_fp32
248
+ fused_add_norm = config.fused_add_norm
249
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
250
+ factory_kwargs = {"device": device, "dtype": dtype}
251
+
252
+ super().__init__()
253
+ if vocab_size % pad_vocab_size_multiple != 0:
254
+ vocab_size += pad_vocab_size_multiple - (
255
+ vocab_size % pad_vocab_size_multiple
256
+ )
257
+ self.backbone = MixerModel(
258
+ d_model=d_model,
259
+ n_layer=n_layer,
260
+ d_intermediate=d_intermediate,
261
+ vocab_size=vocab_size,
262
+ ssm_cfg=ssm_cfg,
263
+ attn_layer_idx=attn_layer_idx,
264
+ attn_cfg=attn_cfg,
265
+ rms_norm=rms_norm,
266
+ initializer_cfg=initializer_cfg,
267
+ fused_add_norm=fused_add_norm,
268
+ residual_in_fp32=residual_in_fp32,
269
+ **factory_kwargs,
270
+ )
271
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
272
+
273
+ # Initialize weights and apply final processing
274
+ self.apply(
275
+ partial(
276
+ _init_weights,
277
+ n_layer=n_layer,
278
+ **(initializer_cfg if initializer_cfg is not None else {}),
279
+ )
280
+ )
281
+ self.tie_weights()
282
+
283
+ def tie_weights(self):
284
+ if self.config.tie_embeddings:
285
+ self.lm_head.weight = self.backbone.embedding.weight
286
+
287
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
288
+ return self.backbone.allocate_inference_cache(
289
+ batch_size, max_seqlen, dtype=dtype, **kwargs
290
+ )
291
+
292
+ def forward(
293
+ self,
294
+ input_ids,
295
+ position_ids=None,
296
+ inference_params=None,
297
+ num_last_tokens=0,
298
+ **mixer_kwargs,
299
+ ):
300
+ """
301
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
302
+ num_last_tokens: if > 0, only return the logits for the last n tokens
303
+ """
304
+ hidden_states = self.backbone(
305
+ input_ids, inference_params=inference_params, **mixer_kwargs
306
+ )
307
+ if num_last_tokens > 0:
308
+ hidden_states = hidden_states[:, -num_last_tokens:]
309
+ lm_logits = self.lm_head(hidden_states)
310
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
311
+ return CausalLMOutput(logits=lm_logits)
312
+
313
+ @classmethod
314
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
315
+ config_data = load_config_hf(pretrained_model_name)
316
+ config = MambaConfig(**config_data)
317
+ model = cls(config, device=device, dtype=dtype, **kwargs)
318
+ model.load_state_dict(
319
+ load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
320
+ )
321
+ return model
322
+
323
+ def save_pretrained(self, save_directory):
324
+ """
325
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
326
+ Save the model and its configuration file to a directory.
327
+ """
328
+ # Ensure save_directory exists
329
+ os.makedirs(save_directory, exist_ok=True)
330
+
331
+ # Save the model's state_dict
332
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
333
+ torch.save(self.state_dict(), model_path)
334
+
335
+ # Save the configuration of the model
336
+ config_path = os.path.join(save_directory, "config.json")
337
+ with open(config_path, "w") as f:
338
+ json.dump(self.config.__dict__, f, indent=4)
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/__init__.py ADDED
File without changes
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/block.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
+
9
+
10
+ class Block(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim,
14
+ mixer_cls,
15
+ mlp_cls,
16
+ norm_cls=nn.LayerNorm,
17
+ fused_add_norm=False,
18
+ residual_in_fp32=False,
19
+ ):
20
+ """
21
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
22
+
23
+ This Block has a slightly different structure compared to a regular
24
+ prenorm Transformer block.
25
+ The standard block is: LN -> MHA/MLP -> Add.
26
+ [Ref: https://arxiv.org/abs/2002.04745]
27
+ Here we have: Add -> LN -> Mixer, returning both
28
+ the hidden_states (output of the mixer) and the residual.
29
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
30
+ The residual needs to be provided (except for the very first block).
31
+ """
32
+ super().__init__()
33
+ self.residual_in_fp32 = residual_in_fp32
34
+ self.fused_add_norm = fused_add_norm
35
+ self.norm = norm_cls(dim)
36
+ self.mixer = mixer_cls(dim)
37
+ if mlp_cls is not nn.Identity:
38
+ self.norm2 = norm_cls(dim)
39
+ self.mlp = mlp_cls(dim)
40
+ else:
41
+ self.mlp = None
42
+ if self.fused_add_norm:
43
+ assert RMSNorm is not None, "RMSNorm import fails"
44
+ assert isinstance(
45
+ self.norm, (nn.LayerNorm, RMSNorm)
46
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: Tensor,
51
+ residual: Optional[Tensor] = None,
52
+ inference_params=None,
53
+ **mixer_kwargs
54
+ ):
55
+ r"""Pass the input through the encoder layer.
56
+
57
+ Args:
58
+ hidden_states: the sequence to the encoder layer (required).
59
+ residual: hidden_states = Mixer(LN(residual))
60
+ """
61
+ if not self.fused_add_norm:
62
+ residual = (
63
+ (hidden_states + residual) if residual is not None else hidden_states
64
+ )
65
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
66
+ if self.residual_in_fp32:
67
+ residual = residual.to(torch.float32)
68
+ else:
69
+ hidden_states, residual = layer_norm_fn(
70
+ hidden_states,
71
+ self.norm.weight,
72
+ self.norm.bias,
73
+ residual=residual,
74
+ prenorm=True,
75
+ residual_in_fp32=self.residual_in_fp32,
76
+ eps=self.norm.eps,
77
+ is_rms_norm=isinstance(self.norm, RMSNorm),
78
+ )
79
+ hidden_states = self.mixer(
80
+ hidden_states, inference_params=inference_params, **mixer_kwargs
81
+ )
82
+
83
+ if self.mlp is not None:
84
+ if not self.fused_add_norm:
85
+ residual = hidden_states + residual
86
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
87
+ if self.residual_in_fp32:
88
+ residual = residual.to(torch.float32)
89
+ else:
90
+ hidden_states, residual = layer_norm_fn(
91
+ hidden_states,
92
+ self.norm2.weight,
93
+ self.norm2.bias,
94
+ residual=residual,
95
+ prenorm=True,
96
+ residual_in_fp32=self.residual_in_fp32,
97
+ eps=self.norm2.eps,
98
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
99
+ )
100
+ hidden_states = self.mlp(hidden_states)
101
+
102
+ return hidden_states, residual
103
+
104
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
+ return self.mixer.allocate_inference_cache(
106
+ batch_size, max_seqlen, dtype=dtype, **kwargs
107
+ )
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba2.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ try:
22
+ from ..ops.triton.selective_state_update import selective_state_update
23
+ except ImportError:
24
+ selective_state_update = None
25
+
26
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
+
28
+ from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
+ from ..distributed.distributed_utils import all_reduce, reduce_scatter
30
+
31
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
32
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
+
34
+ from huggingface_hub import PyTorchModelHubMixin
35
+
36
+
37
+ class Mamba2(nn.Module, PyTorchModelHubMixin):
38
+ def __init__(
39
+ self,
40
+ d_model,
41
+ d_state=128,
42
+ d_conv=4,
43
+ conv_init=None,
44
+ expand=2,
45
+ headdim=64,
46
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
+ ngroups=1,
48
+ A_init_range=(1, 16),
49
+ D_has_hdim=False,
50
+ rmsnorm=True,
51
+ norm_before_gate=False,
52
+ dt_min=0.001,
53
+ dt_max=0.1,
54
+ dt_init_floor=1e-4,
55
+ dt_limit=(0.0, float("inf")),
56
+ bias=False,
57
+ conv_bias=True,
58
+ # Fused kernel and sharding options
59
+ chunk_size=256,
60
+ use_mem_eff_path=True,
61
+ layer_idx=None, # Absorb kwarg for general module
62
+ process_group=None,
63
+ sequence_parallel=True,
64
+ device=None,
65
+ dtype=None,
66
+ ):
67
+ factory_kwargs = {"device": device, "dtype": dtype}
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.d_state = d_state
71
+ self.d_conv = d_conv
72
+ self.conv_init = conv_init
73
+ self.expand = expand
74
+ self.process_group = process_group
75
+ self.sequence_parallel = sequence_parallel
76
+ self.world_size = 1 if process_group is None else process_group.size()
77
+ self.local_rank = 0 if process_group is None else process_group.rank()
78
+ self.d_inner = (self.expand * self.d_model) // self.world_size
79
+ assert self.d_inner * self.world_size == self.expand * self.d_model
80
+ self.headdim = headdim
81
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
+ assert ngroups % self.world_size == 0
83
+ self.ngroups = ngroups // self.world_size
84
+ assert self.d_ssm % self.headdim == 0
85
+ self.nheads = self.d_ssm // self.headdim
86
+ self.D_has_hdim = D_has_hdim
87
+ self.rmsnorm = rmsnorm
88
+ self.norm_before_gate = norm_before_gate
89
+ self.dt_limit = dt_limit
90
+ self.activation = "silu"
91
+ self.chunk_size = chunk_size
92
+ self.use_mem_eff_path = use_mem_eff_path
93
+ self.layer_idx = layer_idx
94
+
95
+ # Order: [z, x, B, C, dt]
96
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
+ if self.process_group is None:
98
+ self.in_proj = nn.Linear(
99
+ self.d_model, d_in_proj, bias=bias, **factory_kwargs
100
+ )
101
+ else:
102
+ self.in_proj = ColumnParallelLinear(
103
+ self.d_model,
104
+ d_in_proj * self.world_size,
105
+ bias=bias,
106
+ process_group=self.process_group,
107
+ sequence_parallel=self.sequence_parallel,
108
+ **factory_kwargs,
109
+ )
110
+
111
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
112
+ self.conv1d = nn.Conv1d(
113
+ in_channels=conv_dim,
114
+ out_channels=conv_dim,
115
+ bias=conv_bias,
116
+ kernel_size=d_conv,
117
+ groups=conv_dim,
118
+ padding=d_conv - 1,
119
+ **factory_kwargs,
120
+ )
121
+ if self.conv_init is not None:
122
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
123
+
124
+ self.act = nn.SiLU()
125
+
126
+ # Initialize log dt bias
127
+ dt = torch.exp(
128
+ torch.rand(self.nheads, **factory_kwargs)
129
+ * (math.log(dt_max) - math.log(dt_min))
130
+ + math.log(dt_min)
131
+ )
132
+ dt = torch.clamp(dt, min=dt_init_floor)
133
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
134
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
135
+ self.dt_bias = nn.Parameter(inv_dt)
136
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
137
+ # name.endswith("bias") in param_grouping.py
138
+ self.dt_bias._no_weight_decay = True
139
+
140
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
141
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
142
+ *A_init_range
143
+ )
144
+ A_log = torch.log(A).to(dtype=dtype)
145
+ self.A_log = nn.Parameter(A_log)
146
+ self.A_log._no_weight_decay = True
147
+
148
+ # D "skip" parameter
149
+ self.D = nn.Parameter(
150
+ torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
151
+ )
152
+ self.D._no_weight_decay = True
153
+
154
+ if self.rmsnorm:
155
+ assert RMSNormGated is not None
156
+ self.norm = RMSNormGated(
157
+ self.d_ssm,
158
+ eps=1e-5,
159
+ norm_before_gate=self.norm_before_gate,
160
+ group_size=self.d_ssm // ngroups,
161
+ **factory_kwargs,
162
+ )
163
+
164
+ if self.process_group is None:
165
+ self.out_proj = nn.Linear(
166
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
167
+ )
168
+ else:
169
+ self.out_proj = RowParallelLinear(
170
+ self.d_inner * self.world_size,
171
+ self.d_model,
172
+ bias=bias,
173
+ process_group=self.process_group,
174
+ sequence_parallel=self.sequence_parallel,
175
+ **factory_kwargs,
176
+ )
177
+
178
+ def forward(
179
+ self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
180
+ ):
181
+ """
182
+ u: (batch, seqlen, hidden_dim) if seqlen=None.
183
+ If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
184
+ split u during sequence parallel, we split the batch * seqlen dimension
185
+ (in case batch is small).
186
+ Returns: same shape as u
187
+ """
188
+ seqlen_og = seqlen
189
+ if seqlen is None:
190
+ batch, seqlen, dim = u.shape
191
+ else:
192
+ batch_seqlen, dim = u.shape
193
+ batch = batch_seqlen // seqlen
194
+
195
+ conv_state, ssm_state = None, None
196
+ if inference_params is not None:
197
+ inference_batch = (
198
+ cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
199
+ )
200
+ conv_state, ssm_state = self._get_states_from_cache(
201
+ inference_params, inference_batch
202
+ )
203
+ if inference_params.seqlen_offset > 0:
204
+ # The states are updated inplace
205
+ out, _, _ = self.step(u, conv_state, ssm_state)
206
+ return out
207
+
208
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
209
+ if seqlen_og is not None:
210
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
211
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
212
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
213
+ dt_limit_kwargs = (
214
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
215
+ )
216
+ if self.use_mem_eff_path and inference_params is None:
217
+ out = mamba_split_conv1d_scan_combined(
218
+ zxbcdt,
219
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
220
+ self.conv1d.bias,
221
+ self.dt_bias,
222
+ A,
223
+ D=(
224
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
225
+ if self.D_has_hdim
226
+ else self.D
227
+ ),
228
+ chunk_size=self.chunk_size,
229
+ seq_idx=seq_idx,
230
+ activation=self.activation,
231
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
232
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
233
+ outproj_weight=self.out_proj.weight,
234
+ outproj_bias=self.out_proj.bias,
235
+ headdim=None if self.D_has_hdim else self.headdim,
236
+ ngroups=self.ngroups,
237
+ norm_before_gate=self.norm_before_gate,
238
+ **dt_limit_kwargs,
239
+ )
240
+ if seqlen_og is not None:
241
+ out = rearrange(out, "b l d -> (b l) d")
242
+ if self.process_group is not None:
243
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
244
+ out = reduce_fn(out, self.process_group)
245
+ else:
246
+ d_mlp = (
247
+ zxbcdt.shape[-1]
248
+ - 2 * self.d_ssm
249
+ - 2 * self.ngroups * self.d_state
250
+ - self.nheads
251
+ ) // 2
252
+ z0, x0, z, xBC, dt = torch.split(
253
+ zxbcdt,
254
+ [
255
+ d_mlp,
256
+ d_mlp,
257
+ self.d_ssm,
258
+ self.d_ssm + 2 * self.ngroups * self.d_state,
259
+ self.nheads,
260
+ ],
261
+ dim=-1,
262
+ )
263
+ if conv_state is not None:
264
+ if cu_seqlens is None:
265
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
266
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
267
+ xBC_t = rearrange(xBC, "b l d -> b d l")
268
+ conv_state.copy_(
269
+ F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
270
+ ) # Update state (B D W)
271
+ else:
272
+ assert (
273
+ causal_conv1d_varlen_states is not None
274
+ ), "varlen inference requires causal_conv1d package"
275
+ assert (
276
+ batch == 1
277
+ ), "varlen inference only supports batch dimension 1"
278
+ conv_varlen_states = causal_conv1d_varlen_states(
279
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
280
+ )
281
+ conv_state.copy_(conv_varlen_states)
282
+ assert self.activation in ["silu", "swish"]
283
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
284
+ assert (
285
+ seq_idx is None
286
+ ), "varlen conv1d requires the causal_conv1d package"
287
+ xBC = self.act(
288
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
289
+ :, : -(self.d_conv - 1)
290
+ ]
291
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
292
+ else:
293
+ xBC = causal_conv1d_fn(
294
+ xBC.transpose(1, 2),
295
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
296
+ bias=self.conv1d.bias,
297
+ activation=self.activation,
298
+ seq_idx=seq_idx,
299
+ ).transpose(1, 2)
300
+ x, B, C = torch.split(
301
+ xBC,
302
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
303
+ dim=-1,
304
+ )
305
+ y = mamba_chunk_scan_combined(
306
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
307
+ dt,
308
+ A,
309
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
310
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
311
+ chunk_size=self.chunk_size,
312
+ D=(
313
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
314
+ if self.D_has_hdim
315
+ else self.D
316
+ ),
317
+ z=(
318
+ rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
319
+ if not self.rmsnorm
320
+ else None
321
+ ),
322
+ dt_bias=self.dt_bias,
323
+ dt_softplus=True,
324
+ seq_idx=seq_idx,
325
+ cu_seqlens=cu_seqlens,
326
+ **dt_limit_kwargs,
327
+ return_final_states=ssm_state is not None,
328
+ return_varlen_states=cu_seqlens is not None
329
+ and inference_params is not None,
330
+ )
331
+ if ssm_state is not None:
332
+ y, last_state, *rest = y
333
+ if cu_seqlens is None:
334
+ ssm_state.copy_(last_state)
335
+ else:
336
+ varlen_states = rest[0]
337
+ ssm_state.copy_(varlen_states)
338
+ y = rearrange(y, "b l h p -> b l (h p)")
339
+ if self.rmsnorm:
340
+ y = self.norm(y, z)
341
+ if d_mlp > 0:
342
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
343
+ if seqlen_og is not None:
344
+ y = rearrange(y, "b l d -> (b l) d")
345
+ out = self.out_proj(y)
346
+ return out
347
+
348
+ def step(self, hidden_states, conv_state, ssm_state):
349
+ dtype = hidden_states.dtype
350
+ assert (
351
+ hidden_states.shape[1] == 1
352
+ ), "Only support decoding with 1 token at a time for now"
353
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
354
+ d_mlp = (
355
+ zxbcdt.shape[-1]
356
+ - 2 * self.d_ssm
357
+ - 2 * self.ngroups * self.d_state
358
+ - self.nheads
359
+ ) // 2
360
+ z0, x0, z, xBC, dt = torch.split(
361
+ zxbcdt,
362
+ [
363
+ d_mlp,
364
+ d_mlp,
365
+ self.d_ssm,
366
+ self.d_ssm + 2 * self.ngroups * self.d_state,
367
+ self.nheads,
368
+ ],
369
+ dim=-1,
370
+ )
371
+
372
+ # Conv step
373
+ if causal_conv1d_update is None:
374
+ conv_state.copy_(
375
+ torch.roll(conv_state, shifts=-1, dims=-1)
376
+ ) # Update state (B D W)
377
+ conv_state[:, :, -1] = xBC
378
+ xBC = torch.sum(
379
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
380
+ ) # (B D)
381
+ if self.conv1d.bias is not None:
382
+ xBC = xBC + self.conv1d.bias
383
+ xBC = self.act(xBC).to(dtype=dtype)
384
+ else:
385
+ xBC = causal_conv1d_update(
386
+ xBC,
387
+ conv_state,
388
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
389
+ self.conv1d.bias,
390
+ self.activation,
391
+ )
392
+
393
+ x, B, C = torch.split(
394
+ xBC,
395
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
396
+ dim=-1,
397
+ )
398
+ A = -torch.exp(self.A_log.float()) # (nheads,)
399
+
400
+ # SSM step
401
+ if selective_state_update is None:
402
+ assert (
403
+ self.ngroups == 1
404
+ ), "Only support ngroups=1 for this inference code path"
405
+ # Discretize A and B
406
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
407
+ dA = torch.exp(dt * A) # (batch, nheads)
408
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
409
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
410
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
411
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
412
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
413
+ y = rearrange(y, "b h p -> b (h p)")
414
+ if not self.rmsnorm:
415
+ y = y * self.act(z) # (B D)
416
+ else:
417
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
418
+ dtype=torch.float32
419
+ )
420
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
421
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
422
+ D = repeat(self.D, "h -> h p", p=self.headdim)
423
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
424
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
425
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
426
+ if not self.rmsnorm:
427
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
428
+ y = selective_state_update(
429
+ ssm_state,
430
+ x_reshaped,
431
+ dt,
432
+ A,
433
+ B,
434
+ C,
435
+ D,
436
+ z=z if not self.rmsnorm else None,
437
+ dt_bias=dt_bias,
438
+ dt_softplus=True,
439
+ )
440
+ y = rearrange(y, "b h p -> b (h p)")
441
+ if self.rmsnorm:
442
+ y = self.norm(y, z)
443
+ if d_mlp > 0:
444
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
445
+ out = self.out_proj(y)
446
+ return out.unsqueeze(1), conv_state, ssm_state
447
+
448
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
449
+ device = self.out_proj.weight.device
450
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
451
+ conv_state = torch.zeros(
452
+ batch_size,
453
+ self.d_conv,
454
+ self.conv1d.weight.shape[0],
455
+ device=device,
456
+ dtype=conv_dtype,
457
+ ).transpose(1, 2)
458
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
459
+ ssm_state = torch.zeros(
460
+ batch_size,
461
+ self.nheads,
462
+ self.headdim,
463
+ self.d_state,
464
+ device=device,
465
+ dtype=ssm_dtype,
466
+ )
467
+ return conv_state, ssm_state
468
+
469
+ def _get_states_from_cache(
470
+ self, inference_params, batch_size, initialize_states=False
471
+ ):
472
+ assert self.layer_idx is not None
473
+ if self.layer_idx not in inference_params.key_value_memory_dict:
474
+ batch_shape = (batch_size,)
475
+ conv_state = torch.zeros(
476
+ batch_size,
477
+ self.d_conv,
478
+ self.conv1d.weight.shape[0],
479
+ device=self.conv1d.weight.device,
480
+ dtype=self.conv1d.weight.dtype,
481
+ ).transpose(1, 2)
482
+ ssm_state = torch.zeros(
483
+ batch_size,
484
+ self.nheads,
485
+ self.headdim,
486
+ self.d_state,
487
+ device=self.in_proj.weight.device,
488
+ dtype=self.in_proj.weight.dtype,
489
+ )
490
+ inference_params.key_value_memory_dict[self.layer_idx] = (
491
+ conv_state,
492
+ ssm_state,
493
+ )
494
+ else:
495
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
496
+ self.layer_idx
497
+ ]
498
+ # TODO: What if batch size changes between generation, and we reuse the same states?
499
+ if initialize_states:
500
+ conv_state.zero_()
501
+ ssm_state.zero_()
502
+ return conv_state, ssm_state
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba2_simple.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ try:
11
+ from causal_conv1d import causal_conv1d_fn
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+
15
+ try:
16
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
17
+ except ImportError:
18
+ RMSNormGated, LayerNorm = None, None
19
+
20
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
21
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
22
+
23
+
24
+ class Mamba2Simple(nn.Module):
25
+ def __init__(
26
+ self,
27
+ d_model,
28
+ d_state=64,
29
+ d_conv=4,
30
+ conv_init=None,
31
+ expand=2,
32
+ headdim=128,
33
+ ngroups=1,
34
+ A_init_range=(1, 16),
35
+ dt_min=0.001,
36
+ dt_max=0.1,
37
+ dt_init_floor=1e-4,
38
+ dt_limit=(0.0, float("inf")),
39
+ learnable_init_states=False,
40
+ activation="swish",
41
+ bias=False,
42
+ conv_bias=True,
43
+ # Fused kernel and sharding options
44
+ chunk_size=256,
45
+ use_mem_eff_path=True,
46
+ layer_idx=None, # Absorb kwarg for general module
47
+ device=None,
48
+ dtype=None,
49
+ ):
50
+ factory_kwargs = {"device": device, "dtype": dtype}
51
+ super().__init__()
52
+ self.d_model = d_model
53
+ self.d_state = d_state
54
+ self.d_conv = d_conv
55
+ self.conv_init = conv_init
56
+ self.expand = expand
57
+ self.d_inner = self.expand * self.d_model
58
+ self.headdim = headdim
59
+ self.ngroups = ngroups
60
+ assert self.d_inner % self.headdim == 0
61
+ self.nheads = self.d_inner // self.headdim
62
+ self.dt_limit = dt_limit
63
+ self.learnable_init_states = learnable_init_states
64
+ self.activation = activation
65
+ self.chunk_size = chunk_size
66
+ self.use_mem_eff_path = use_mem_eff_path
67
+ self.layer_idx = layer_idx
68
+
69
+ # Order: [z, x, B, C, dt]
70
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
71
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
72
+
73
+ conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
74
+ self.conv1d = nn.Conv1d(
75
+ in_channels=conv_dim,
76
+ out_channels=conv_dim,
77
+ bias=conv_bias,
78
+ kernel_size=d_conv,
79
+ groups=conv_dim,
80
+ padding=d_conv - 1,
81
+ **factory_kwargs,
82
+ )
83
+ if self.conv_init is not None:
84
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
85
+ # self.conv1d.weight._no_weight_decay = True
86
+
87
+ if self.learnable_init_states:
88
+ self.init_states = nn.Parameter(
89
+ torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
90
+ )
91
+ self.init_states._no_weight_decay = True
92
+
93
+ self.act = nn.SiLU()
94
+
95
+ # Initialize log dt bias
96
+ dt = torch.exp(
97
+ torch.rand(self.nheads, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ )
101
+ dt = torch.clamp(dt, min=dt_init_floor)
102
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
103
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
104
+ self.dt_bias = nn.Parameter(inv_dt)
105
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
106
+ # name.endswith("bias") in param_grouping.py
107
+ self.dt_bias._no_weight_decay = True
108
+
109
+ # A parameter
110
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
111
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
112
+ *A_init_range
113
+ )
114
+ A_log = torch.log(A).to(dtype=dtype)
115
+ self.A_log = nn.Parameter(A_log)
116
+ # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
117
+ self.A_log._no_weight_decay = True
118
+
119
+ # D "skip" parameter
120
+ self.D = nn.Parameter(torch.ones(self.nheads, device=device))
121
+ self.D._no_weight_decay = True
122
+
123
+ # Extra normalization layer right before output projection
124
+ assert RMSNormGated is not None
125
+ self.norm = RMSNormGated(
126
+ self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
127
+ )
128
+
129
+ self.out_proj = nn.Linear(
130
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
131
+ )
132
+
133
+ def forward(self, u, seq_idx=None):
134
+ """
135
+ u: (B, L, D)
136
+ Returns: same shape as u
137
+ """
138
+ batch, seqlen, dim = u.shape
139
+
140
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
141
+ A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
142
+ initial_states = (
143
+ repeat(self.init_states, "... -> b ...", b=batch)
144
+ if self.learnable_init_states
145
+ else None
146
+ )
147
+ dt_limit_kwargs = (
148
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
149
+ )
150
+
151
+ if self.use_mem_eff_path:
152
+ # Fully fused path
153
+ out = mamba_split_conv1d_scan_combined(
154
+ zxbcdt,
155
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
156
+ self.conv1d.bias,
157
+ self.dt_bias,
158
+ A,
159
+ D=self.D,
160
+ chunk_size=self.chunk_size,
161
+ seq_idx=seq_idx,
162
+ activation=self.activation,
163
+ rmsnorm_weight=self.norm.weight,
164
+ rmsnorm_eps=self.norm.eps,
165
+ outproj_weight=self.out_proj.weight,
166
+ outproj_bias=self.out_proj.bias,
167
+ headdim=self.headdim,
168
+ ngroups=self.ngroups,
169
+ norm_before_gate=False,
170
+ initial_states=initial_states,
171
+ **dt_limit_kwargs,
172
+ )
173
+ else:
174
+ z, xBC, dt = torch.split(
175
+ zxbcdt,
176
+ [
177
+ self.d_inner,
178
+ self.d_inner + 2 * self.ngroups * self.d_state,
179
+ self.nheads,
180
+ ],
181
+ dim=-1,
182
+ )
183
+ dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
184
+ assert self.activation in ["silu", "swish"]
185
+
186
+ # 1D Convolution
187
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
188
+ xBC = self.act(
189
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
190
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
191
+ xBC = xBC[:, :seqlen, :]
192
+ else:
193
+ xBC = causal_conv1d_fn(
194
+ x=xBC.transpose(1, 2),
195
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
196
+ bias=self.conv1d.bias,
197
+ activation=self.activation,
198
+ ).transpose(1, 2)
199
+
200
+ # Split into 3 main branches: X, B, C
201
+ # These correspond to V, K, Q respectively in the SSM/attention duality
202
+ x, B, C = torch.split(
203
+ xBC,
204
+ [
205
+ self.d_inner,
206
+ self.ngroups * self.d_state,
207
+ self.ngroups * self.d_state,
208
+ ],
209
+ dim=-1,
210
+ )
211
+ y = mamba_chunk_scan_combined(
212
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
213
+ dt,
214
+ A,
215
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
216
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
217
+ chunk_size=self.chunk_size,
218
+ D=self.D,
219
+ z=None,
220
+ seq_idx=seq_idx,
221
+ initial_states=initial_states,
222
+ **dt_limit_kwargs,
223
+ )
224
+ y = rearrange(y, "b l h p -> b l (h p)")
225
+
226
+ # Multiply "gate" branch and apply extra normalization layer
227
+ y = self.norm(y, z)
228
+ out = self.out_proj(y)
229
+ return out
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from ..ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(
63
+ self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
64
+ )
65
+
66
+ self.conv1d = nn.Conv1d(
67
+ in_channels=self.d_inner,
68
+ out_channels=self.d_inner,
69
+ bias=conv_bias,
70
+ kernel_size=d_conv,
71
+ groups=self.d_inner,
72
+ padding=d_conv - 1,
73
+ **factory_kwargs,
74
+ )
75
+
76
+ self.activation = "silu"
77
+ self.act = nn.SiLU()
78
+
79
+ self.x_proj = nn.Linear(
80
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
81
+ )
82
+ self.dt_proj = nn.Linear(
83
+ self.dt_rank, self.d_inner, bias=True, **factory_kwargs
84
+ )
85
+
86
+ # Initialize special dt projection to preserve variance at initialization
87
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
88
+ if dt_init == "constant":
89
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
90
+ elif dt_init == "random":
91
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
96
+ dt = torch.exp(
97
+ torch.rand(self.d_inner, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ ).clamp(min=dt_init_floor)
101
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
102
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
103
+ with torch.no_grad():
104
+ self.dt_proj.bias.copy_(inv_dt)
105
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
106
+ self.dt_proj.bias._no_reinit = True
107
+
108
+ # S4D real initialization
109
+ A = repeat(
110
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
111
+ "n -> d n",
112
+ d=self.d_inner,
113
+ ).contiguous()
114
+ A_log = torch.log(A) # Keep A_log in fp32
115
+ self.A_log = nn.Parameter(A_log)
116
+ self.A_log._no_weight_decay = True
117
+
118
+ # D "skip" parameter
119
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
120
+ self.D._no_weight_decay = True
121
+
122
+ self.out_proj = nn.Linear(
123
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
124
+ )
125
+
126
+ def forward(self, hidden_states, inference_params=None):
127
+ """
128
+ hidden_states: (B, L, D)
129
+ Returns: same shape as hidden_states
130
+ """
131
+ batch, seqlen, dim = hidden_states.shape
132
+
133
+ conv_state, ssm_state = None, None
134
+ if inference_params is not None:
135
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
136
+ if inference_params.seqlen_offset > 0:
137
+ # The states are updated inplace
138
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
139
+ return out
140
+
141
+ # We do matmul and transpose BLH -> HBL at the same time
142
+ xz = rearrange(
143
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
144
+ "d (b l) -> b d l",
145
+ l=seqlen,
146
+ )
147
+ if self.in_proj.bias is not None:
148
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
149
+
150
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
151
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
152
+ if (
153
+ self.use_fast_path
154
+ and causal_conv1d_fn is not None
155
+ and inference_params is None
156
+ ): # Doesn't support outputting the states
157
+ out = mamba_inner_fn(
158
+ xz,
159
+ self.conv1d.weight,
160
+ self.conv1d.bias,
161
+ self.x_proj.weight,
162
+ self.dt_proj.weight,
163
+ self.out_proj.weight,
164
+ self.out_proj.bias,
165
+ A,
166
+ None, # input-dependent B
167
+ None, # input-dependent C
168
+ self.D.float(),
169
+ delta_bias=self.dt_proj.bias.float(),
170
+ delta_softplus=True,
171
+ )
172
+ else:
173
+ x, z = xz.chunk(2, dim=1)
174
+ # Compute short convolution
175
+ if conv_state is not None:
176
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
177
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
178
+ conv_state.copy_(
179
+ F.pad(x, (self.d_conv - x.shape[-1], 0))
180
+ ) # Update state (B D W)
181
+ if causal_conv1d_fn is None:
182
+ x = self.act(self.conv1d(x)[..., :seqlen])
183
+ else:
184
+ assert self.activation in ["silu", "swish"]
185
+ x = causal_conv1d_fn(
186
+ x=x,
187
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
188
+ bias=self.conv1d.bias,
189
+ activation=self.activation,
190
+ )
191
+
192
+ # We're careful here about the layout, to avoid extra transposes.
193
+ # We want dt to have d as the slowest moving dimension
194
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
195
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
196
+ dt, B, C = torch.split(
197
+ x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
198
+ )
199
+ dt = self.dt_proj.weight @ dt.t()
200
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
201
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
202
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
203
+ assert self.activation in ["silu", "swish"]
204
+ y = selective_scan_fn(
205
+ x,
206
+ dt,
207
+ A,
208
+ B,
209
+ C,
210
+ self.D.float(),
211
+ z=z,
212
+ delta_bias=self.dt_proj.bias.float(),
213
+ delta_softplus=True,
214
+ return_last_state=ssm_state is not None,
215
+ )
216
+ if ssm_state is not None:
217
+ y, last_state = y
218
+ ssm_state.copy_(last_state)
219
+ y = rearrange(y, "b d l -> b l d")
220
+ out = self.out_proj(y)
221
+ return out
222
+
223
+ def step(self, hidden_states, conv_state, ssm_state):
224
+ dtype = hidden_states.dtype
225
+ assert (
226
+ hidden_states.shape[1] == 1
227
+ ), "Only support decoding with 1 token at a time for now"
228
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
229
+ x, z = xz.chunk(2, dim=-1) # (B D)
230
+
231
+ # Conv step
232
+ if causal_conv1d_update is None:
233
+ conv_state.copy_(
234
+ torch.roll(conv_state, shifts=-1, dims=-1)
235
+ ) # Update state (B D W)
236
+ conv_state[:, :, -1] = x
237
+ x = torch.sum(
238
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
239
+ ) # (B D)
240
+ if self.conv1d.bias is not None:
241
+ x = x + self.conv1d.bias
242
+ x = self.act(x).to(dtype=dtype)
243
+ else:
244
+ x = causal_conv1d_update(
245
+ x,
246
+ conv_state,
247
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
248
+ self.conv1d.bias,
249
+ self.activation,
250
+ )
251
+
252
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
253
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
254
+ # Don't add dt_bias here
255
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
256
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
257
+
258
+ # SSM step
259
+ if selective_state_update is None:
260
+ # Discretize A and B
261
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
262
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
263
+ dB = torch.einsum("bd,bn->bdn", dt, B)
264
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
265
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
266
+ y = y + self.D.to(dtype) * x
267
+ y = y * self.act(z) # (B D)
268
+ else:
269
+ y = selective_state_update(
270
+ ssm_state,
271
+ x,
272
+ dt,
273
+ A,
274
+ B,
275
+ C,
276
+ self.D,
277
+ z=z,
278
+ dt_bias=self.dt_proj.bias,
279
+ dt_softplus=True,
280
+ )
281
+
282
+ out = self.out_proj(y)
283
+ return out.unsqueeze(1), conv_state, ssm_state
284
+
285
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
286
+ device = self.out_proj.weight.device
287
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
288
+ conv_state = torch.zeros(
289
+ batch_size,
290
+ self.d_model * self.expand,
291
+ self.d_conv,
292
+ device=device,
293
+ dtype=conv_dtype,
294
+ )
295
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
296
+ # ssm_dtype = torch.float32
297
+ ssm_state = torch.zeros(
298
+ batch_size,
299
+ self.d_model * self.expand,
300
+ self.d_state,
301
+ device=device,
302
+ dtype=ssm_dtype,
303
+ )
304
+ return conv_state, ssm_state
305
+
306
+ def _get_states_from_cache(
307
+ self, inference_params, batch_size, initialize_states=False
308
+ ):
309
+ assert self.layer_idx is not None
310
+ if self.layer_idx not in inference_params.key_value_memory_dict:
311
+ batch_shape = (batch_size,)
312
+ conv_state = torch.zeros(
313
+ batch_size,
314
+ self.d_model * self.expand,
315
+ self.d_conv,
316
+ device=self.conv1d.weight.device,
317
+ dtype=self.conv1d.weight.dtype,
318
+ )
319
+ ssm_state = torch.zeros(
320
+ batch_size,
321
+ self.d_model * self.expand,
322
+ self.d_state,
323
+ device=self.dt_proj.weight.device,
324
+ dtype=self.dt_proj.weight.dtype,
325
+ # dtype=torch.float32,
326
+ )
327
+ inference_params.key_value_memory_dict[self.layer_idx] = (
328
+ conv_state,
329
+ ssm_state,
330
+ )
331
+ else:
332
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
333
+ self.layer_idx
334
+ ]
335
+ # TODO: What if batch size changes between generation, and we reuse the same states?
336
+ if initialize_states:
337
+ conv_state.zero_()
338
+ ssm_state.zero_()
339
+ return conv_state, ssm_state
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mha.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ try:
11
+ from flash_attn import flash_attn_with_kvcache
12
+ except ImportError:
13
+ flash_attn_with_kvcache = None
14
+
15
+ try:
16
+ from flash_attn.layers.rotary import RotaryEmbedding
17
+ except ImportError:
18
+ RotaryEmbedding = None
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn, causal_conv1d_update = None, None
24
+
25
+
26
+ def _update_kv_cache(kv, inference_params, layer_idx):
27
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
28
+ # Pre-allocate memory for key-values for inference.
29
+ num_heads, head_dim = kv.shape[-2:]
30
+ assert layer_idx in inference_params.key_value_memory_dict
31
+ kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
32
+ # Adjust key and value for inference
33
+ batch_start = inference_params.batch_size_offset
34
+ batch_end = batch_start + kv.shape[0]
35
+ sequence_start = inference_params.seqlen_offset
36
+ sequence_end = sequence_start + kv.shape[1]
37
+ assert batch_end <= kv_cache.shape[0]
38
+ assert sequence_end <= kv_cache.shape[1]
39
+ assert kv_cache is not None
40
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
41
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
42
+
43
+
44
+ class MHA(nn.Module):
45
+ """Multi-head self-attention and cross-attention"""
46
+
47
+ def __init__(
48
+ self,
49
+ embed_dim,
50
+ num_heads,
51
+ num_heads_kv=None,
52
+ head_dim=None, # If None, use embed_dim // num_heads
53
+ mlp_dim=0,
54
+ qkv_proj_bias=True,
55
+ out_proj_bias=True,
56
+ softmax_scale=None,
57
+ causal=False,
58
+ layer_idx=None,
59
+ d_conv=0,
60
+ rotary_emb_dim=0,
61
+ rotary_emb_base=10000.0,
62
+ rotary_emb_interleaved=False,
63
+ device=None,
64
+ dtype=None,
65
+ ) -> None:
66
+ """
67
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
68
+ return_residual: whether to return the input x along with the output. This is for
69
+ performance reason: for post-norm architecture, returning the input allows us
70
+ to fuse the backward of nn.Linear with the residual connection.
71
+ """
72
+ factory_kwargs = {"device": device, "dtype": dtype}
73
+ super().__init__()
74
+ self.embed_dim = embed_dim
75
+ self.layer_idx = layer_idx
76
+ self.d_conv = d_conv
77
+ self.rotary_emb_dim = rotary_emb_dim
78
+ self.softmax_scale = softmax_scale
79
+ self.causal = causal
80
+
81
+ self.num_heads = num_heads
82
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
83
+ assert (
84
+ self.num_heads % self.num_heads_kv == 0
85
+ ), "num_heads must be divisible by num_heads_kv"
86
+ if head_dim is None:
87
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
88
+ self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
89
+ self.mlp_dim = math.ceil(mlp_dim / 256) * 256
90
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
91
+ out_dim = self.head_dim * self.num_heads
92
+
93
+ if self.rotary_emb_dim > 0:
94
+ assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
95
+ self.rotary_emb = RotaryEmbedding(
96
+ self.rotary_emb_dim,
97
+ base=rotary_emb_base,
98
+ interleaved=rotary_emb_interleaved,
99
+ device=device,
100
+ )
101
+
102
+ self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
103
+ if self.d_conv > 0:
104
+ self.conv1d = nn.Conv1d(
105
+ qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
106
+ **factory_kwargs
107
+ )
108
+ self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
109
+
110
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
111
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
112
+ device = self.out_proj.weight.device
113
+ if self.d_conv > 0:
114
+ conv_state = torch.zeros(
115
+ batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
116
+ )
117
+ else:
118
+ conv_state = None
119
+ kv_cache = torch.empty(
120
+ batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
121
+ )
122
+ return kv_cache, conv_state
123
+
124
+ def _update_kv_cache(self, kv, inference_params):
125
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
126
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
127
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
128
+
129
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
130
+ """
131
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
132
+ q: (batch_size, seqlen_q, nheads, head_dim)
133
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
134
+ """
135
+ assert inference_params is not None and inference_params.seqlen_offset > 0
136
+ if self.rotary_emb_dim > 0:
137
+ self.rotary_emb._update_cos_sin_cache(
138
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
139
+ )
140
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
141
+ else:
142
+ rotary_cos, rotary_sin = None, None
143
+ batch = q.shape[0]
144
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
145
+ kv_cache = kv_cache[:batch]
146
+ cache_seqlens = (
147
+ inference_params.lengths_per_sample[:batch]
148
+ if inference_params.lengths_per_sample is not None
149
+ else inference_params.seqlen_offset
150
+ )
151
+ assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
152
+ context = flash_attn_with_kvcache(
153
+ q,
154
+ kv_cache[:, :, 0],
155
+ kv_cache[:, :, 1],
156
+ kv[:, :, 0],
157
+ kv[:, :, 1],
158
+ rotary_cos=rotary_cos,
159
+ rotary_sin=rotary_sin,
160
+ cache_seqlens=cache_seqlens,
161
+ softmax_scale=self.softmax_scale,
162
+ causal=self.causal,
163
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
164
+ )
165
+ return context
166
+
167
+ def _update_kvcache_attention(self, q, kv, inference_params):
168
+ """Write kv to inference_params, then do attention"""
169
+ if (
170
+ inference_params.seqlen_offset == 0
171
+ or flash_attn_with_kvcache is None
172
+ ):
173
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
174
+ kv = self._update_kv_cache(kv, inference_params)
175
+ k, v = kv.unbind(dim=-3)
176
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
177
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
178
+ return F.scaled_dot_product_attention(
179
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
180
+ ).transpose(1, 2)
181
+ else:
182
+ batch = q.shape[0]
183
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184
+ kv_cache = kv_cache[:batch]
185
+ cache_seqlens = (
186
+ inference_params.lengths_per_sample[:batch]
187
+ if inference_params.lengths_per_sample is not None
188
+ else inference_params.seqlen_offset
189
+ )
190
+ return flash_attn_with_kvcache(
191
+ q,
192
+ kv_cache[:, :, 0],
193
+ kv_cache[:, :, 1],
194
+ kv[:, :, 0],
195
+ kv[:, :, 1],
196
+ cache_seqlens=cache_seqlens,
197
+ softmax_scale=self.softmax_scale,
198
+ causal=self.causal,
199
+ )
200
+
201
+ def forward(self, x, inference_params=None):
202
+ """
203
+ Arguments:
204
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
205
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
206
+ is the is the sum of the sequence lengths in the batch.
207
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
208
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
209
+ """
210
+ if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
211
+ inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
212
+ x.shape[0], inference_params.max_seqlen, dtype=x.dtype
213
+ )
214
+ seqlen_offset = (
215
+ 0
216
+ if inference_params is None
217
+ else (
218
+ inference_params.lengths_per_sample
219
+ if inference_params.lengths_per_sample is not None
220
+ else inference_params.seqlen_offset
221
+ )
222
+ )
223
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
224
+ qkv = self.in_proj(x)
225
+ if self.mlp_dim > 0:
226
+ qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
227
+ x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
228
+ x_mlp = x_mlp_up * F.silu(x_mlp_gate)
229
+ if self.d_conv > 0:
230
+ # The inference code for conv1d is pretty messy, should clean it up
231
+ if (inference_params is None or inference_params.seqlen_offset == 0):
232
+ if causal_conv1d_fn is None:
233
+ qkv = rearrange(
234
+ self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
235
+ ).contiguous()
236
+ else:
237
+ qkv = causal_conv1d_fn(
238
+ qkv.transpose(1, 2),
239
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
240
+ self.conv1d.bias
241
+ ).transpose(1, 2)
242
+ if inference_params is not None:
243
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
244
+ # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
245
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
246
+ qkv_t = rearrange(qkv, "b l d -> b d l")
247
+ conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
248
+ else:
249
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
250
+ assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
251
+ qkv = qkv.squeeze(1)
252
+ # Conv step
253
+ if causal_conv1d_update is None:
254
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
255
+ conv_state[:, :, -1] = qkv
256
+ qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
257
+ if self.conv1d.bias is not None:
258
+ qkv = qkv + self.conv1d.bias
259
+ else:
260
+ qkv = causal_conv1d_update(
261
+ qkv,
262
+ conv_state,
263
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
264
+ self.conv1d.bias
265
+ )
266
+ qkv = qkv.unsqueeze(1)
267
+ q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
268
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
269
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
270
+ if (
271
+ inference_params is None
272
+ or inference_params.seqlen_offset == 0
273
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
274
+ ):
275
+ if self.rotary_emb_dim > 0:
276
+ q, kv = self.rotary_emb(
277
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
278
+ )
279
+ if inference_params is None:
280
+ k, v = kv.unbind(dim=-3)
281
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
282
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
283
+ context = F.scaled_dot_product_attention(
284
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
285
+ ).transpose(1, 2)
286
+ else:
287
+ context = self._update_kvcache_attention(q, kv, inference_params)
288
+ else:
289
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
290
+ context = rearrange(context, "... h d -> ... (h d)")
291
+ if self.mlp_dim > 0:
292
+ context = torch.cat([context, x_mlp], dim=-1)
293
+ out = self.out_proj(context)
294
+ return out
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mlp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class GatedMLP(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ hidden_features=None,
11
+ out_features=None,
12
+ activation=F.silu,
13
+ bias=False,
14
+ multiple_of=128,
15
+ device=None,
16
+ dtype=None,
17
+ ):
18
+ factory_kwargs = {"device": device, "dtype": dtype}
19
+ super().__init__()
20
+ out_features = out_features if out_features is not None else in_features
21
+ hidden_features = (
22
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
23
+ )
24
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
25
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
26
+ self.activation = activation
27
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
28
+
29
+ def forward(self, x):
30
+ y = self.fc1(x)
31
+ y, gate = y.chunk(2, dim=-1)
32
+ y = y * self.activation(gate)
33
+ y = self.fc2(y)
34
+ return y