Build
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py +14 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_bft6nicqkg6ni.abi3.so +3 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py +9 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/distributed_utils.py +144 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +326 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/config_mamba.py +18 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +338 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/block.py +107 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2.py +502 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2_simple.py +229 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba_simple.py +339 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mha.py +294 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mlp.py +34 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/ssd_minimal.py +111 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/__init__.py +0 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py +659 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/__init__.py +0 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/k_activations.py +169 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py +1166 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layernorm_gated.py +437 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py +389 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/softplus.py +15 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_bmm.py +262 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py +2012 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py +1884 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_state_passing.py +348 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/__init__.py +0 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/generation.py +390 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/hf.py +23 -0
- build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/torch.py +21 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/__init__.py +14 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_nmrmresto7zfi.abi3.so +3 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_ops.py +9 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/__init__.py +0 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/distributed_utils.py +144 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py +326 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/__init__.py +0 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/config_mamba.py +18 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py +338 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/__init__.py +0 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/block.py +107 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba2.py +502 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba2_simple.py +229 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba_simple.py +339 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mha.py +294 -0
- build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mlp.py +34 -0
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "2.2.4"
|
2 |
+
|
3 |
+
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
4 |
+
from .modules.mamba_simple import Mamba
|
5 |
+
from .modules.mamba2 import Mamba2
|
6 |
+
from .models.mixer_seq_simple import MambaLMHeadModel
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"selective_scan_fn",
|
10 |
+
"mamba_inner_fn",
|
11 |
+
"Mamba",
|
12 |
+
"Mamba2",
|
13 |
+
"MambaLMHeadModel",
|
14 |
+
]
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_mamba_ssm_bft6nicqkg6ni.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c0e8bc801359703c8d092b7c8c9906bd59c083d94e6778b621ba709d79fff5a0
|
3 |
+
size 258973648
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _mamba_ssm_bft6nicqkg6ni
|
3 |
+
ops = torch.ops._mamba_ssm_bft6nicqkg6ni
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_mamba_ssm_bft6nicqkg6ni::{op_name}"
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/distributed_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.distributed import ProcessGroup
|
6 |
+
|
7 |
+
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
8 |
+
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
9 |
+
# version of PyTorch. The following 4 lines are for backward compatibility with
|
10 |
+
# older PyTorch.
|
11 |
+
if "all_gather_into_tensor" not in dir(torch.distributed):
|
12 |
+
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
13 |
+
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
14 |
+
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
15 |
+
|
16 |
+
|
17 |
+
# Raw operation, does not support autograd, but does support async
|
18 |
+
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
19 |
+
world_size = torch.distributed.get_world_size(process_group)
|
20 |
+
output = torch.empty(
|
21 |
+
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
22 |
+
)
|
23 |
+
handle = torch.distributed.all_gather_into_tensor(
|
24 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
25 |
+
)
|
26 |
+
return output, handle
|
27 |
+
|
28 |
+
|
29 |
+
# Raw operation, does not support autograd, but does support async
|
30 |
+
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
31 |
+
world_size = torch.distributed.get_world_size(process_group)
|
32 |
+
assert input_.shape[0] % world_size == 0
|
33 |
+
output = torch.empty(
|
34 |
+
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
35 |
+
)
|
36 |
+
handle = torch.distributed.reduce_scatter_tensor(
|
37 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
38 |
+
)
|
39 |
+
return output, handle
|
40 |
+
|
41 |
+
|
42 |
+
# Raw operation, does not support autograd, but does support async
|
43 |
+
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
44 |
+
input_ = input_.contiguous()
|
45 |
+
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
46 |
+
return input_, handle
|
47 |
+
|
48 |
+
|
49 |
+
class AllGatherFunc(torch.autograd.Function):
|
50 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
54 |
+
ctx.process_group = process_group
|
55 |
+
output, _ = all_gather_raw(input_, process_group)
|
56 |
+
return output
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def backward(ctx, grad_output: Tensor):
|
60 |
+
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
61 |
+
return grad_input, None
|
62 |
+
|
63 |
+
|
64 |
+
# Supports autograd, but does not support async
|
65 |
+
all_gather = AllGatherFunc.apply
|
66 |
+
|
67 |
+
|
68 |
+
class ReduceScatterFunc(torch.autograd.Function):
|
69 |
+
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
73 |
+
ctx.process_group = process_group
|
74 |
+
output, _ = reduce_scatter_raw(input_, process_group)
|
75 |
+
return output
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def backward(ctx, grad_output: Tensor):
|
79 |
+
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
80 |
+
return grad_input, None
|
81 |
+
|
82 |
+
|
83 |
+
# Supports autograd, but does not support async
|
84 |
+
reduce_scatter = ReduceScatterFunc.apply
|
85 |
+
|
86 |
+
|
87 |
+
class AllReduceFunc(torch.autograd.Function):
|
88 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
92 |
+
ctx.process_group = process_group
|
93 |
+
output, _ = all_reduce_raw(input_, process_group)
|
94 |
+
return output
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def backward(ctx, grad_output: Tensor):
|
98 |
+
return grad_output, None
|
99 |
+
|
100 |
+
|
101 |
+
# Supports autograd, but does not support async
|
102 |
+
all_reduce = AllReduceFunc.apply
|
103 |
+
|
104 |
+
|
105 |
+
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
106 |
+
# We want to iterate over parameters with _shared_params=True in the same order,
|
107 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
108 |
+
pamams_shared = {
|
109 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
|
110 |
+
}
|
111 |
+
for _, p in sorted(pamams_shared.items()):
|
112 |
+
with torch.no_grad():
|
113 |
+
# Broadcast needs src to be global rank, not group rank
|
114 |
+
torch.distributed.broadcast(
|
115 |
+
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
120 |
+
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
121 |
+
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
122 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
123 |
+
params_seqparallel = {
|
124 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
|
125 |
+
}
|
126 |
+
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
127 |
+
if grads:
|
128 |
+
with torch.no_grad():
|
129 |
+
coalesced = torch._utils._flatten_dense_tensors(grads)
|
130 |
+
torch.distributed.all_reduce(coalesced, group=process_group)
|
131 |
+
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
132 |
+
buf.copy_(synced)
|
133 |
+
|
134 |
+
|
135 |
+
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
136 |
+
"""Get the dim for the local rank derived from splitting dim on world_size processes.
|
137 |
+
|
138 |
+
The split may not be even across the world_size processes.
|
139 |
+
"""
|
140 |
+
multiple = dim // multiple_of
|
141 |
+
div = multiple // world_size
|
142 |
+
mod = multiple % world_size
|
143 |
+
local_multiple = div + int(local_rank < mod)
|
144 |
+
return local_multiple * multiple_of
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.distributed import ProcessGroup
|
10 |
+
from ..utils.torch import custom_bwd, custom_fwd
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from ..distributed.distributed_utils import (
|
15 |
+
all_gather_raw,
|
16 |
+
all_reduce,
|
17 |
+
all_reduce_raw,
|
18 |
+
reduce_scatter,
|
19 |
+
reduce_scatter_raw,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class ParallelLinearFunc(torch.autograd.Function):
|
24 |
+
@staticmethod
|
25 |
+
@custom_fwd
|
26 |
+
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
27 |
+
"""
|
28 |
+
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
29 |
+
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
30 |
+
"""
|
31 |
+
ctx.compute_weight_gradient = weight.requires_grad
|
32 |
+
ctx.process_group = process_group
|
33 |
+
ctx.sequence_parallel = sequence_parallel
|
34 |
+
|
35 |
+
if torch.is_autocast_enabled():
|
36 |
+
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
37 |
+
x = x.contiguous()
|
38 |
+
if process_group is not None and sequence_parallel:
|
39 |
+
# We want to kick off the all_gather early, before weight dtype conversion
|
40 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
41 |
+
else:
|
42 |
+
total_x = x
|
43 |
+
|
44 |
+
if torch.is_autocast_enabled():
|
45 |
+
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
46 |
+
bias = (
|
47 |
+
bias.to(dtype=torch.get_autocast_gpu_dtype())
|
48 |
+
if bias is not None
|
49 |
+
else None
|
50 |
+
)
|
51 |
+
weight = weight.contiguous()
|
52 |
+
if process_group is not None and sequence_parallel:
|
53 |
+
handle_x.wait()
|
54 |
+
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
55 |
+
batch_dim = batch_shape.numel()
|
56 |
+
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
57 |
+
output = F.linear(total_x, weight, bias)
|
58 |
+
if ctx.compute_weight_gradient:
|
59 |
+
ctx.save_for_backward(x, weight)
|
60 |
+
else:
|
61 |
+
ctx.save_for_backward(weight)
|
62 |
+
return output
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
@custom_bwd
|
66 |
+
def backward(ctx, grad_output):
|
67 |
+
grad_output = grad_output.contiguous()
|
68 |
+
process_group = ctx.process_group
|
69 |
+
sequence_parallel = ctx.sequence_parallel
|
70 |
+
if ctx.compute_weight_gradient:
|
71 |
+
x, weight = ctx.saved_tensors
|
72 |
+
if process_group is not None and sequence_parallel:
|
73 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
74 |
+
else:
|
75 |
+
total_x = x
|
76 |
+
else:
|
77 |
+
(weight,) = ctx.saved_tensors
|
78 |
+
total_x = None
|
79 |
+
batch_shape = grad_output.shape[:-1]
|
80 |
+
batch_dim = batch_shape.numel()
|
81 |
+
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
82 |
+
if ctx.needs_input_grad[0]:
|
83 |
+
grad_input = F.linear(grad_output, weight.t())
|
84 |
+
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
85 |
+
if process_group is not None:
|
86 |
+
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
87 |
+
grad_input, handle_grad_input = reduce_fn(
|
88 |
+
grad_input, process_group, async_op=True
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
grad_input = None
|
92 |
+
if ctx.needs_input_grad[1]:
|
93 |
+
assert ctx.compute_weight_gradient
|
94 |
+
if process_group is not None and sequence_parallel:
|
95 |
+
handle_x.wait()
|
96 |
+
grad_weight = torch.einsum(
|
97 |
+
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
grad_weight = None
|
101 |
+
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
102 |
+
if process_group is not None and ctx.needs_input_grad[0]:
|
103 |
+
handle_grad_input.wait()
|
104 |
+
return grad_input, grad_weight, grad_bias, None, None
|
105 |
+
|
106 |
+
|
107 |
+
def parallel_linear_func(
|
108 |
+
x: Tensor,
|
109 |
+
weight: Tensor,
|
110 |
+
bias: Optional[Tensor] = None,
|
111 |
+
process_group: Optional[ProcessGroup] = None,
|
112 |
+
sequence_parallel: bool = True,
|
113 |
+
):
|
114 |
+
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
115 |
+
|
116 |
+
|
117 |
+
class ColumnParallelLinear(nn.Linear):
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
in_features: int,
|
121 |
+
out_features: int,
|
122 |
+
process_group: ProcessGroup,
|
123 |
+
bias: bool = True,
|
124 |
+
sequence_parallel=True,
|
125 |
+
multiple_of=1,
|
126 |
+
device=None,
|
127 |
+
dtype=None,
|
128 |
+
) -> None:
|
129 |
+
world_size = torch.distributed.get_world_size(process_group)
|
130 |
+
if out_features % multiple_of:
|
131 |
+
raise ValueError(
|
132 |
+
f"out_features ({out_features}) must be a multiple of {multiple_of}"
|
133 |
+
)
|
134 |
+
multiple = out_features // multiple_of
|
135 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
136 |
+
div = multiple // world_size
|
137 |
+
mod = multiple % world_size
|
138 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
139 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
140 |
+
super().__init__(
|
141 |
+
in_features,
|
142 |
+
local_multiple * multiple_of,
|
143 |
+
bias=bias,
|
144 |
+
device=device,
|
145 |
+
dtype=dtype,
|
146 |
+
)
|
147 |
+
self.process_group = process_group
|
148 |
+
self.sequence_parallel = sequence_parallel
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
152 |
+
# we do an all_gather of x before doing the matmul.
|
153 |
+
# If not, then the input is already gathered.
|
154 |
+
return parallel_linear_func(
|
155 |
+
x,
|
156 |
+
self.weight,
|
157 |
+
self.bias,
|
158 |
+
process_group=self.process_group,
|
159 |
+
sequence_parallel=self.sequence_parallel,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
class RowParallelLinear(nn.Linear):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
in_features: int,
|
167 |
+
out_features: int,
|
168 |
+
process_group: ProcessGroup,
|
169 |
+
bias: bool = True,
|
170 |
+
sequence_parallel=True,
|
171 |
+
multiple_of=1,
|
172 |
+
device=None,
|
173 |
+
dtype=None,
|
174 |
+
) -> None:
|
175 |
+
world_size = torch.distributed.get_world_size(process_group)
|
176 |
+
rank = torch.distributed.get_rank(process_group)
|
177 |
+
if in_features % multiple_of:
|
178 |
+
raise ValueError(
|
179 |
+
f"in_features ({in_features}) must be a multiple of {multiple_of}"
|
180 |
+
)
|
181 |
+
multiple = in_features // multiple_of
|
182 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
183 |
+
div = multiple // world_size
|
184 |
+
mod = multiple % world_size
|
185 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
186 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
187 |
+
# Only rank 0 will have bias
|
188 |
+
super().__init__(
|
189 |
+
local_multiple * multiple_of,
|
190 |
+
out_features,
|
191 |
+
bias=bias and rank == 0,
|
192 |
+
device=device,
|
193 |
+
dtype=dtype,
|
194 |
+
)
|
195 |
+
self.process_group = process_group
|
196 |
+
self.sequence_parallel = sequence_parallel
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
"""
|
200 |
+
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
201 |
+
a reduce_scatter of the result.
|
202 |
+
"""
|
203 |
+
out = parallel_linear_func(x, self.weight, self.bias)
|
204 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
205 |
+
return reduce_fn(out, self.process_group)
|
206 |
+
|
207 |
+
|
208 |
+
class VocabParallelEmbedding(nn.Embedding):
|
209 |
+
def __init__(
|
210 |
+
self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
|
211 |
+
):
|
212 |
+
self.process_group = process_group
|
213 |
+
if process_group is not None:
|
214 |
+
world_size = torch.distributed.get_world_size(process_group)
|
215 |
+
if num_embeddings % world_size != 0:
|
216 |
+
raise ValueError(
|
217 |
+
f"num_embeddings ({num_embeddings}) must be divisible by "
|
218 |
+
f"world_size ({world_size})"
|
219 |
+
)
|
220 |
+
if world_size > 1 and padding_idx is not None:
|
221 |
+
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
222 |
+
else:
|
223 |
+
world_size = 1
|
224 |
+
super().__init__(
|
225 |
+
num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
|
226 |
+
)
|
227 |
+
|
228 |
+
def forward(self, input: Tensor) -> Tensor:
|
229 |
+
if self.process_group is None:
|
230 |
+
return super().forward(input)
|
231 |
+
else:
|
232 |
+
rank = torch.distributed.get_rank(self.process_group)
|
233 |
+
vocab_size = self.num_embeddings
|
234 |
+
vocab_start_index, vocab_end_index = (
|
235 |
+
rank * vocab_size,
|
236 |
+
(rank + 1) * vocab_size,
|
237 |
+
)
|
238 |
+
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
239 |
+
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
240 |
+
input = input - vocab_start_index
|
241 |
+
input[input_ids_mask] = 0
|
242 |
+
embeddings = super().forward(input)
|
243 |
+
embeddings[input_ids_mask] = 0.0
|
244 |
+
return embeddings
|
245 |
+
|
246 |
+
|
247 |
+
class ColumnParallelEmbedding(nn.Embedding):
|
248 |
+
def __init__(
|
249 |
+
self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
|
250 |
+
):
|
251 |
+
self.process_group = process_group
|
252 |
+
if process_group is not None:
|
253 |
+
world_size = torch.distributed.get_world_size(process_group)
|
254 |
+
if embedding_dim % world_size != 0:
|
255 |
+
raise ValueError(
|
256 |
+
f"embedding_dim ({embedding_dim}) must be divisible by "
|
257 |
+
f"world_size ({world_size})"
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
world_size = 1
|
261 |
+
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
262 |
+
|
263 |
+
|
264 |
+
class ParallelEmbeddings(nn.Module):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
embed_dim,
|
268 |
+
vocab_size,
|
269 |
+
max_position_embeddings,
|
270 |
+
process_group,
|
271 |
+
padding_idx=None,
|
272 |
+
sequence_parallel=True,
|
273 |
+
device=None,
|
274 |
+
dtype=None,
|
275 |
+
):
|
276 |
+
"""
|
277 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
278 |
+
"""
|
279 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
280 |
+
super().__init__()
|
281 |
+
self.process_group = process_group
|
282 |
+
self.sequence_parallel = sequence_parallel
|
283 |
+
self.word_embeddings = VocabParallelEmbedding(
|
284 |
+
vocab_size,
|
285 |
+
embed_dim,
|
286 |
+
padding_idx=padding_idx,
|
287 |
+
process_group=process_group,
|
288 |
+
**factory_kwargs,
|
289 |
+
)
|
290 |
+
self.max_position_embeddings = max_position_embeddings
|
291 |
+
if self.max_position_embeddings > 0:
|
292 |
+
self.position_embeddings = ColumnParallelEmbedding(
|
293 |
+
max_position_embeddings,
|
294 |
+
embed_dim,
|
295 |
+
process_group=process_group,
|
296 |
+
**factory_kwargs,
|
297 |
+
)
|
298 |
+
|
299 |
+
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
300 |
+
"""
|
301 |
+
input_ids: (batch, seqlen)
|
302 |
+
position_ids: (batch, seqlen)
|
303 |
+
"""
|
304 |
+
batch_size, seqlen = input_ids.shape
|
305 |
+
world_size = torch.distributed.get_world_size(self.process_group)
|
306 |
+
embeddings = self.word_embeddings(input_ids)
|
307 |
+
if self.max_position_embeddings > 0:
|
308 |
+
if position_ids is None:
|
309 |
+
position_ids = torch.arange(
|
310 |
+
seqlen, dtype=torch.long, device=input_ids.device
|
311 |
+
)
|
312 |
+
position_embeddings = self.position_embeddings(position_ids)
|
313 |
+
if world_size <= 1:
|
314 |
+
embeddings = embeddings + position_embeddings
|
315 |
+
else:
|
316 |
+
partition_dim = self.position_embeddings.embedding_dim
|
317 |
+
rank = torch.distributed.get_rank(self.process_group)
|
318 |
+
embeddings[
|
319 |
+
..., rank * partition_dim : (rank + 1) * partition_dim
|
320 |
+
] += position_embeddings
|
321 |
+
if combine_batch_seqlen_dim:
|
322 |
+
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
323 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
324 |
+
return (
|
325 |
+
embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
326 |
+
)
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/config_mamba.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class MambaConfig:
|
6 |
+
|
7 |
+
d_model: int = 2560
|
8 |
+
d_intermediate: int = 0
|
9 |
+
n_layer: int = 64
|
10 |
+
vocab_size: int = 50277
|
11 |
+
ssm_cfg: dict = field(default_factory=dict)
|
12 |
+
attn_layer_idx: list = field(default_factory=list)
|
13 |
+
attn_cfg: dict = field(default_factory=dict)
|
14 |
+
rms_norm: bool = True
|
15 |
+
residual_in_fp32: bool = True
|
16 |
+
fused_add_norm: bool = True
|
17 |
+
pad_vocab_size_multiple: int = 8
|
18 |
+
tie_embeddings: bool = True
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from collections import namedtuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from .config_mamba import MambaConfig
|
15 |
+
from ..modules.mamba_simple import Mamba
|
16 |
+
from ..modules.mamba2 import Mamba2
|
17 |
+
from ..modules.mha import MHA
|
18 |
+
from ..modules.mlp import GatedMLP
|
19 |
+
from ..modules.block import Block
|
20 |
+
from ..utils.generation import GenerationMixin
|
21 |
+
from ..utils.hf import load_config_hf, load_state_dict_hf
|
22 |
+
|
23 |
+
try:
|
24 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
25 |
+
except ImportError:
|
26 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
27 |
+
|
28 |
+
|
29 |
+
def create_block(
|
30 |
+
d_model,
|
31 |
+
d_intermediate,
|
32 |
+
ssm_cfg=None,
|
33 |
+
attn_layer_idx=None,
|
34 |
+
attn_cfg=None,
|
35 |
+
norm_epsilon=1e-5,
|
36 |
+
rms_norm=False,
|
37 |
+
residual_in_fp32=False,
|
38 |
+
fused_add_norm=False,
|
39 |
+
layer_idx=None,
|
40 |
+
device=None,
|
41 |
+
dtype=None,
|
42 |
+
):
|
43 |
+
if ssm_cfg is None:
|
44 |
+
ssm_cfg = {}
|
45 |
+
if attn_layer_idx is None:
|
46 |
+
attn_layer_idx = []
|
47 |
+
if attn_cfg is None:
|
48 |
+
attn_cfg = {}
|
49 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
50 |
+
if layer_idx not in attn_layer_idx:
|
51 |
+
# Create a copy of the config to modify
|
52 |
+
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
53 |
+
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
54 |
+
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
55 |
+
raise ValueError(
|
56 |
+
f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
|
57 |
+
)
|
58 |
+
mixer_cls = partial(
|
59 |
+
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
60 |
+
layer_idx=layer_idx,
|
61 |
+
**ssm_cfg,
|
62 |
+
**factory_kwargs,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
66 |
+
norm_cls = partial(
|
67 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
68 |
+
)
|
69 |
+
if d_intermediate == 0:
|
70 |
+
mlp_cls = nn.Identity
|
71 |
+
else:
|
72 |
+
mlp_cls = partial(
|
73 |
+
GatedMLP,
|
74 |
+
hidden_features=d_intermediate,
|
75 |
+
out_features=d_model,
|
76 |
+
**factory_kwargs,
|
77 |
+
)
|
78 |
+
block = Block(
|
79 |
+
d_model,
|
80 |
+
mixer_cls,
|
81 |
+
mlp_cls,
|
82 |
+
norm_cls=norm_cls,
|
83 |
+
fused_add_norm=fused_add_norm,
|
84 |
+
residual_in_fp32=residual_in_fp32,
|
85 |
+
)
|
86 |
+
block.layer_idx = layer_idx
|
87 |
+
return block
|
88 |
+
|
89 |
+
|
90 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
91 |
+
def _init_weights(
|
92 |
+
module,
|
93 |
+
n_layer,
|
94 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
95 |
+
rescale_prenorm_residual=True,
|
96 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
97 |
+
):
|
98 |
+
if isinstance(module, nn.Linear):
|
99 |
+
if module.bias is not None:
|
100 |
+
if not getattr(module.bias, "_no_reinit", False):
|
101 |
+
nn.init.zeros_(module.bias)
|
102 |
+
elif isinstance(module, nn.Embedding):
|
103 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
104 |
+
|
105 |
+
if rescale_prenorm_residual:
|
106 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
107 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
108 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
109 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
110 |
+
#
|
111 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
112 |
+
for name, p in module.named_parameters():
|
113 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
114 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
115 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
116 |
+
# We need to reinit p since this code could be called multiple times
|
117 |
+
# Having just p *= scale would repeatedly scale it down
|
118 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
119 |
+
with torch.no_grad():
|
120 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
121 |
+
|
122 |
+
|
123 |
+
class MixerModel(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
d_model: int,
|
127 |
+
n_layer: int,
|
128 |
+
d_intermediate: int,
|
129 |
+
vocab_size: int,
|
130 |
+
ssm_cfg=None,
|
131 |
+
attn_layer_idx=None,
|
132 |
+
attn_cfg=None,
|
133 |
+
norm_epsilon: float = 1e-5,
|
134 |
+
rms_norm: bool = False,
|
135 |
+
initializer_cfg=None,
|
136 |
+
fused_add_norm=False,
|
137 |
+
residual_in_fp32=False,
|
138 |
+
device=None,
|
139 |
+
dtype=None,
|
140 |
+
) -> None:
|
141 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
142 |
+
super().__init__()
|
143 |
+
self.residual_in_fp32 = residual_in_fp32
|
144 |
+
|
145 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
146 |
+
|
147 |
+
# We change the order of residual and layer norm:
|
148 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
149 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
150 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
151 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
152 |
+
self.fused_add_norm = fused_add_norm
|
153 |
+
if self.fused_add_norm:
|
154 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
155 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
156 |
+
|
157 |
+
self.layers = nn.ModuleList(
|
158 |
+
[
|
159 |
+
create_block(
|
160 |
+
d_model,
|
161 |
+
d_intermediate=d_intermediate,
|
162 |
+
ssm_cfg=ssm_cfg,
|
163 |
+
attn_layer_idx=attn_layer_idx,
|
164 |
+
attn_cfg=attn_cfg,
|
165 |
+
norm_epsilon=norm_epsilon,
|
166 |
+
rms_norm=rms_norm,
|
167 |
+
residual_in_fp32=residual_in_fp32,
|
168 |
+
fused_add_norm=fused_add_norm,
|
169 |
+
layer_idx=i,
|
170 |
+
**factory_kwargs,
|
171 |
+
)
|
172 |
+
for i in range(n_layer)
|
173 |
+
]
|
174 |
+
)
|
175 |
+
|
176 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
177 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
178 |
+
)
|
179 |
+
|
180 |
+
self.apply(
|
181 |
+
partial(
|
182 |
+
_init_weights,
|
183 |
+
n_layer=n_layer,
|
184 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
185 |
+
n_residuals_per_layer=(
|
186 |
+
1 if d_intermediate == 0 else 2
|
187 |
+
), # 2 if we have MLP
|
188 |
+
)
|
189 |
+
)
|
190 |
+
|
191 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
192 |
+
return {
|
193 |
+
i: layer.allocate_inference_cache(
|
194 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
195 |
+
)
|
196 |
+
for i, layer in enumerate(self.layers)
|
197 |
+
}
|
198 |
+
|
199 |
+
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
200 |
+
hidden_states = self.embedding(input_ids)
|
201 |
+
residual = None
|
202 |
+
for layer in self.layers:
|
203 |
+
hidden_states, residual = layer(
|
204 |
+
hidden_states,
|
205 |
+
residual,
|
206 |
+
inference_params=inference_params,
|
207 |
+
**mixer_kwargs,
|
208 |
+
)
|
209 |
+
if not self.fused_add_norm:
|
210 |
+
residual = (
|
211 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
212 |
+
)
|
213 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
214 |
+
else:
|
215 |
+
# Set prenorm=False here since we don't need the residual
|
216 |
+
hidden_states = layer_norm_fn(
|
217 |
+
hidden_states,
|
218 |
+
self.norm_f.weight,
|
219 |
+
self.norm_f.bias,
|
220 |
+
eps=self.norm_f.eps,
|
221 |
+
residual=residual,
|
222 |
+
prenorm=False,
|
223 |
+
residual_in_fp32=self.residual_in_fp32,
|
224 |
+
is_rms_norm=isinstance(self.norm_f, RMSNorm),
|
225 |
+
)
|
226 |
+
return hidden_states
|
227 |
+
|
228 |
+
|
229 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
config: MambaConfig,
|
234 |
+
initializer_cfg=None,
|
235 |
+
device=None,
|
236 |
+
dtype=None,
|
237 |
+
) -> None:
|
238 |
+
self.config = config
|
239 |
+
d_model = config.d_model
|
240 |
+
n_layer = config.n_layer
|
241 |
+
d_intermediate = config.d_intermediate
|
242 |
+
vocab_size = config.vocab_size
|
243 |
+
ssm_cfg = config.ssm_cfg
|
244 |
+
attn_layer_idx = config.attn_layer_idx
|
245 |
+
attn_cfg = config.attn_cfg
|
246 |
+
rms_norm = config.rms_norm
|
247 |
+
residual_in_fp32 = config.residual_in_fp32
|
248 |
+
fused_add_norm = config.fused_add_norm
|
249 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
250 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
251 |
+
|
252 |
+
super().__init__()
|
253 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
254 |
+
vocab_size += pad_vocab_size_multiple - (
|
255 |
+
vocab_size % pad_vocab_size_multiple
|
256 |
+
)
|
257 |
+
self.backbone = MixerModel(
|
258 |
+
d_model=d_model,
|
259 |
+
n_layer=n_layer,
|
260 |
+
d_intermediate=d_intermediate,
|
261 |
+
vocab_size=vocab_size,
|
262 |
+
ssm_cfg=ssm_cfg,
|
263 |
+
attn_layer_idx=attn_layer_idx,
|
264 |
+
attn_cfg=attn_cfg,
|
265 |
+
rms_norm=rms_norm,
|
266 |
+
initializer_cfg=initializer_cfg,
|
267 |
+
fused_add_norm=fused_add_norm,
|
268 |
+
residual_in_fp32=residual_in_fp32,
|
269 |
+
**factory_kwargs,
|
270 |
+
)
|
271 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
272 |
+
|
273 |
+
# Initialize weights and apply final processing
|
274 |
+
self.apply(
|
275 |
+
partial(
|
276 |
+
_init_weights,
|
277 |
+
n_layer=n_layer,
|
278 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
279 |
+
)
|
280 |
+
)
|
281 |
+
self.tie_weights()
|
282 |
+
|
283 |
+
def tie_weights(self):
|
284 |
+
if self.config.tie_embeddings:
|
285 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
286 |
+
|
287 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
288 |
+
return self.backbone.allocate_inference_cache(
|
289 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(
|
293 |
+
self,
|
294 |
+
input_ids,
|
295 |
+
position_ids=None,
|
296 |
+
inference_params=None,
|
297 |
+
num_last_tokens=0,
|
298 |
+
**mixer_kwargs,
|
299 |
+
):
|
300 |
+
"""
|
301 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
302 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
303 |
+
"""
|
304 |
+
hidden_states = self.backbone(
|
305 |
+
input_ids, inference_params=inference_params, **mixer_kwargs
|
306 |
+
)
|
307 |
+
if num_last_tokens > 0:
|
308 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
309 |
+
lm_logits = self.lm_head(hidden_states)
|
310 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
311 |
+
return CausalLMOutput(logits=lm_logits)
|
312 |
+
|
313 |
+
@classmethod
|
314 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
315 |
+
config_data = load_config_hf(pretrained_model_name)
|
316 |
+
config = MambaConfig(**config_data)
|
317 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
318 |
+
model.load_state_dict(
|
319 |
+
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
320 |
+
)
|
321 |
+
return model
|
322 |
+
|
323 |
+
def save_pretrained(self, save_directory):
|
324 |
+
"""
|
325 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
326 |
+
Save the model and its configuration file to a directory.
|
327 |
+
"""
|
328 |
+
# Ensure save_directory exists
|
329 |
+
os.makedirs(save_directory, exist_ok=True)
|
330 |
+
|
331 |
+
# Save the model's state_dict
|
332 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
333 |
+
torch.save(self.state_dict(), model_path)
|
334 |
+
|
335 |
+
# Save the configuration of the model
|
336 |
+
config_path = os.path.join(save_directory, "config.json")
|
337 |
+
with open(config_path, "w") as f:
|
338 |
+
json.dump(self.config.__dict__, f, indent=4)
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/block.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
|
7 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
8 |
+
|
9 |
+
|
10 |
+
class Block(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
dim,
|
14 |
+
mixer_cls,
|
15 |
+
mlp_cls,
|
16 |
+
norm_cls=nn.LayerNorm,
|
17 |
+
fused_add_norm=False,
|
18 |
+
residual_in_fp32=False,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
22 |
+
|
23 |
+
This Block has a slightly different structure compared to a regular
|
24 |
+
prenorm Transformer block.
|
25 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
26 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
27 |
+
Here we have: Add -> LN -> Mixer, returning both
|
28 |
+
the hidden_states (output of the mixer) and the residual.
|
29 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
30 |
+
The residual needs to be provided (except for the very first block).
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
self.residual_in_fp32 = residual_in_fp32
|
34 |
+
self.fused_add_norm = fused_add_norm
|
35 |
+
self.norm = norm_cls(dim)
|
36 |
+
self.mixer = mixer_cls(dim)
|
37 |
+
if mlp_cls is not nn.Identity:
|
38 |
+
self.norm2 = norm_cls(dim)
|
39 |
+
self.mlp = mlp_cls(dim)
|
40 |
+
else:
|
41 |
+
self.mlp = None
|
42 |
+
if self.fused_add_norm:
|
43 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
44 |
+
assert isinstance(
|
45 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
46 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
47 |
+
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
hidden_states: Tensor,
|
51 |
+
residual: Optional[Tensor] = None,
|
52 |
+
inference_params=None,
|
53 |
+
**mixer_kwargs
|
54 |
+
):
|
55 |
+
r"""Pass the input through the encoder layer.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
hidden_states: the sequence to the encoder layer (required).
|
59 |
+
residual: hidden_states = Mixer(LN(residual))
|
60 |
+
"""
|
61 |
+
if not self.fused_add_norm:
|
62 |
+
residual = (
|
63 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
64 |
+
)
|
65 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
66 |
+
if self.residual_in_fp32:
|
67 |
+
residual = residual.to(torch.float32)
|
68 |
+
else:
|
69 |
+
hidden_states, residual = layer_norm_fn(
|
70 |
+
hidden_states,
|
71 |
+
self.norm.weight,
|
72 |
+
self.norm.bias,
|
73 |
+
residual=residual,
|
74 |
+
prenorm=True,
|
75 |
+
residual_in_fp32=self.residual_in_fp32,
|
76 |
+
eps=self.norm.eps,
|
77 |
+
is_rms_norm=isinstance(self.norm, RMSNorm),
|
78 |
+
)
|
79 |
+
hidden_states = self.mixer(
|
80 |
+
hidden_states, inference_params=inference_params, **mixer_kwargs
|
81 |
+
)
|
82 |
+
|
83 |
+
if self.mlp is not None:
|
84 |
+
if not self.fused_add_norm:
|
85 |
+
residual = hidden_states + residual
|
86 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
87 |
+
if self.residual_in_fp32:
|
88 |
+
residual = residual.to(torch.float32)
|
89 |
+
else:
|
90 |
+
hidden_states, residual = layer_norm_fn(
|
91 |
+
hidden_states,
|
92 |
+
self.norm2.weight,
|
93 |
+
self.norm2.bias,
|
94 |
+
residual=residual,
|
95 |
+
prenorm=True,
|
96 |
+
residual_in_fp32=self.residual_in_fp32,
|
97 |
+
eps=self.norm2.eps,
|
98 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
99 |
+
)
|
100 |
+
hidden_states = self.mlp(hidden_states)
|
101 |
+
|
102 |
+
return hidden_states, residual
|
103 |
+
|
104 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
105 |
+
return self.mixer.allocate_inference_cache(
|
106 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
107 |
+
)
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
try:
|
12 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
13 |
+
except ImportError:
|
14 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
15 |
+
|
16 |
+
try:
|
17 |
+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
|
18 |
+
except ImportError:
|
19 |
+
causal_conv1d_varlen_states = None
|
20 |
+
|
21 |
+
try:
|
22 |
+
from ..ops.triton.selective_state_update import selective_state_update
|
23 |
+
except ImportError:
|
24 |
+
selective_state_update = None
|
25 |
+
|
26 |
+
from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
27 |
+
|
28 |
+
from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
29 |
+
from ..distributed.distributed_utils import all_reduce, reduce_scatter
|
30 |
+
|
31 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
32 |
+
from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
33 |
+
|
34 |
+
from huggingface_hub import PyTorchModelHubMixin
|
35 |
+
|
36 |
+
|
37 |
+
class Mamba2(nn.Module, PyTorchModelHubMixin):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
d_model,
|
41 |
+
d_state=128,
|
42 |
+
d_conv=4,
|
43 |
+
conv_init=None,
|
44 |
+
expand=2,
|
45 |
+
headdim=64,
|
46 |
+
d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
|
47 |
+
ngroups=1,
|
48 |
+
A_init_range=(1, 16),
|
49 |
+
D_has_hdim=False,
|
50 |
+
rmsnorm=True,
|
51 |
+
norm_before_gate=False,
|
52 |
+
dt_min=0.001,
|
53 |
+
dt_max=0.1,
|
54 |
+
dt_init_floor=1e-4,
|
55 |
+
dt_limit=(0.0, float("inf")),
|
56 |
+
bias=False,
|
57 |
+
conv_bias=True,
|
58 |
+
# Fused kernel and sharding options
|
59 |
+
chunk_size=256,
|
60 |
+
use_mem_eff_path=True,
|
61 |
+
layer_idx=None, # Absorb kwarg for general module
|
62 |
+
process_group=None,
|
63 |
+
sequence_parallel=True,
|
64 |
+
device=None,
|
65 |
+
dtype=None,
|
66 |
+
):
|
67 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
68 |
+
super().__init__()
|
69 |
+
self.d_model = d_model
|
70 |
+
self.d_state = d_state
|
71 |
+
self.d_conv = d_conv
|
72 |
+
self.conv_init = conv_init
|
73 |
+
self.expand = expand
|
74 |
+
self.process_group = process_group
|
75 |
+
self.sequence_parallel = sequence_parallel
|
76 |
+
self.world_size = 1 if process_group is None else process_group.size()
|
77 |
+
self.local_rank = 0 if process_group is None else process_group.rank()
|
78 |
+
self.d_inner = (self.expand * self.d_model) // self.world_size
|
79 |
+
assert self.d_inner * self.world_size == self.expand * self.d_model
|
80 |
+
self.headdim = headdim
|
81 |
+
self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
|
82 |
+
assert ngroups % self.world_size == 0
|
83 |
+
self.ngroups = ngroups // self.world_size
|
84 |
+
assert self.d_ssm % self.headdim == 0
|
85 |
+
self.nheads = self.d_ssm // self.headdim
|
86 |
+
self.D_has_hdim = D_has_hdim
|
87 |
+
self.rmsnorm = rmsnorm
|
88 |
+
self.norm_before_gate = norm_before_gate
|
89 |
+
self.dt_limit = dt_limit
|
90 |
+
self.activation = "silu"
|
91 |
+
self.chunk_size = chunk_size
|
92 |
+
self.use_mem_eff_path = use_mem_eff_path
|
93 |
+
self.layer_idx = layer_idx
|
94 |
+
|
95 |
+
# Order: [z, x, B, C, dt]
|
96 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
97 |
+
if self.process_group is None:
|
98 |
+
self.in_proj = nn.Linear(
|
99 |
+
self.d_model, d_in_proj, bias=bias, **factory_kwargs
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
self.in_proj = ColumnParallelLinear(
|
103 |
+
self.d_model,
|
104 |
+
d_in_proj * self.world_size,
|
105 |
+
bias=bias,
|
106 |
+
process_group=self.process_group,
|
107 |
+
sequence_parallel=self.sequence_parallel,
|
108 |
+
**factory_kwargs,
|
109 |
+
)
|
110 |
+
|
111 |
+
conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
|
112 |
+
self.conv1d = nn.Conv1d(
|
113 |
+
in_channels=conv_dim,
|
114 |
+
out_channels=conv_dim,
|
115 |
+
bias=conv_bias,
|
116 |
+
kernel_size=d_conv,
|
117 |
+
groups=conv_dim,
|
118 |
+
padding=d_conv - 1,
|
119 |
+
**factory_kwargs,
|
120 |
+
)
|
121 |
+
if self.conv_init is not None:
|
122 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
123 |
+
|
124 |
+
self.act = nn.SiLU()
|
125 |
+
|
126 |
+
# Initialize log dt bias
|
127 |
+
dt = torch.exp(
|
128 |
+
torch.rand(self.nheads, **factory_kwargs)
|
129 |
+
* (math.log(dt_max) - math.log(dt_min))
|
130 |
+
+ math.log(dt_min)
|
131 |
+
)
|
132 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
133 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
134 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
135 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
136 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
137 |
+
# name.endswith("bias") in param_grouping.py
|
138 |
+
self.dt_bias._no_weight_decay = True
|
139 |
+
|
140 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
141 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
|
142 |
+
*A_init_range
|
143 |
+
)
|
144 |
+
A_log = torch.log(A).to(dtype=dtype)
|
145 |
+
self.A_log = nn.Parameter(A_log)
|
146 |
+
self.A_log._no_weight_decay = True
|
147 |
+
|
148 |
+
# D "skip" parameter
|
149 |
+
self.D = nn.Parameter(
|
150 |
+
torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
|
151 |
+
)
|
152 |
+
self.D._no_weight_decay = True
|
153 |
+
|
154 |
+
if self.rmsnorm:
|
155 |
+
assert RMSNormGated is not None
|
156 |
+
self.norm = RMSNormGated(
|
157 |
+
self.d_ssm,
|
158 |
+
eps=1e-5,
|
159 |
+
norm_before_gate=self.norm_before_gate,
|
160 |
+
group_size=self.d_ssm // ngroups,
|
161 |
+
**factory_kwargs,
|
162 |
+
)
|
163 |
+
|
164 |
+
if self.process_group is None:
|
165 |
+
self.out_proj = nn.Linear(
|
166 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
self.out_proj = RowParallelLinear(
|
170 |
+
self.d_inner * self.world_size,
|
171 |
+
self.d_model,
|
172 |
+
bias=bias,
|
173 |
+
process_group=self.process_group,
|
174 |
+
sequence_parallel=self.sequence_parallel,
|
175 |
+
**factory_kwargs,
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(
|
179 |
+
self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
|
180 |
+
):
|
181 |
+
"""
|
182 |
+
u: (batch, seqlen, hidden_dim) if seqlen=None.
|
183 |
+
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
|
184 |
+
split u during sequence parallel, we split the batch * seqlen dimension
|
185 |
+
(in case batch is small).
|
186 |
+
Returns: same shape as u
|
187 |
+
"""
|
188 |
+
seqlen_og = seqlen
|
189 |
+
if seqlen is None:
|
190 |
+
batch, seqlen, dim = u.shape
|
191 |
+
else:
|
192 |
+
batch_seqlen, dim = u.shape
|
193 |
+
batch = batch_seqlen // seqlen
|
194 |
+
|
195 |
+
conv_state, ssm_state = None, None
|
196 |
+
if inference_params is not None:
|
197 |
+
inference_batch = (
|
198 |
+
cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
|
199 |
+
)
|
200 |
+
conv_state, ssm_state = self._get_states_from_cache(
|
201 |
+
inference_params, inference_batch
|
202 |
+
)
|
203 |
+
if inference_params.seqlen_offset > 0:
|
204 |
+
# The states are updated inplace
|
205 |
+
out, _, _ = self.step(u, conv_state, ssm_state)
|
206 |
+
return out
|
207 |
+
|
208 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
|
209 |
+
if seqlen_og is not None:
|
210 |
+
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
|
211 |
+
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
212 |
+
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
|
213 |
+
dt_limit_kwargs = (
|
214 |
+
{} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
215 |
+
)
|
216 |
+
if self.use_mem_eff_path and inference_params is None:
|
217 |
+
out = mamba_split_conv1d_scan_combined(
|
218 |
+
zxbcdt,
|
219 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
220 |
+
self.conv1d.bias,
|
221 |
+
self.dt_bias,
|
222 |
+
A,
|
223 |
+
D=(
|
224 |
+
rearrange(self.D, "(h p) -> h p", p=self.headdim)
|
225 |
+
if self.D_has_hdim
|
226 |
+
else self.D
|
227 |
+
),
|
228 |
+
chunk_size=self.chunk_size,
|
229 |
+
seq_idx=seq_idx,
|
230 |
+
activation=self.activation,
|
231 |
+
rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
|
232 |
+
rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
|
233 |
+
outproj_weight=self.out_proj.weight,
|
234 |
+
outproj_bias=self.out_proj.bias,
|
235 |
+
headdim=None if self.D_has_hdim else self.headdim,
|
236 |
+
ngroups=self.ngroups,
|
237 |
+
norm_before_gate=self.norm_before_gate,
|
238 |
+
**dt_limit_kwargs,
|
239 |
+
)
|
240 |
+
if seqlen_og is not None:
|
241 |
+
out = rearrange(out, "b l d -> (b l) d")
|
242 |
+
if self.process_group is not None:
|
243 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
244 |
+
out = reduce_fn(out, self.process_group)
|
245 |
+
else:
|
246 |
+
d_mlp = (
|
247 |
+
zxbcdt.shape[-1]
|
248 |
+
- 2 * self.d_ssm
|
249 |
+
- 2 * self.ngroups * self.d_state
|
250 |
+
- self.nheads
|
251 |
+
) // 2
|
252 |
+
z0, x0, z, xBC, dt = torch.split(
|
253 |
+
zxbcdt,
|
254 |
+
[
|
255 |
+
d_mlp,
|
256 |
+
d_mlp,
|
257 |
+
self.d_ssm,
|
258 |
+
self.d_ssm + 2 * self.ngroups * self.d_state,
|
259 |
+
self.nheads,
|
260 |
+
],
|
261 |
+
dim=-1,
|
262 |
+
)
|
263 |
+
if conv_state is not None:
|
264 |
+
if cu_seqlens is None:
|
265 |
+
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
266 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
267 |
+
xBC_t = rearrange(xBC, "b l d -> b d l")
|
268 |
+
conv_state.copy_(
|
269 |
+
F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
|
270 |
+
) # Update state (B D W)
|
271 |
+
else:
|
272 |
+
assert (
|
273 |
+
causal_conv1d_varlen_states is not None
|
274 |
+
), "varlen inference requires causal_conv1d package"
|
275 |
+
assert (
|
276 |
+
batch == 1
|
277 |
+
), "varlen inference only supports batch dimension 1"
|
278 |
+
conv_varlen_states = causal_conv1d_varlen_states(
|
279 |
+
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
|
280 |
+
)
|
281 |
+
conv_state.copy_(conv_varlen_states)
|
282 |
+
assert self.activation in ["silu", "swish"]
|
283 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
284 |
+
assert (
|
285 |
+
seq_idx is None
|
286 |
+
), "varlen conv1d requires the causal_conv1d package"
|
287 |
+
xBC = self.act(
|
288 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
|
289 |
+
:, : -(self.d_conv - 1)
|
290 |
+
]
|
291 |
+
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
|
292 |
+
else:
|
293 |
+
xBC = causal_conv1d_fn(
|
294 |
+
xBC.transpose(1, 2),
|
295 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
296 |
+
bias=self.conv1d.bias,
|
297 |
+
activation=self.activation,
|
298 |
+
seq_idx=seq_idx,
|
299 |
+
).transpose(1, 2)
|
300 |
+
x, B, C = torch.split(
|
301 |
+
xBC,
|
302 |
+
[self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
|
303 |
+
dim=-1,
|
304 |
+
)
|
305 |
+
y = mamba_chunk_scan_combined(
|
306 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
307 |
+
dt,
|
308 |
+
A,
|
309 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
310 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
311 |
+
chunk_size=self.chunk_size,
|
312 |
+
D=(
|
313 |
+
rearrange(self.D, "(h p) -> h p", p=self.headdim)
|
314 |
+
if self.D_has_hdim
|
315 |
+
else self.D
|
316 |
+
),
|
317 |
+
z=(
|
318 |
+
rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
|
319 |
+
if not self.rmsnorm
|
320 |
+
else None
|
321 |
+
),
|
322 |
+
dt_bias=self.dt_bias,
|
323 |
+
dt_softplus=True,
|
324 |
+
seq_idx=seq_idx,
|
325 |
+
cu_seqlens=cu_seqlens,
|
326 |
+
**dt_limit_kwargs,
|
327 |
+
return_final_states=ssm_state is not None,
|
328 |
+
return_varlen_states=cu_seqlens is not None
|
329 |
+
and inference_params is not None,
|
330 |
+
)
|
331 |
+
if ssm_state is not None:
|
332 |
+
y, last_state, *rest = y
|
333 |
+
if cu_seqlens is None:
|
334 |
+
ssm_state.copy_(last_state)
|
335 |
+
else:
|
336 |
+
varlen_states = rest[0]
|
337 |
+
ssm_state.copy_(varlen_states)
|
338 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
339 |
+
if self.rmsnorm:
|
340 |
+
y = self.norm(y, z)
|
341 |
+
if d_mlp > 0:
|
342 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
343 |
+
if seqlen_og is not None:
|
344 |
+
y = rearrange(y, "b l d -> (b l) d")
|
345 |
+
out = self.out_proj(y)
|
346 |
+
return out
|
347 |
+
|
348 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
349 |
+
dtype = hidden_states.dtype
|
350 |
+
assert (
|
351 |
+
hidden_states.shape[1] == 1
|
352 |
+
), "Only support decoding with 1 token at a time for now"
|
353 |
+
zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
354 |
+
d_mlp = (
|
355 |
+
zxbcdt.shape[-1]
|
356 |
+
- 2 * self.d_ssm
|
357 |
+
- 2 * self.ngroups * self.d_state
|
358 |
+
- self.nheads
|
359 |
+
) // 2
|
360 |
+
z0, x0, z, xBC, dt = torch.split(
|
361 |
+
zxbcdt,
|
362 |
+
[
|
363 |
+
d_mlp,
|
364 |
+
d_mlp,
|
365 |
+
self.d_ssm,
|
366 |
+
self.d_ssm + 2 * self.ngroups * self.d_state,
|
367 |
+
self.nheads,
|
368 |
+
],
|
369 |
+
dim=-1,
|
370 |
+
)
|
371 |
+
|
372 |
+
# Conv step
|
373 |
+
if causal_conv1d_update is None:
|
374 |
+
conv_state.copy_(
|
375 |
+
torch.roll(conv_state, shifts=-1, dims=-1)
|
376 |
+
) # Update state (B D W)
|
377 |
+
conv_state[:, :, -1] = xBC
|
378 |
+
xBC = torch.sum(
|
379 |
+
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
380 |
+
) # (B D)
|
381 |
+
if self.conv1d.bias is not None:
|
382 |
+
xBC = xBC + self.conv1d.bias
|
383 |
+
xBC = self.act(xBC).to(dtype=dtype)
|
384 |
+
else:
|
385 |
+
xBC = causal_conv1d_update(
|
386 |
+
xBC,
|
387 |
+
conv_state,
|
388 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
389 |
+
self.conv1d.bias,
|
390 |
+
self.activation,
|
391 |
+
)
|
392 |
+
|
393 |
+
x, B, C = torch.split(
|
394 |
+
xBC,
|
395 |
+
[self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
|
396 |
+
dim=-1,
|
397 |
+
)
|
398 |
+
A = -torch.exp(self.A_log.float()) # (nheads,)
|
399 |
+
|
400 |
+
# SSM step
|
401 |
+
if selective_state_update is None:
|
402 |
+
assert (
|
403 |
+
self.ngroups == 1
|
404 |
+
), "Only support ngroups=1 for this inference code path"
|
405 |
+
# Discretize A and B
|
406 |
+
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
|
407 |
+
dA = torch.exp(dt * A) # (batch, nheads)
|
408 |
+
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
409 |
+
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
|
410 |
+
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
411 |
+
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
|
412 |
+
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
|
413 |
+
y = rearrange(y, "b h p -> b (h p)")
|
414 |
+
if not self.rmsnorm:
|
415 |
+
y = y * self.act(z) # (B D)
|
416 |
+
else:
|
417 |
+
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
|
418 |
+
dtype=torch.float32
|
419 |
+
)
|
420 |
+
dt = repeat(dt, "b h -> b h p", p=self.headdim)
|
421 |
+
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
|
422 |
+
D = repeat(self.D, "h -> h p", p=self.headdim)
|
423 |
+
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
|
424 |
+
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
|
425 |
+
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
426 |
+
if not self.rmsnorm:
|
427 |
+
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
|
428 |
+
y = selective_state_update(
|
429 |
+
ssm_state,
|
430 |
+
x_reshaped,
|
431 |
+
dt,
|
432 |
+
A,
|
433 |
+
B,
|
434 |
+
C,
|
435 |
+
D,
|
436 |
+
z=z if not self.rmsnorm else None,
|
437 |
+
dt_bias=dt_bias,
|
438 |
+
dt_softplus=True,
|
439 |
+
)
|
440 |
+
y = rearrange(y, "b h p -> b (h p)")
|
441 |
+
if self.rmsnorm:
|
442 |
+
y = self.norm(y, z)
|
443 |
+
if d_mlp > 0:
|
444 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
445 |
+
out = self.out_proj(y)
|
446 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
447 |
+
|
448 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
449 |
+
device = self.out_proj.weight.device
|
450 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
451 |
+
conv_state = torch.zeros(
|
452 |
+
batch_size,
|
453 |
+
self.d_conv,
|
454 |
+
self.conv1d.weight.shape[0],
|
455 |
+
device=device,
|
456 |
+
dtype=conv_dtype,
|
457 |
+
).transpose(1, 2)
|
458 |
+
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
|
459 |
+
ssm_state = torch.zeros(
|
460 |
+
batch_size,
|
461 |
+
self.nheads,
|
462 |
+
self.headdim,
|
463 |
+
self.d_state,
|
464 |
+
device=device,
|
465 |
+
dtype=ssm_dtype,
|
466 |
+
)
|
467 |
+
return conv_state, ssm_state
|
468 |
+
|
469 |
+
def _get_states_from_cache(
|
470 |
+
self, inference_params, batch_size, initialize_states=False
|
471 |
+
):
|
472 |
+
assert self.layer_idx is not None
|
473 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
474 |
+
batch_shape = (batch_size,)
|
475 |
+
conv_state = torch.zeros(
|
476 |
+
batch_size,
|
477 |
+
self.d_conv,
|
478 |
+
self.conv1d.weight.shape[0],
|
479 |
+
device=self.conv1d.weight.device,
|
480 |
+
dtype=self.conv1d.weight.dtype,
|
481 |
+
).transpose(1, 2)
|
482 |
+
ssm_state = torch.zeros(
|
483 |
+
batch_size,
|
484 |
+
self.nheads,
|
485 |
+
self.headdim,
|
486 |
+
self.d_state,
|
487 |
+
device=self.in_proj.weight.device,
|
488 |
+
dtype=self.in_proj.weight.dtype,
|
489 |
+
)
|
490 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (
|
491 |
+
conv_state,
|
492 |
+
ssm_state,
|
493 |
+
)
|
494 |
+
else:
|
495 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
496 |
+
self.layer_idx
|
497 |
+
]
|
498 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
499 |
+
if initialize_states:
|
500 |
+
conv_state.zero_()
|
501 |
+
ssm_state.zero_()
|
502 |
+
return conv_state, ssm_state
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba2_simple.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
try:
|
11 |
+
from causal_conv1d import causal_conv1d_fn
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
|
17 |
+
except ImportError:
|
18 |
+
RMSNormGated, LayerNorm = None, None
|
19 |
+
|
20 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
21 |
+
from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
22 |
+
|
23 |
+
|
24 |
+
class Mamba2Simple(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
d_model,
|
28 |
+
d_state=64,
|
29 |
+
d_conv=4,
|
30 |
+
conv_init=None,
|
31 |
+
expand=2,
|
32 |
+
headdim=128,
|
33 |
+
ngroups=1,
|
34 |
+
A_init_range=(1, 16),
|
35 |
+
dt_min=0.001,
|
36 |
+
dt_max=0.1,
|
37 |
+
dt_init_floor=1e-4,
|
38 |
+
dt_limit=(0.0, float("inf")),
|
39 |
+
learnable_init_states=False,
|
40 |
+
activation="swish",
|
41 |
+
bias=False,
|
42 |
+
conv_bias=True,
|
43 |
+
# Fused kernel and sharding options
|
44 |
+
chunk_size=256,
|
45 |
+
use_mem_eff_path=True,
|
46 |
+
layer_idx=None, # Absorb kwarg for general module
|
47 |
+
device=None,
|
48 |
+
dtype=None,
|
49 |
+
):
|
50 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
51 |
+
super().__init__()
|
52 |
+
self.d_model = d_model
|
53 |
+
self.d_state = d_state
|
54 |
+
self.d_conv = d_conv
|
55 |
+
self.conv_init = conv_init
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = self.expand * self.d_model
|
58 |
+
self.headdim = headdim
|
59 |
+
self.ngroups = ngroups
|
60 |
+
assert self.d_inner % self.headdim == 0
|
61 |
+
self.nheads = self.d_inner // self.headdim
|
62 |
+
self.dt_limit = dt_limit
|
63 |
+
self.learnable_init_states = learnable_init_states
|
64 |
+
self.activation = activation
|
65 |
+
self.chunk_size = chunk_size
|
66 |
+
self.use_mem_eff_path = use_mem_eff_path
|
67 |
+
self.layer_idx = layer_idx
|
68 |
+
|
69 |
+
# Order: [z, x, B, C, dt]
|
70 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
71 |
+
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
72 |
+
|
73 |
+
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
74 |
+
self.conv1d = nn.Conv1d(
|
75 |
+
in_channels=conv_dim,
|
76 |
+
out_channels=conv_dim,
|
77 |
+
bias=conv_bias,
|
78 |
+
kernel_size=d_conv,
|
79 |
+
groups=conv_dim,
|
80 |
+
padding=d_conv - 1,
|
81 |
+
**factory_kwargs,
|
82 |
+
)
|
83 |
+
if self.conv_init is not None:
|
84 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
85 |
+
# self.conv1d.weight._no_weight_decay = True
|
86 |
+
|
87 |
+
if self.learnable_init_states:
|
88 |
+
self.init_states = nn.Parameter(
|
89 |
+
torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
|
90 |
+
)
|
91 |
+
self.init_states._no_weight_decay = True
|
92 |
+
|
93 |
+
self.act = nn.SiLU()
|
94 |
+
|
95 |
+
# Initialize log dt bias
|
96 |
+
dt = torch.exp(
|
97 |
+
torch.rand(self.nheads, **factory_kwargs)
|
98 |
+
* (math.log(dt_max) - math.log(dt_min))
|
99 |
+
+ math.log(dt_min)
|
100 |
+
)
|
101 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
102 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
103 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
104 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
105 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
106 |
+
# name.endswith("bias") in param_grouping.py
|
107 |
+
self.dt_bias._no_weight_decay = True
|
108 |
+
|
109 |
+
# A parameter
|
110 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
111 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
|
112 |
+
*A_init_range
|
113 |
+
)
|
114 |
+
A_log = torch.log(A).to(dtype=dtype)
|
115 |
+
self.A_log = nn.Parameter(A_log)
|
116 |
+
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
|
117 |
+
self.A_log._no_weight_decay = True
|
118 |
+
|
119 |
+
# D "skip" parameter
|
120 |
+
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
|
121 |
+
self.D._no_weight_decay = True
|
122 |
+
|
123 |
+
# Extra normalization layer right before output projection
|
124 |
+
assert RMSNormGated is not None
|
125 |
+
self.norm = RMSNormGated(
|
126 |
+
self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
|
127 |
+
)
|
128 |
+
|
129 |
+
self.out_proj = nn.Linear(
|
130 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, u, seq_idx=None):
|
134 |
+
"""
|
135 |
+
u: (B, L, D)
|
136 |
+
Returns: same shape as u
|
137 |
+
"""
|
138 |
+
batch, seqlen, dim = u.shape
|
139 |
+
|
140 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
141 |
+
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
142 |
+
initial_states = (
|
143 |
+
repeat(self.init_states, "... -> b ...", b=batch)
|
144 |
+
if self.learnable_init_states
|
145 |
+
else None
|
146 |
+
)
|
147 |
+
dt_limit_kwargs = (
|
148 |
+
{} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
149 |
+
)
|
150 |
+
|
151 |
+
if self.use_mem_eff_path:
|
152 |
+
# Fully fused path
|
153 |
+
out = mamba_split_conv1d_scan_combined(
|
154 |
+
zxbcdt,
|
155 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
156 |
+
self.conv1d.bias,
|
157 |
+
self.dt_bias,
|
158 |
+
A,
|
159 |
+
D=self.D,
|
160 |
+
chunk_size=self.chunk_size,
|
161 |
+
seq_idx=seq_idx,
|
162 |
+
activation=self.activation,
|
163 |
+
rmsnorm_weight=self.norm.weight,
|
164 |
+
rmsnorm_eps=self.norm.eps,
|
165 |
+
outproj_weight=self.out_proj.weight,
|
166 |
+
outproj_bias=self.out_proj.bias,
|
167 |
+
headdim=self.headdim,
|
168 |
+
ngroups=self.ngroups,
|
169 |
+
norm_before_gate=False,
|
170 |
+
initial_states=initial_states,
|
171 |
+
**dt_limit_kwargs,
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
z, xBC, dt = torch.split(
|
175 |
+
zxbcdt,
|
176 |
+
[
|
177 |
+
self.d_inner,
|
178 |
+
self.d_inner + 2 * self.ngroups * self.d_state,
|
179 |
+
self.nheads,
|
180 |
+
],
|
181 |
+
dim=-1,
|
182 |
+
)
|
183 |
+
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
184 |
+
assert self.activation in ["silu", "swish"]
|
185 |
+
|
186 |
+
# 1D Convolution
|
187 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
188 |
+
xBC = self.act(
|
189 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
190 |
+
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
191 |
+
xBC = xBC[:, :seqlen, :]
|
192 |
+
else:
|
193 |
+
xBC = causal_conv1d_fn(
|
194 |
+
x=xBC.transpose(1, 2),
|
195 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
196 |
+
bias=self.conv1d.bias,
|
197 |
+
activation=self.activation,
|
198 |
+
).transpose(1, 2)
|
199 |
+
|
200 |
+
# Split into 3 main branches: X, B, C
|
201 |
+
# These correspond to V, K, Q respectively in the SSM/attention duality
|
202 |
+
x, B, C = torch.split(
|
203 |
+
xBC,
|
204 |
+
[
|
205 |
+
self.d_inner,
|
206 |
+
self.ngroups * self.d_state,
|
207 |
+
self.ngroups * self.d_state,
|
208 |
+
],
|
209 |
+
dim=-1,
|
210 |
+
)
|
211 |
+
y = mamba_chunk_scan_combined(
|
212 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
213 |
+
dt,
|
214 |
+
A,
|
215 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
216 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
217 |
+
chunk_size=self.chunk_size,
|
218 |
+
D=self.D,
|
219 |
+
z=None,
|
220 |
+
seq_idx=seq_idx,
|
221 |
+
initial_states=initial_states,
|
222 |
+
**dt_limit_kwargs,
|
223 |
+
)
|
224 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
225 |
+
|
226 |
+
# Multiply "gate" branch and apply extra normalization layer
|
227 |
+
y = self.norm(y, z)
|
228 |
+
out = self.out_proj(y)
|
229 |
+
return out
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mamba_simple.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
14 |
+
|
15 |
+
try:
|
16 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
17 |
+
except ImportError:
|
18 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from ..ops.triton.selective_state_update import selective_state_update
|
22 |
+
except ImportError:
|
23 |
+
selective_state_update = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
27 |
+
except ImportError:
|
28 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
29 |
+
|
30 |
+
|
31 |
+
class Mamba(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
d_model,
|
35 |
+
d_state=16,
|
36 |
+
d_conv=4,
|
37 |
+
expand=2,
|
38 |
+
dt_rank="auto",
|
39 |
+
dt_min=0.001,
|
40 |
+
dt_max=0.1,
|
41 |
+
dt_init="random",
|
42 |
+
dt_scale=1.0,
|
43 |
+
dt_init_floor=1e-4,
|
44 |
+
conv_bias=True,
|
45 |
+
bias=False,
|
46 |
+
use_fast_path=True, # Fused kernel options
|
47 |
+
layer_idx=None,
|
48 |
+
device=None,
|
49 |
+
dtype=None,
|
50 |
+
):
|
51 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
52 |
+
super().__init__()
|
53 |
+
self.d_model = d_model
|
54 |
+
self.d_state = d_state
|
55 |
+
self.d_conv = d_conv
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = int(self.expand * self.d_model)
|
58 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
59 |
+
self.use_fast_path = use_fast_path
|
60 |
+
self.layer_idx = layer_idx
|
61 |
+
|
62 |
+
self.in_proj = nn.Linear(
|
63 |
+
self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
|
64 |
+
)
|
65 |
+
|
66 |
+
self.conv1d = nn.Conv1d(
|
67 |
+
in_channels=self.d_inner,
|
68 |
+
out_channels=self.d_inner,
|
69 |
+
bias=conv_bias,
|
70 |
+
kernel_size=d_conv,
|
71 |
+
groups=self.d_inner,
|
72 |
+
padding=d_conv - 1,
|
73 |
+
**factory_kwargs,
|
74 |
+
)
|
75 |
+
|
76 |
+
self.activation = "silu"
|
77 |
+
self.act = nn.SiLU()
|
78 |
+
|
79 |
+
self.x_proj = nn.Linear(
|
80 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
81 |
+
)
|
82 |
+
self.dt_proj = nn.Linear(
|
83 |
+
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
|
84 |
+
)
|
85 |
+
|
86 |
+
# Initialize special dt projection to preserve variance at initialization
|
87 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
88 |
+
if dt_init == "constant":
|
89 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
90 |
+
elif dt_init == "random":
|
91 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
96 |
+
dt = torch.exp(
|
97 |
+
torch.rand(self.d_inner, **factory_kwargs)
|
98 |
+
* (math.log(dt_max) - math.log(dt_min))
|
99 |
+
+ math.log(dt_min)
|
100 |
+
).clamp(min=dt_init_floor)
|
101 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
102 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
103 |
+
with torch.no_grad():
|
104 |
+
self.dt_proj.bias.copy_(inv_dt)
|
105 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
106 |
+
self.dt_proj.bias._no_reinit = True
|
107 |
+
|
108 |
+
# S4D real initialization
|
109 |
+
A = repeat(
|
110 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
111 |
+
"n -> d n",
|
112 |
+
d=self.d_inner,
|
113 |
+
).contiguous()
|
114 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
115 |
+
self.A_log = nn.Parameter(A_log)
|
116 |
+
self.A_log._no_weight_decay = True
|
117 |
+
|
118 |
+
# D "skip" parameter
|
119 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
120 |
+
self.D._no_weight_decay = True
|
121 |
+
|
122 |
+
self.out_proj = nn.Linear(
|
123 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, hidden_states, inference_params=None):
|
127 |
+
"""
|
128 |
+
hidden_states: (B, L, D)
|
129 |
+
Returns: same shape as hidden_states
|
130 |
+
"""
|
131 |
+
batch, seqlen, dim = hidden_states.shape
|
132 |
+
|
133 |
+
conv_state, ssm_state = None, None
|
134 |
+
if inference_params is not None:
|
135 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
136 |
+
if inference_params.seqlen_offset > 0:
|
137 |
+
# The states are updated inplace
|
138 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
139 |
+
return out
|
140 |
+
|
141 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
142 |
+
xz = rearrange(
|
143 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
144 |
+
"d (b l) -> b d l",
|
145 |
+
l=seqlen,
|
146 |
+
)
|
147 |
+
if self.in_proj.bias is not None:
|
148 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
149 |
+
|
150 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
151 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
152 |
+
if (
|
153 |
+
self.use_fast_path
|
154 |
+
and causal_conv1d_fn is not None
|
155 |
+
and inference_params is None
|
156 |
+
): # Doesn't support outputting the states
|
157 |
+
out = mamba_inner_fn(
|
158 |
+
xz,
|
159 |
+
self.conv1d.weight,
|
160 |
+
self.conv1d.bias,
|
161 |
+
self.x_proj.weight,
|
162 |
+
self.dt_proj.weight,
|
163 |
+
self.out_proj.weight,
|
164 |
+
self.out_proj.bias,
|
165 |
+
A,
|
166 |
+
None, # input-dependent B
|
167 |
+
None, # input-dependent C
|
168 |
+
self.D.float(),
|
169 |
+
delta_bias=self.dt_proj.bias.float(),
|
170 |
+
delta_softplus=True,
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
x, z = xz.chunk(2, dim=1)
|
174 |
+
# Compute short convolution
|
175 |
+
if conv_state is not None:
|
176 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
177 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
178 |
+
conv_state.copy_(
|
179 |
+
F.pad(x, (self.d_conv - x.shape[-1], 0))
|
180 |
+
) # Update state (B D W)
|
181 |
+
if causal_conv1d_fn is None:
|
182 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
183 |
+
else:
|
184 |
+
assert self.activation in ["silu", "swish"]
|
185 |
+
x = causal_conv1d_fn(
|
186 |
+
x=x,
|
187 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
188 |
+
bias=self.conv1d.bias,
|
189 |
+
activation=self.activation,
|
190 |
+
)
|
191 |
+
|
192 |
+
# We're careful here about the layout, to avoid extra transposes.
|
193 |
+
# We want dt to have d as the slowest moving dimension
|
194 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
195 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
196 |
+
dt, B, C = torch.split(
|
197 |
+
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
198 |
+
)
|
199 |
+
dt = self.dt_proj.weight @ dt.t()
|
200 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
201 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
202 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
203 |
+
assert self.activation in ["silu", "swish"]
|
204 |
+
y = selective_scan_fn(
|
205 |
+
x,
|
206 |
+
dt,
|
207 |
+
A,
|
208 |
+
B,
|
209 |
+
C,
|
210 |
+
self.D.float(),
|
211 |
+
z=z,
|
212 |
+
delta_bias=self.dt_proj.bias.float(),
|
213 |
+
delta_softplus=True,
|
214 |
+
return_last_state=ssm_state is not None,
|
215 |
+
)
|
216 |
+
if ssm_state is not None:
|
217 |
+
y, last_state = y
|
218 |
+
ssm_state.copy_(last_state)
|
219 |
+
y = rearrange(y, "b d l -> b l d")
|
220 |
+
out = self.out_proj(y)
|
221 |
+
return out
|
222 |
+
|
223 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
224 |
+
dtype = hidden_states.dtype
|
225 |
+
assert (
|
226 |
+
hidden_states.shape[1] == 1
|
227 |
+
), "Only support decoding with 1 token at a time for now"
|
228 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
229 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
230 |
+
|
231 |
+
# Conv step
|
232 |
+
if causal_conv1d_update is None:
|
233 |
+
conv_state.copy_(
|
234 |
+
torch.roll(conv_state, shifts=-1, dims=-1)
|
235 |
+
) # Update state (B D W)
|
236 |
+
conv_state[:, :, -1] = x
|
237 |
+
x = torch.sum(
|
238 |
+
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
239 |
+
) # (B D)
|
240 |
+
if self.conv1d.bias is not None:
|
241 |
+
x = x + self.conv1d.bias
|
242 |
+
x = self.act(x).to(dtype=dtype)
|
243 |
+
else:
|
244 |
+
x = causal_conv1d_update(
|
245 |
+
x,
|
246 |
+
conv_state,
|
247 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
248 |
+
self.conv1d.bias,
|
249 |
+
self.activation,
|
250 |
+
)
|
251 |
+
|
252 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
253 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
254 |
+
# Don't add dt_bias here
|
255 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
256 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
257 |
+
|
258 |
+
# SSM step
|
259 |
+
if selective_state_update is None:
|
260 |
+
# Discretize A and B
|
261 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
262 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
263 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
264 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
265 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
266 |
+
y = y + self.D.to(dtype) * x
|
267 |
+
y = y * self.act(z) # (B D)
|
268 |
+
else:
|
269 |
+
y = selective_state_update(
|
270 |
+
ssm_state,
|
271 |
+
x,
|
272 |
+
dt,
|
273 |
+
A,
|
274 |
+
B,
|
275 |
+
C,
|
276 |
+
self.D,
|
277 |
+
z=z,
|
278 |
+
dt_bias=self.dt_proj.bias,
|
279 |
+
dt_softplus=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
out = self.out_proj(y)
|
283 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
284 |
+
|
285 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
286 |
+
device = self.out_proj.weight.device
|
287 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
288 |
+
conv_state = torch.zeros(
|
289 |
+
batch_size,
|
290 |
+
self.d_model * self.expand,
|
291 |
+
self.d_conv,
|
292 |
+
device=device,
|
293 |
+
dtype=conv_dtype,
|
294 |
+
)
|
295 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
296 |
+
# ssm_dtype = torch.float32
|
297 |
+
ssm_state = torch.zeros(
|
298 |
+
batch_size,
|
299 |
+
self.d_model * self.expand,
|
300 |
+
self.d_state,
|
301 |
+
device=device,
|
302 |
+
dtype=ssm_dtype,
|
303 |
+
)
|
304 |
+
return conv_state, ssm_state
|
305 |
+
|
306 |
+
def _get_states_from_cache(
|
307 |
+
self, inference_params, batch_size, initialize_states=False
|
308 |
+
):
|
309 |
+
assert self.layer_idx is not None
|
310 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
311 |
+
batch_shape = (batch_size,)
|
312 |
+
conv_state = torch.zeros(
|
313 |
+
batch_size,
|
314 |
+
self.d_model * self.expand,
|
315 |
+
self.d_conv,
|
316 |
+
device=self.conv1d.weight.device,
|
317 |
+
dtype=self.conv1d.weight.dtype,
|
318 |
+
)
|
319 |
+
ssm_state = torch.zeros(
|
320 |
+
batch_size,
|
321 |
+
self.d_model * self.expand,
|
322 |
+
self.d_state,
|
323 |
+
device=self.dt_proj.weight.device,
|
324 |
+
dtype=self.dt_proj.weight.dtype,
|
325 |
+
# dtype=torch.float32,
|
326 |
+
)
|
327 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (
|
328 |
+
conv_state,
|
329 |
+
ssm_state,
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
333 |
+
self.layer_idx
|
334 |
+
]
|
335 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
336 |
+
if initialize_states:
|
337 |
+
conv_state.zero_()
|
338 |
+
ssm_state.zero_()
|
339 |
+
return conv_state, ssm_state
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mha.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
try:
|
11 |
+
from flash_attn import flash_attn_with_kvcache
|
12 |
+
except ImportError:
|
13 |
+
flash_attn_with_kvcache = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
17 |
+
except ImportError:
|
18 |
+
RotaryEmbedding = None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
22 |
+
except ImportError:
|
23 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
24 |
+
|
25 |
+
|
26 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
27 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
28 |
+
# Pre-allocate memory for key-values for inference.
|
29 |
+
num_heads, head_dim = kv.shape[-2:]
|
30 |
+
assert layer_idx in inference_params.key_value_memory_dict
|
31 |
+
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
|
32 |
+
# Adjust key and value for inference
|
33 |
+
batch_start = inference_params.batch_size_offset
|
34 |
+
batch_end = batch_start + kv.shape[0]
|
35 |
+
sequence_start = inference_params.seqlen_offset
|
36 |
+
sequence_end = sequence_start + kv.shape[1]
|
37 |
+
assert batch_end <= kv_cache.shape[0]
|
38 |
+
assert sequence_end <= kv_cache.shape[1]
|
39 |
+
assert kv_cache is not None
|
40 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
41 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
42 |
+
|
43 |
+
|
44 |
+
class MHA(nn.Module):
|
45 |
+
"""Multi-head self-attention and cross-attention"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
embed_dim,
|
50 |
+
num_heads,
|
51 |
+
num_heads_kv=None,
|
52 |
+
head_dim=None, # If None, use embed_dim // num_heads
|
53 |
+
mlp_dim=0,
|
54 |
+
qkv_proj_bias=True,
|
55 |
+
out_proj_bias=True,
|
56 |
+
softmax_scale=None,
|
57 |
+
causal=False,
|
58 |
+
layer_idx=None,
|
59 |
+
d_conv=0,
|
60 |
+
rotary_emb_dim=0,
|
61 |
+
rotary_emb_base=10000.0,
|
62 |
+
rotary_emb_interleaved=False,
|
63 |
+
device=None,
|
64 |
+
dtype=None,
|
65 |
+
) -> None:
|
66 |
+
"""
|
67 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
68 |
+
return_residual: whether to return the input x along with the output. This is for
|
69 |
+
performance reason: for post-norm architecture, returning the input allows us
|
70 |
+
to fuse the backward of nn.Linear with the residual connection.
|
71 |
+
"""
|
72 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
73 |
+
super().__init__()
|
74 |
+
self.embed_dim = embed_dim
|
75 |
+
self.layer_idx = layer_idx
|
76 |
+
self.d_conv = d_conv
|
77 |
+
self.rotary_emb_dim = rotary_emb_dim
|
78 |
+
self.softmax_scale = softmax_scale
|
79 |
+
self.causal = causal
|
80 |
+
|
81 |
+
self.num_heads = num_heads
|
82 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
83 |
+
assert (
|
84 |
+
self.num_heads % self.num_heads_kv == 0
|
85 |
+
), "num_heads must be divisible by num_heads_kv"
|
86 |
+
if head_dim is None:
|
87 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
88 |
+
self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
|
89 |
+
self.mlp_dim = math.ceil(mlp_dim / 256) * 256
|
90 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
91 |
+
out_dim = self.head_dim * self.num_heads
|
92 |
+
|
93 |
+
if self.rotary_emb_dim > 0:
|
94 |
+
assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
|
95 |
+
self.rotary_emb = RotaryEmbedding(
|
96 |
+
self.rotary_emb_dim,
|
97 |
+
base=rotary_emb_base,
|
98 |
+
interleaved=rotary_emb_interleaved,
|
99 |
+
device=device,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
|
103 |
+
if self.d_conv > 0:
|
104 |
+
self.conv1d = nn.Conv1d(
|
105 |
+
qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
|
106 |
+
**factory_kwargs
|
107 |
+
)
|
108 |
+
self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
109 |
+
|
110 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
111 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
112 |
+
device = self.out_proj.weight.device
|
113 |
+
if self.d_conv > 0:
|
114 |
+
conv_state = torch.zeros(
|
115 |
+
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
conv_state = None
|
119 |
+
kv_cache = torch.empty(
|
120 |
+
batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
|
121 |
+
)
|
122 |
+
return kv_cache, conv_state
|
123 |
+
|
124 |
+
def _update_kv_cache(self, kv, inference_params):
|
125 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
126 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
127 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
128 |
+
|
129 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
130 |
+
"""
|
131 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
132 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
133 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
134 |
+
"""
|
135 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
136 |
+
if self.rotary_emb_dim > 0:
|
137 |
+
self.rotary_emb._update_cos_sin_cache(
|
138 |
+
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
139 |
+
)
|
140 |
+
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
141 |
+
else:
|
142 |
+
rotary_cos, rotary_sin = None, None
|
143 |
+
batch = q.shape[0]
|
144 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
145 |
+
kv_cache = kv_cache[:batch]
|
146 |
+
cache_seqlens = (
|
147 |
+
inference_params.lengths_per_sample[:batch]
|
148 |
+
if inference_params.lengths_per_sample is not None
|
149 |
+
else inference_params.seqlen_offset
|
150 |
+
)
|
151 |
+
assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
|
152 |
+
context = flash_attn_with_kvcache(
|
153 |
+
q,
|
154 |
+
kv_cache[:, :, 0],
|
155 |
+
kv_cache[:, :, 1],
|
156 |
+
kv[:, :, 0],
|
157 |
+
kv[:, :, 1],
|
158 |
+
rotary_cos=rotary_cos,
|
159 |
+
rotary_sin=rotary_sin,
|
160 |
+
cache_seqlens=cache_seqlens,
|
161 |
+
softmax_scale=self.softmax_scale,
|
162 |
+
causal=self.causal,
|
163 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
164 |
+
)
|
165 |
+
return context
|
166 |
+
|
167 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
168 |
+
"""Write kv to inference_params, then do attention"""
|
169 |
+
if (
|
170 |
+
inference_params.seqlen_offset == 0
|
171 |
+
or flash_attn_with_kvcache is None
|
172 |
+
):
|
173 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
174 |
+
kv = self._update_kv_cache(kv, inference_params)
|
175 |
+
k, v = kv.unbind(dim=-3)
|
176 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
177 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
178 |
+
return F.scaled_dot_product_attention(
|
179 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
180 |
+
).transpose(1, 2)
|
181 |
+
else:
|
182 |
+
batch = q.shape[0]
|
183 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
184 |
+
kv_cache = kv_cache[:batch]
|
185 |
+
cache_seqlens = (
|
186 |
+
inference_params.lengths_per_sample[:batch]
|
187 |
+
if inference_params.lengths_per_sample is not None
|
188 |
+
else inference_params.seqlen_offset
|
189 |
+
)
|
190 |
+
return flash_attn_with_kvcache(
|
191 |
+
q,
|
192 |
+
kv_cache[:, :, 0],
|
193 |
+
kv_cache[:, :, 1],
|
194 |
+
kv[:, :, 0],
|
195 |
+
kv[:, :, 1],
|
196 |
+
cache_seqlens=cache_seqlens,
|
197 |
+
softmax_scale=self.softmax_scale,
|
198 |
+
causal=self.causal,
|
199 |
+
)
|
200 |
+
|
201 |
+
def forward(self, x, inference_params=None):
|
202 |
+
"""
|
203 |
+
Arguments:
|
204 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
205 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
206 |
+
is the is the sum of the sequence lengths in the batch.
|
207 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
208 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
209 |
+
"""
|
210 |
+
if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
|
211 |
+
inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
|
212 |
+
x.shape[0], inference_params.max_seqlen, dtype=x.dtype
|
213 |
+
)
|
214 |
+
seqlen_offset = (
|
215 |
+
0
|
216 |
+
if inference_params is None
|
217 |
+
else (
|
218 |
+
inference_params.lengths_per_sample
|
219 |
+
if inference_params.lengths_per_sample is not None
|
220 |
+
else inference_params.seqlen_offset
|
221 |
+
)
|
222 |
+
)
|
223 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
224 |
+
qkv = self.in_proj(x)
|
225 |
+
if self.mlp_dim > 0:
|
226 |
+
qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
|
227 |
+
x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
|
228 |
+
x_mlp = x_mlp_up * F.silu(x_mlp_gate)
|
229 |
+
if self.d_conv > 0:
|
230 |
+
# The inference code for conv1d is pretty messy, should clean it up
|
231 |
+
if (inference_params is None or inference_params.seqlen_offset == 0):
|
232 |
+
if causal_conv1d_fn is None:
|
233 |
+
qkv = rearrange(
|
234 |
+
self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
|
235 |
+
).contiguous()
|
236 |
+
else:
|
237 |
+
qkv = causal_conv1d_fn(
|
238 |
+
qkv.transpose(1, 2),
|
239 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
240 |
+
self.conv1d.bias
|
241 |
+
).transpose(1, 2)
|
242 |
+
if inference_params is not None:
|
243 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
244 |
+
# If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
245 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
246 |
+
qkv_t = rearrange(qkv, "b l d -> b d l")
|
247 |
+
conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
|
248 |
+
else:
|
249 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
250 |
+
assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
251 |
+
qkv = qkv.squeeze(1)
|
252 |
+
# Conv step
|
253 |
+
if causal_conv1d_update is None:
|
254 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
255 |
+
conv_state[:, :, -1] = qkv
|
256 |
+
qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
257 |
+
if self.conv1d.bias is not None:
|
258 |
+
qkv = qkv + self.conv1d.bias
|
259 |
+
else:
|
260 |
+
qkv = causal_conv1d_update(
|
261 |
+
qkv,
|
262 |
+
conv_state,
|
263 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
264 |
+
self.conv1d.bias
|
265 |
+
)
|
266 |
+
qkv = qkv.unsqueeze(1)
|
267 |
+
q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
|
268 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
269 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
270 |
+
if (
|
271 |
+
inference_params is None
|
272 |
+
or inference_params.seqlen_offset == 0
|
273 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
274 |
+
):
|
275 |
+
if self.rotary_emb_dim > 0:
|
276 |
+
q, kv = self.rotary_emb(
|
277 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
278 |
+
)
|
279 |
+
if inference_params is None:
|
280 |
+
k, v = kv.unbind(dim=-3)
|
281 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
282 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
283 |
+
context = F.scaled_dot_product_attention(
|
284 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
285 |
+
).transpose(1, 2)
|
286 |
+
else:
|
287 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
288 |
+
else:
|
289 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
290 |
+
context = rearrange(context, "... h d -> ... (h d)")
|
291 |
+
if self.mlp_dim > 0:
|
292 |
+
context = torch.cat([context, x_mlp], dim=-1)
|
293 |
+
out = self.out_proj(context)
|
294 |
+
return out
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/mlp.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class GatedMLP(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_features,
|
10 |
+
hidden_features=None,
|
11 |
+
out_features=None,
|
12 |
+
activation=F.silu,
|
13 |
+
bias=False,
|
14 |
+
multiple_of=128,
|
15 |
+
device=None,
|
16 |
+
dtype=None,
|
17 |
+
):
|
18 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features if out_features is not None else in_features
|
21 |
+
hidden_features = (
|
22 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
23 |
+
)
|
24 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
25 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
|
26 |
+
self.activation = activation
|
27 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
y = self.fc1(x)
|
31 |
+
y, gate = y.chunk(2, dim=-1)
|
32 |
+
y = y * self.activation(gate)
|
33 |
+
y = self.fc2(y)
|
34 |
+
return y
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/modules/ssd_minimal.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Albert Gu and Tri Dao.
|
2 |
+
"""Minimal implementation of SSD.
|
3 |
+
|
4 |
+
This is the same as Listing 1 from the paper.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
12 |
+
|
13 |
+
|
14 |
+
def segsum_unstable(x):
|
15 |
+
"""Naive segment sum calculation."""
|
16 |
+
T = x.size(-1)
|
17 |
+
x_cumsum = torch.cumsum(x, dim=-1)
|
18 |
+
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
|
19 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
20 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
21 |
+
return x_segsum
|
22 |
+
|
23 |
+
|
24 |
+
def segsum(x):
|
25 |
+
"""More stable segment sum calculation."""
|
26 |
+
T = x.size(-1)
|
27 |
+
x = repeat(x, "... d -> ... d e", e=T)
|
28 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
29 |
+
x = x.masked_fill(~mask, 0)
|
30 |
+
x_segsum = torch.cumsum(x, dim=-2)
|
31 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
32 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
33 |
+
return x_segsum
|
34 |
+
|
35 |
+
|
36 |
+
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
37 |
+
"""
|
38 |
+
Arguments:
|
39 |
+
X: (batch, length, n_heads, d_head)
|
40 |
+
A: (batch, length, n_heads)
|
41 |
+
B: (batch, length, n_heads, d_state)
|
42 |
+
C: (batch, length, n_heads, d_state)
|
43 |
+
Return:
|
44 |
+
Y: (batch, length, n_heads, d_head)
|
45 |
+
"""
|
46 |
+
assert X.dtype == A.dtype == B.dtype == C.dtype
|
47 |
+
assert X.shape[1] % block_len == 0
|
48 |
+
|
49 |
+
# Rearrange into blocks/chunks
|
50 |
+
X, A, B, C = [
|
51 |
+
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
|
52 |
+
]
|
53 |
+
|
54 |
+
A = rearrange(A, "b c l h -> b h c l")
|
55 |
+
A_cumsum = torch.cumsum(A, dim=-1)
|
56 |
+
|
57 |
+
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
58 |
+
L = torch.exp(segsum(A))
|
59 |
+
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
60 |
+
|
61 |
+
# 2. Compute the state for each intra-chunk
|
62 |
+
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
63 |
+
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
64 |
+
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
65 |
+
|
66 |
+
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
67 |
+
# (middle term of factorization of off-diag blocks; A terms)
|
68 |
+
if initial_states is None:
|
69 |
+
initial_states = torch.zeros_like(states[:, :1])
|
70 |
+
states = torch.cat([initial_states, states], dim=1)
|
71 |
+
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
72 |
+
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
73 |
+
states, final_state = new_states[:, :-1], new_states[:, -1]
|
74 |
+
|
75 |
+
# 4. Compute state -> output conversion per chunk
|
76 |
+
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
77 |
+
state_decay_out = torch.exp(A_cumsum)
|
78 |
+
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
79 |
+
|
80 |
+
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
81 |
+
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
82 |
+
return Y, final_state
|
83 |
+
|
84 |
+
|
85 |
+
# Simple test
|
86 |
+
def test_correctness():
|
87 |
+
torch.manual_seed(42)
|
88 |
+
|
89 |
+
## Dimensions
|
90 |
+
# Denoted (B, T, Q, D, P) in the paper
|
91 |
+
batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
|
92 |
+
nheads = dim // headdim # (H) in the paper
|
93 |
+
ngroups = 1 # (G) in the paper
|
94 |
+
dstate = 64 # (N) in the paper
|
95 |
+
dtype = torch.float32
|
96 |
+
device = "cuda"
|
97 |
+
|
98 |
+
x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
|
99 |
+
dt = F.softplus(
|
100 |
+
torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4
|
101 |
+
).requires_grad_()
|
102 |
+
A = (
|
103 |
+
-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))
|
104 |
+
).requires_grad_()
|
105 |
+
B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
106 |
+
C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
107 |
+
D = torch.randn(nheads, dtype=dtype, device=device)
|
108 |
+
|
109 |
+
# Comparing fused version and minimal version
|
110 |
+
y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
|
111 |
+
y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/selective_scan_interface.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from ..utils.torch import custom_fwd, custom_bwd
|
6 |
+
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
try:
|
10 |
+
from causal_conv1d import causal_conv1d_fn
|
11 |
+
import causal_conv1d_cuda
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
causal_conv1d_cuda = None
|
15 |
+
|
16 |
+
from .triton.layer_norm import _layer_norm_fwd
|
17 |
+
|
18 |
+
from .._ops import ops
|
19 |
+
|
20 |
+
|
21 |
+
class SelectiveScanFn(torch.autograd.Function):
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def forward(
|
25 |
+
ctx,
|
26 |
+
u,
|
27 |
+
delta,
|
28 |
+
A,
|
29 |
+
B,
|
30 |
+
C,
|
31 |
+
D=None,
|
32 |
+
z=None,
|
33 |
+
delta_bias=None,
|
34 |
+
delta_softplus=False,
|
35 |
+
return_last_state=False,
|
36 |
+
):
|
37 |
+
if u.stride(-1) != 1:
|
38 |
+
u = u.contiguous()
|
39 |
+
if delta.stride(-1) != 1:
|
40 |
+
delta = delta.contiguous()
|
41 |
+
if D is not None:
|
42 |
+
D = D.contiguous()
|
43 |
+
if B.stride(-1) != 1:
|
44 |
+
B = B.contiguous()
|
45 |
+
if C.stride(-1) != 1:
|
46 |
+
C = C.contiguous()
|
47 |
+
if z is not None and z.stride(-1) != 1:
|
48 |
+
z = z.contiguous()
|
49 |
+
if B.dim() == 3:
|
50 |
+
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
51 |
+
ctx.squeeze_B = True
|
52 |
+
if C.dim() == 3:
|
53 |
+
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
54 |
+
ctx.squeeze_C = True
|
55 |
+
out, x, *rest = ops.selective_scan_fwd(
|
56 |
+
u, delta, A, B, C, D, z, delta_bias, delta_softplus
|
57 |
+
)
|
58 |
+
ctx.delta_softplus = delta_softplus
|
59 |
+
ctx.has_z = z is not None
|
60 |
+
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
61 |
+
if not ctx.has_z:
|
62 |
+
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
63 |
+
return out if not return_last_state else (out, last_state)
|
64 |
+
else:
|
65 |
+
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
66 |
+
out_z = rest[0]
|
67 |
+
return out_z if not return_last_state else (out_z, last_state)
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def backward(ctx, dout, *args):
|
71 |
+
if not ctx.has_z:
|
72 |
+
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
73 |
+
z = None
|
74 |
+
out = None
|
75 |
+
else:
|
76 |
+
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
77 |
+
if dout.stride(-1) != 1:
|
78 |
+
dout = dout.contiguous()
|
79 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
80 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
81 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
82 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
|
83 |
+
u,
|
84 |
+
delta,
|
85 |
+
A,
|
86 |
+
B,
|
87 |
+
C,
|
88 |
+
D,
|
89 |
+
z,
|
90 |
+
delta_bias,
|
91 |
+
dout,
|
92 |
+
x,
|
93 |
+
out,
|
94 |
+
None,
|
95 |
+
ctx.delta_softplus,
|
96 |
+
False, # option to recompute out_z, not used here
|
97 |
+
)
|
98 |
+
dz = rest[0] if ctx.has_z else None
|
99 |
+
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
100 |
+
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
101 |
+
return (
|
102 |
+
du,
|
103 |
+
ddelta,
|
104 |
+
dA,
|
105 |
+
dB,
|
106 |
+
dC,
|
107 |
+
dD if D is not None else None,
|
108 |
+
dz,
|
109 |
+
ddelta_bias if delta_bias is not None else None,
|
110 |
+
None,
|
111 |
+
None,
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
def rms_norm_forward(
|
116 |
+
x,
|
117 |
+
weight,
|
118 |
+
bias,
|
119 |
+
eps=1e-6,
|
120 |
+
is_rms_norm=True,
|
121 |
+
):
|
122 |
+
# x (b l) d
|
123 |
+
if x.stride(-1) != 1:
|
124 |
+
x = x.contiguous()
|
125 |
+
weight = weight.contiguous()
|
126 |
+
if bias is not None:
|
127 |
+
bias = bias.contiguous()
|
128 |
+
y = _layer_norm_fwd(
|
129 |
+
x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
|
130 |
+
)[0]
|
131 |
+
# y (b l) d
|
132 |
+
return y
|
133 |
+
|
134 |
+
|
135 |
+
def selective_scan_fn(
|
136 |
+
u,
|
137 |
+
delta,
|
138 |
+
A,
|
139 |
+
B,
|
140 |
+
C,
|
141 |
+
D=None,
|
142 |
+
z=None,
|
143 |
+
delta_bias=None,
|
144 |
+
delta_softplus=False,
|
145 |
+
return_last_state=False,
|
146 |
+
):
|
147 |
+
"""if return_last_state is True, returns (out, last_state)
|
148 |
+
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
149 |
+
not considered in the backward pass.
|
150 |
+
"""
|
151 |
+
return SelectiveScanFn.apply(
|
152 |
+
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
def selective_scan_ref(
|
157 |
+
u,
|
158 |
+
delta,
|
159 |
+
A,
|
160 |
+
B,
|
161 |
+
C,
|
162 |
+
D=None,
|
163 |
+
z=None,
|
164 |
+
delta_bias=None,
|
165 |
+
delta_softplus=False,
|
166 |
+
return_last_state=False,
|
167 |
+
):
|
168 |
+
"""
|
169 |
+
u: r(B D L)
|
170 |
+
delta: r(B D L)
|
171 |
+
A: c(D N) or r(D N)
|
172 |
+
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
173 |
+
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
174 |
+
D: r(D)
|
175 |
+
z: r(B D L)
|
176 |
+
delta_bias: r(D), fp32
|
177 |
+
|
178 |
+
out: r(B D L)
|
179 |
+
last_state (optional): r(B D dstate) or c(B D dstate)
|
180 |
+
"""
|
181 |
+
dtype_in = u.dtype
|
182 |
+
u = u.float()
|
183 |
+
delta = delta.float()
|
184 |
+
if delta_bias is not None:
|
185 |
+
delta = delta + delta_bias[..., None].float()
|
186 |
+
if delta_softplus:
|
187 |
+
delta = F.softplus(delta)
|
188 |
+
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
189 |
+
is_variable_B = B.dim() >= 3
|
190 |
+
is_variable_C = C.dim() >= 3
|
191 |
+
if A.is_complex():
|
192 |
+
if is_variable_B:
|
193 |
+
B = torch.view_as_complex(
|
194 |
+
rearrange(B.float(), "... (L two) -> ... L two", two=2)
|
195 |
+
)
|
196 |
+
if is_variable_C:
|
197 |
+
C = torch.view_as_complex(
|
198 |
+
rearrange(C.float(), "... (L two) -> ... L two", two=2)
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
B = B.float()
|
202 |
+
C = C.float()
|
203 |
+
x = A.new_zeros((batch, dim, dstate))
|
204 |
+
ys = []
|
205 |
+
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
206 |
+
if not is_variable_B:
|
207 |
+
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
208 |
+
else:
|
209 |
+
if B.dim() == 3:
|
210 |
+
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
211 |
+
else:
|
212 |
+
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
213 |
+
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
214 |
+
if is_variable_C and C.dim() == 4:
|
215 |
+
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
216 |
+
last_state = None
|
217 |
+
for i in range(u.shape[2]):
|
218 |
+
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
219 |
+
if not is_variable_C:
|
220 |
+
y = torch.einsum("bdn,dn->bd", x, C)
|
221 |
+
else:
|
222 |
+
if C.dim() == 3:
|
223 |
+
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
224 |
+
else:
|
225 |
+
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
226 |
+
if i == u.shape[2] - 1:
|
227 |
+
last_state = x
|
228 |
+
if y.is_complex():
|
229 |
+
y = y.real * 2
|
230 |
+
ys.append(y)
|
231 |
+
y = torch.stack(ys, dim=2) # (batch dim L)
|
232 |
+
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
233 |
+
if z is not None:
|
234 |
+
out = out * F.silu(z)
|
235 |
+
out = out.to(dtype=dtype_in)
|
236 |
+
return out if not return_last_state else (out, last_state)
|
237 |
+
|
238 |
+
|
239 |
+
class MambaInnerFn(torch.autograd.Function):
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
@custom_fwd
|
243 |
+
def forward(
|
244 |
+
ctx,
|
245 |
+
xz,
|
246 |
+
conv1d_weight,
|
247 |
+
conv1d_bias,
|
248 |
+
x_proj_weight,
|
249 |
+
delta_proj_weight,
|
250 |
+
out_proj_weight,
|
251 |
+
out_proj_bias,
|
252 |
+
A,
|
253 |
+
B=None,
|
254 |
+
C=None,
|
255 |
+
D=None,
|
256 |
+
delta_bias=None,
|
257 |
+
B_proj_bias=None,
|
258 |
+
C_proj_bias=None,
|
259 |
+
delta_softplus=True,
|
260 |
+
checkpoint_lvl=1,
|
261 |
+
b_rms_weight=None,
|
262 |
+
c_rms_weight=None,
|
263 |
+
dt_rms_weight=None,
|
264 |
+
b_c_dt_rms_eps=1e-6,
|
265 |
+
):
|
266 |
+
"""
|
267 |
+
xz: (batch, dim, seqlen)
|
268 |
+
"""
|
269 |
+
assert (
|
270 |
+
causal_conv1d_cuda is not None
|
271 |
+
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
272 |
+
assert checkpoint_lvl in [0, 1]
|
273 |
+
L = xz.shape[-1]
|
274 |
+
delta_rank = delta_proj_weight.shape[1]
|
275 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
276 |
+
if torch.is_autocast_enabled():
|
277 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
278 |
+
delta_proj_weight = delta_proj_weight.to(
|
279 |
+
dtype=torch.get_autocast_gpu_dtype()
|
280 |
+
)
|
281 |
+
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
282 |
+
out_proj_bias = (
|
283 |
+
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
284 |
+
if out_proj_bias is not None
|
285 |
+
else None
|
286 |
+
)
|
287 |
+
if xz.stride(-1) != 1:
|
288 |
+
xz = xz.contiguous()
|
289 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
290 |
+
x, z = xz.chunk(2, dim=1)
|
291 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
292 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
293 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
294 |
+
)
|
295 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
296 |
+
# We want delta to have d as the slowest moving dimension
|
297 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
298 |
+
x_dbl = F.linear(
|
299 |
+
rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
|
300 |
+
) # (bl d)
|
301 |
+
delta = rearrange(
|
302 |
+
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
303 |
+
)
|
304 |
+
ctx.is_variable_B = B is None
|
305 |
+
ctx.is_variable_C = C is None
|
306 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
307 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
308 |
+
if B is None: # variable B
|
309 |
+
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
|
310 |
+
if B_proj_bias is not None:
|
311 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
312 |
+
if not A.is_complex():
|
313 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
314 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
315 |
+
else:
|
316 |
+
B = rearrange(
|
317 |
+
B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
318 |
+
).contiguous()
|
319 |
+
else:
|
320 |
+
if B.stride(-1) != 1:
|
321 |
+
B = B.contiguous()
|
322 |
+
if C is None: # variable C
|
323 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
324 |
+
if C_proj_bias is not None:
|
325 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
326 |
+
if not A.is_complex():
|
327 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
328 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
329 |
+
else:
|
330 |
+
C = rearrange(
|
331 |
+
C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
332 |
+
).contiguous()
|
333 |
+
else:
|
334 |
+
if C.stride(-1) != 1:
|
335 |
+
C = C.contiguous()
|
336 |
+
if D is not None:
|
337 |
+
D = D.contiguous()
|
338 |
+
|
339 |
+
if b_rms_weight is not None:
|
340 |
+
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
341 |
+
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
342 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
343 |
+
if c_rms_weight is not None:
|
344 |
+
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
345 |
+
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
346 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
347 |
+
if dt_rms_weight is not None:
|
348 |
+
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
349 |
+
delta = rms_norm_forward(
|
350 |
+
delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
|
351 |
+
)
|
352 |
+
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
353 |
+
|
354 |
+
out, scan_intermediates, out_z = ops.selective_scan_fwd(
|
355 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
356 |
+
)
|
357 |
+
ctx.delta_softplus = delta_softplus
|
358 |
+
ctx.out_proj_bias_is_None = out_proj_bias is None
|
359 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
360 |
+
ctx.b_rms_weight = b_rms_weight
|
361 |
+
ctx.c_rms_weight = c_rms_weight
|
362 |
+
ctx.dt_rms_weight = dt_rms_weight
|
363 |
+
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
|
364 |
+
if (
|
365 |
+
checkpoint_lvl >= 1
|
366 |
+
): # Will recompute conv1d_out and delta in the backward pass
|
367 |
+
conv1d_out, delta = None, None
|
368 |
+
ctx.save_for_backward(
|
369 |
+
xz,
|
370 |
+
conv1d_weight,
|
371 |
+
conv1d_bias,
|
372 |
+
x_dbl,
|
373 |
+
x_proj_weight,
|
374 |
+
delta_proj_weight,
|
375 |
+
out_proj_weight,
|
376 |
+
conv1d_out,
|
377 |
+
delta,
|
378 |
+
A,
|
379 |
+
B,
|
380 |
+
C,
|
381 |
+
D,
|
382 |
+
delta_bias,
|
383 |
+
scan_intermediates,
|
384 |
+
b_rms_weight,
|
385 |
+
c_rms_weight,
|
386 |
+
dt_rms_weight,
|
387 |
+
out,
|
388 |
+
)
|
389 |
+
return F.linear(
|
390 |
+
rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
|
391 |
+
)
|
392 |
+
|
393 |
+
@staticmethod
|
394 |
+
@custom_bwd
|
395 |
+
def backward(ctx, dout):
|
396 |
+
# dout: (batch, seqlen, dim)
|
397 |
+
assert (
|
398 |
+
causal_conv1d_cuda is not None
|
399 |
+
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
400 |
+
(
|
401 |
+
xz,
|
402 |
+
conv1d_weight,
|
403 |
+
conv1d_bias,
|
404 |
+
x_dbl,
|
405 |
+
x_proj_weight,
|
406 |
+
delta_proj_weight,
|
407 |
+
out_proj_weight,
|
408 |
+
conv1d_out,
|
409 |
+
delta,
|
410 |
+
A,
|
411 |
+
B,
|
412 |
+
C,
|
413 |
+
D,
|
414 |
+
delta_bias,
|
415 |
+
scan_intermediates,
|
416 |
+
b_rms_weight,
|
417 |
+
c_rms_weight,
|
418 |
+
dt_rms_weight,
|
419 |
+
out,
|
420 |
+
) = ctx.saved_tensors
|
421 |
+
L = xz.shape[-1]
|
422 |
+
delta_rank = delta_proj_weight.shape[1]
|
423 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
424 |
+
x, z = xz.chunk(2, dim=1)
|
425 |
+
if dout.stride(-1) != 1:
|
426 |
+
dout = dout.contiguous()
|
427 |
+
if ctx.checkpoint_lvl == 1:
|
428 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
429 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
430 |
+
)
|
431 |
+
delta = rearrange(
|
432 |
+
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
433 |
+
)
|
434 |
+
if dt_rms_weight is not None:
|
435 |
+
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
436 |
+
delta = rms_norm_forward(
|
437 |
+
delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
|
438 |
+
)
|
439 |
+
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
440 |
+
if b_rms_weight is not None:
|
441 |
+
# Recompute & RMSNorm B
|
442 |
+
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
443 |
+
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
|
444 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
445 |
+
if c_rms_weight is not None:
|
446 |
+
# Recompute & RMSNorm C
|
447 |
+
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
448 |
+
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
|
449 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
450 |
+
|
451 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
452 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
453 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
454 |
+
dx, dz = dxz.chunk(2, dim=1)
|
455 |
+
dout = rearrange(dout, "b l e -> e (b l)")
|
456 |
+
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
457 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
|
458 |
+
ops.selective_scan_bwd(
|
459 |
+
conv1d_out,
|
460 |
+
delta,
|
461 |
+
A,
|
462 |
+
B,
|
463 |
+
C,
|
464 |
+
D,
|
465 |
+
z,
|
466 |
+
delta_bias,
|
467 |
+
dout_y,
|
468 |
+
scan_intermediates,
|
469 |
+
out,
|
470 |
+
dz,
|
471 |
+
ctx.delta_softplus,
|
472 |
+
True, # option to recompute out_z
|
473 |
+
)
|
474 |
+
)
|
475 |
+
dout_proj_weight = torch.einsum(
|
476 |
+
"eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
|
477 |
+
)
|
478 |
+
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
479 |
+
dD = dD if D is not None else None
|
480 |
+
dx_dbl = torch.empty_like(x_dbl)
|
481 |
+
dB_proj_bias = None
|
482 |
+
if ctx.is_variable_B:
|
483 |
+
if not A.is_complex():
|
484 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
485 |
+
else:
|
486 |
+
dB = rearrange(
|
487 |
+
dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
488 |
+
).contiguous()
|
489 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
490 |
+
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
|
491 |
+
dB = None
|
492 |
+
dC_proj_bias = None
|
493 |
+
if ctx.is_variable_C:
|
494 |
+
if not A.is_complex():
|
495 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
496 |
+
else:
|
497 |
+
dC = rearrange(
|
498 |
+
dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
499 |
+
).contiguous()
|
500 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
501 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
502 |
+
dC = None
|
503 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
504 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
505 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
506 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
507 |
+
dx_proj_weight = torch.einsum(
|
508 |
+
"Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
|
509 |
+
)
|
510 |
+
dconv1d_out = torch.addmm(
|
511 |
+
dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
|
512 |
+
)
|
513 |
+
dconv1d_out = rearrange(
|
514 |
+
dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
|
515 |
+
)
|
516 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
517 |
+
# backward of conv1d with the backward of chunk).
|
518 |
+
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
519 |
+
x,
|
520 |
+
conv1d_weight,
|
521 |
+
conv1d_bias,
|
522 |
+
dconv1d_out,
|
523 |
+
None,
|
524 |
+
None,
|
525 |
+
None,
|
526 |
+
dx,
|
527 |
+
False,
|
528 |
+
True,
|
529 |
+
)
|
530 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
531 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
532 |
+
return (
|
533 |
+
dxz,
|
534 |
+
dconv1d_weight,
|
535 |
+
dconv1d_bias,
|
536 |
+
dx_proj_weight,
|
537 |
+
ddelta_proj_weight,
|
538 |
+
dout_proj_weight,
|
539 |
+
dout_proj_bias,
|
540 |
+
dA,
|
541 |
+
dB,
|
542 |
+
dC,
|
543 |
+
dD,
|
544 |
+
ddelta_bias if delta_bias is not None else None,
|
545 |
+
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
|
546 |
+
dB_proj_bias,
|
547 |
+
dC_proj_bias,
|
548 |
+
None,
|
549 |
+
None,
|
550 |
+
None,
|
551 |
+
None,
|
552 |
+
None,
|
553 |
+
None,
|
554 |
+
)
|
555 |
+
|
556 |
+
|
557 |
+
def mamba_inner_fn(
|
558 |
+
xz,
|
559 |
+
conv1d_weight,
|
560 |
+
conv1d_bias,
|
561 |
+
x_proj_weight,
|
562 |
+
delta_proj_weight,
|
563 |
+
out_proj_weight,
|
564 |
+
out_proj_bias,
|
565 |
+
A,
|
566 |
+
B=None,
|
567 |
+
C=None,
|
568 |
+
D=None,
|
569 |
+
delta_bias=None,
|
570 |
+
B_proj_bias=None,
|
571 |
+
C_proj_bias=None,
|
572 |
+
delta_softplus=True,
|
573 |
+
checkpoint_lvl=1,
|
574 |
+
b_rms_weight=None,
|
575 |
+
c_rms_weight=None,
|
576 |
+
dt_rms_weight=None,
|
577 |
+
b_c_dt_rms_eps=1e-6,
|
578 |
+
):
|
579 |
+
return MambaInnerFn.apply(
|
580 |
+
xz,
|
581 |
+
conv1d_weight,
|
582 |
+
conv1d_bias,
|
583 |
+
x_proj_weight,
|
584 |
+
delta_proj_weight,
|
585 |
+
out_proj_weight,
|
586 |
+
out_proj_bias,
|
587 |
+
A,
|
588 |
+
B,
|
589 |
+
C,
|
590 |
+
D,
|
591 |
+
delta_bias,
|
592 |
+
B_proj_bias,
|
593 |
+
C_proj_bias,
|
594 |
+
delta_softplus,
|
595 |
+
checkpoint_lvl,
|
596 |
+
b_rms_weight,
|
597 |
+
c_rms_weight,
|
598 |
+
dt_rms_weight,
|
599 |
+
b_c_dt_rms_eps,
|
600 |
+
)
|
601 |
+
|
602 |
+
|
603 |
+
def mamba_inner_ref(
|
604 |
+
xz,
|
605 |
+
conv1d_weight,
|
606 |
+
conv1d_bias,
|
607 |
+
x_proj_weight,
|
608 |
+
delta_proj_weight,
|
609 |
+
out_proj_weight,
|
610 |
+
out_proj_bias,
|
611 |
+
A,
|
612 |
+
B=None,
|
613 |
+
C=None,
|
614 |
+
D=None,
|
615 |
+
delta_bias=None,
|
616 |
+
B_proj_bias=None,
|
617 |
+
C_proj_bias=None,
|
618 |
+
delta_softplus=True,
|
619 |
+
):
|
620 |
+
assert (
|
621 |
+
causal_conv1d_fn is not None
|
622 |
+
), "causal_conv1d_fn is not available. Please install causal-conv1d."
|
623 |
+
L = xz.shape[-1]
|
624 |
+
delta_rank = delta_proj_weight.shape[1]
|
625 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
626 |
+
x, z = xz.chunk(2, dim=1)
|
627 |
+
x = causal_conv1d_fn(
|
628 |
+
x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
|
629 |
+
)
|
630 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
631 |
+
# We want delta to have d as the slowest moving dimension
|
632 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
633 |
+
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
|
634 |
+
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
635 |
+
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
636 |
+
if B is None: # variable B
|
637 |
+
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
|
638 |
+
if B_proj_bias is not None:
|
639 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
640 |
+
if not A.is_complex():
|
641 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
642 |
+
else:
|
643 |
+
B = rearrange(
|
644 |
+
B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
645 |
+
).contiguous()
|
646 |
+
if C is None: # variable B
|
647 |
+
C = x_dbl[:, -d_state:] # (bl d)
|
648 |
+
if C_proj_bias is not None:
|
649 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
650 |
+
if not A.is_complex():
|
651 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
652 |
+
else:
|
653 |
+
C = rearrange(
|
654 |
+
C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
655 |
+
).contiguous()
|
656 |
+
y = selective_scan_fn(
|
657 |
+
x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
|
658 |
+
)
|
659 |
+
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/k_activations.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
|
9 |
+
@triton.autotune(
|
10 |
+
configs=[
|
11 |
+
triton.Config({'BLOCK_N': 32}),
|
12 |
+
triton.Config({'BLOCK_N': 64}),
|
13 |
+
triton.Config({'BLOCK_N': 128}),
|
14 |
+
triton.Config({'BLOCK_N': 256}),
|
15 |
+
triton.Config({'BLOCK_N': 512}),
|
16 |
+
triton.Config({'BLOCK_N': 1024}),
|
17 |
+
],
|
18 |
+
key=['ncols'],
|
19 |
+
)
|
20 |
+
@triton.jit
|
21 |
+
def _swiglu_fwd_kernel(
|
22 |
+
X,
|
23 |
+
Y,
|
24 |
+
OUT,
|
25 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
26 |
+
stride_y_row,
|
27 |
+
stride_out_row,
|
28 |
+
ncols,
|
29 |
+
BLOCK_N: tl.constexpr,
|
30 |
+
):
|
31 |
+
# Map the program id to the row of X and Y it should compute.
|
32 |
+
row = tl.program_id(0)
|
33 |
+
start_col = tl.program_id(1) * BLOCK_N
|
34 |
+
X += row * stride_x_row
|
35 |
+
Y += row * stride_y_row
|
36 |
+
OUT += row * stride_out_row
|
37 |
+
cols = start_col + tl.arange(0, BLOCK_N)
|
38 |
+
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
39 |
+
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
40 |
+
out = x * tl.sigmoid(x) * y
|
41 |
+
tl.store(OUT + cols, out, mask=cols < ncols)
|
42 |
+
|
43 |
+
|
44 |
+
def _swiglu_fwd(xy, out=None):
|
45 |
+
if xy.stride(-1) != 1:
|
46 |
+
xy = xy.contiguous()
|
47 |
+
batch_shape = xy.shape[:-1]
|
48 |
+
xy = xy.reshape(-1, xy.shape[-1])
|
49 |
+
x, y = xy.chunk(2, dim=-1)
|
50 |
+
if out is None:
|
51 |
+
out = torch.empty_like(x)
|
52 |
+
else:
|
53 |
+
out = out.reshape(-1, out.shape[-1])
|
54 |
+
assert out.shape == x.shape
|
55 |
+
assert out.stride(-1) == 1
|
56 |
+
M, N = x.shape
|
57 |
+
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
58 |
+
with torch.cuda.device(x.device.index):
|
59 |
+
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
|
60 |
+
return out.reshape(*batch_shape, out.shape[-1])
|
61 |
+
|
62 |
+
|
63 |
+
@triton.autotune(
|
64 |
+
configs=[
|
65 |
+
triton.Config({'BLOCK_N': 32}),
|
66 |
+
triton.Config({'BLOCK_N': 64}),
|
67 |
+
triton.Config({'BLOCK_N': 128}),
|
68 |
+
triton.Config({'BLOCK_N': 256}),
|
69 |
+
triton.Config({'BLOCK_N': 512}),
|
70 |
+
triton.Config({'BLOCK_N': 1024}),
|
71 |
+
],
|
72 |
+
key=['ncols'],
|
73 |
+
)
|
74 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
|
75 |
+
@triton.jit
|
76 |
+
def _swiglu_bwd_kernel(
|
77 |
+
X,
|
78 |
+
Y,
|
79 |
+
DOUT,
|
80 |
+
OUT,
|
81 |
+
DX,
|
82 |
+
DY,
|
83 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
84 |
+
stride_y_row,
|
85 |
+
stride_dout_row,
|
86 |
+
stride_out_row,
|
87 |
+
stride_dx_row,
|
88 |
+
stride_dy_row,
|
89 |
+
ncols,
|
90 |
+
BLOCK_N: tl.constexpr,
|
91 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
92 |
+
):
|
93 |
+
# Map the program id to the row of X and Y it should compute.
|
94 |
+
row = tl.program_id(0)
|
95 |
+
start_col = tl.program_id(1) * BLOCK_N
|
96 |
+
X += row * stride_x_row
|
97 |
+
Y += row * stride_y_row
|
98 |
+
DOUT += row * stride_dout_row
|
99 |
+
if RECOMPUTE_OUTPUT:
|
100 |
+
OUT += row * stride_out_row
|
101 |
+
DX += row * stride_dx_row
|
102 |
+
DY += row * stride_dy_row
|
103 |
+
cols = start_col + tl.arange(0, BLOCK_N)
|
104 |
+
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
105 |
+
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
106 |
+
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
107 |
+
x_sigmoid = tl.sigmoid(x)
|
108 |
+
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
|
109 |
+
dy = x * x_sigmoid * dout
|
110 |
+
tl.store(DX + cols, dx, mask=cols < ncols)
|
111 |
+
tl.store(DY + cols, dy, mask=cols < ncols)
|
112 |
+
if RECOMPUTE_OUTPUT:
|
113 |
+
out = x * x_sigmoid * y
|
114 |
+
tl.store(OUT + cols, out, mask=cols < ncols)
|
115 |
+
|
116 |
+
|
117 |
+
def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
|
118 |
+
if xy.stride(-1) != 1:
|
119 |
+
xy = xy.contiguous()
|
120 |
+
if dout.stride(-1) != 1:
|
121 |
+
dout = dout.contiguous()
|
122 |
+
batch_shape = xy.shape[:-1]
|
123 |
+
xy = xy.reshape(-1, xy.shape[-1])
|
124 |
+
x, y = xy.chunk(2, dim=-1)
|
125 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
126 |
+
assert dout.shape == x.shape
|
127 |
+
if dxy is None:
|
128 |
+
dxy = torch.empty_like(xy)
|
129 |
+
else:
|
130 |
+
dxy = dxy.reshape(-1, dxy.shape[-1])
|
131 |
+
assert dxy.shape == xy.shape
|
132 |
+
dx, dy = dxy.chunk(2, dim=-1)
|
133 |
+
assert dx.stride(-1) == 1
|
134 |
+
assert dy.stride(-1) == 1
|
135 |
+
if recompute_output:
|
136 |
+
if out is None:
|
137 |
+
out = torch.empty_like(x)
|
138 |
+
else:
|
139 |
+
out = out.reshape(-1, out.shape[-1])
|
140 |
+
assert out.shape == x.shape
|
141 |
+
assert out.stride(-1) == 1
|
142 |
+
M, N = x.shape
|
143 |
+
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
144 |
+
with torch.cuda.device(x.device.index):
|
145 |
+
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
|
146 |
+
x.stride(0), y.stride(0), dout.stride(0),
|
147 |
+
out.stride(0) if recompute_output else 0,
|
148 |
+
dx.stride(0), dy.stride(0),
|
149 |
+
N)
|
150 |
+
if not recompute_output:
|
151 |
+
return dxy.reshape(*batch_shape, dxy.shape[-1])
|
152 |
+
else:
|
153 |
+
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
|
154 |
+
|
155 |
+
|
156 |
+
class SwiGLU(torch.autograd.Function):
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def forward(ctx, xy):
|
160 |
+
ctx.save_for_backward(xy)
|
161 |
+
return _swiglu_fwd(xy)
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def backward(ctx, dout):
|
165 |
+
xy, = ctx.saved_tensors
|
166 |
+
return _swiglu_bwd(xy, dout)
|
167 |
+
|
168 |
+
|
169 |
+
swiglu = SwiGLU.apply
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layer_norm.py
ADDED
@@ -0,0 +1,1166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# Implement dropout + residual + layer_norm / rms_norm.
|
3 |
+
|
4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
8 |
+
|
9 |
+
import math
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from ...utils.torch import custom_bwd, custom_fwd
|
15 |
+
|
16 |
+
import triton
|
17 |
+
import triton.language as tl
|
18 |
+
|
19 |
+
|
20 |
+
def layer_norm_ref(
|
21 |
+
x,
|
22 |
+
weight,
|
23 |
+
bias,
|
24 |
+
residual=None,
|
25 |
+
x1=None,
|
26 |
+
weight1=None,
|
27 |
+
bias1=None,
|
28 |
+
eps=1e-6,
|
29 |
+
dropout_p=0.0,
|
30 |
+
rowscale=None,
|
31 |
+
prenorm=False,
|
32 |
+
dropout_mask=None,
|
33 |
+
dropout_mask1=None,
|
34 |
+
upcast=False,
|
35 |
+
):
|
36 |
+
dtype = x.dtype
|
37 |
+
if upcast:
|
38 |
+
x = x.float()
|
39 |
+
weight = weight.float()
|
40 |
+
bias = bias.float() if bias is not None else None
|
41 |
+
residual = residual.float() if residual is not None else residual
|
42 |
+
x1 = x1.float() if x1 is not None else None
|
43 |
+
weight1 = weight1.float() if weight1 is not None else None
|
44 |
+
bias1 = bias1.float() if bias1 is not None else None
|
45 |
+
if x1 is not None:
|
46 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
47 |
+
if rowscale is not None:
|
48 |
+
x = x * rowscale[..., None]
|
49 |
+
if dropout_p > 0.0:
|
50 |
+
if dropout_mask is not None:
|
51 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
52 |
+
else:
|
53 |
+
x = F.dropout(x, p=dropout_p)
|
54 |
+
if x1 is not None:
|
55 |
+
if dropout_mask1 is not None:
|
56 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
57 |
+
else:
|
58 |
+
x1 = F.dropout(x1, p=dropout_p)
|
59 |
+
if x1 is not None:
|
60 |
+
x = x + x1
|
61 |
+
if residual is not None:
|
62 |
+
x = (x + residual).to(x.dtype)
|
63 |
+
out = F.layer_norm(
|
64 |
+
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
65 |
+
).to(dtype)
|
66 |
+
if weight1 is None:
|
67 |
+
return out if not prenorm else (out, x)
|
68 |
+
else:
|
69 |
+
out1 = F.layer_norm(
|
70 |
+
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
71 |
+
).to(dtype)
|
72 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
73 |
+
|
74 |
+
|
75 |
+
def rms_norm_ref(
|
76 |
+
x,
|
77 |
+
weight,
|
78 |
+
bias,
|
79 |
+
residual=None,
|
80 |
+
x1=None,
|
81 |
+
weight1=None,
|
82 |
+
bias1=None,
|
83 |
+
eps=1e-6,
|
84 |
+
dropout_p=0.0,
|
85 |
+
rowscale=None,
|
86 |
+
prenorm=False,
|
87 |
+
dropout_mask=None,
|
88 |
+
dropout_mask1=None,
|
89 |
+
upcast=False,
|
90 |
+
):
|
91 |
+
dtype = x.dtype
|
92 |
+
if upcast:
|
93 |
+
x = x.float()
|
94 |
+
weight = weight.float()
|
95 |
+
bias = bias.float() if bias is not None else None
|
96 |
+
residual = residual.float() if residual is not None else residual
|
97 |
+
x1 = x1.float() if x1 is not None else None
|
98 |
+
weight1 = weight1.float() if weight1 is not None else None
|
99 |
+
bias1 = bias1.float() if bias1 is not None else None
|
100 |
+
if x1 is not None:
|
101 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
102 |
+
if rowscale is not None:
|
103 |
+
x = x * rowscale[..., None]
|
104 |
+
if dropout_p > 0.0:
|
105 |
+
if dropout_mask is not None:
|
106 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
107 |
+
else:
|
108 |
+
x = F.dropout(x, p=dropout_p)
|
109 |
+
if x1 is not None:
|
110 |
+
if dropout_mask1 is not None:
|
111 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
112 |
+
else:
|
113 |
+
x1 = F.dropout(x1, p=dropout_p)
|
114 |
+
if x1 is not None:
|
115 |
+
x = x + x1
|
116 |
+
if residual is not None:
|
117 |
+
x = (x + residual).to(x.dtype)
|
118 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
119 |
+
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
|
120 |
+
dtype
|
121 |
+
)
|
122 |
+
if weight1 is None:
|
123 |
+
return out if not prenorm else (out, x)
|
124 |
+
else:
|
125 |
+
out1 = (
|
126 |
+
(x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
|
127 |
+
).to(dtype)
|
128 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
129 |
+
|
130 |
+
|
131 |
+
def config_prune(configs):
|
132 |
+
|
133 |
+
if torch.version.hip:
|
134 |
+
try:
|
135 |
+
# set warp size based on gcn architecure
|
136 |
+
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
|
137 |
+
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
|
138 |
+
# radeon
|
139 |
+
warp_size = 32
|
140 |
+
else:
|
141 |
+
# instinct
|
142 |
+
warp_size = 64
|
143 |
+
except AttributeError as e:
|
144 |
+
# fall back to crude method to set warp size
|
145 |
+
device_name = torch.cuda.get_device_properties(0).name
|
146 |
+
if "instinct" in device_name.lower():
|
147 |
+
warp_size = 64
|
148 |
+
else:
|
149 |
+
warp_size = 32
|
150 |
+
warnings.warn(
|
151 |
+
f"{e}, warp size set to {warp_size} based on device name: {device_name}",
|
152 |
+
UserWarning,
|
153 |
+
)
|
154 |
+
|
155 |
+
else:
|
156 |
+
# cuda
|
157 |
+
warp_size = 32
|
158 |
+
|
159 |
+
max_block_sz = 1024
|
160 |
+
max_num_warps = max_block_sz // warp_size
|
161 |
+
pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
|
162 |
+
return pruned_configs
|
163 |
+
|
164 |
+
|
165 |
+
configs_autotune = [
|
166 |
+
triton.Config({}, num_warps=1),
|
167 |
+
triton.Config({}, num_warps=2),
|
168 |
+
triton.Config({}, num_warps=4),
|
169 |
+
triton.Config({}, num_warps=8),
|
170 |
+
triton.Config({}, num_warps=16),
|
171 |
+
triton.Config({}, num_warps=32),
|
172 |
+
]
|
173 |
+
|
174 |
+
pruned_configs_autotune = config_prune(configs_autotune)
|
175 |
+
|
176 |
+
|
177 |
+
@triton.autotune(
|
178 |
+
configs=pruned_configs_autotune,
|
179 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
180 |
+
)
|
181 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
182 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
183 |
+
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
184 |
+
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
185 |
+
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
186 |
+
@triton.jit
|
187 |
+
def _layer_norm_fwd_1pass_kernel(
|
188 |
+
X, # pointer to the input
|
189 |
+
Y, # pointer to the output
|
190 |
+
W, # pointer to the weights
|
191 |
+
B, # pointer to the biases
|
192 |
+
RESIDUAL, # pointer to the residual
|
193 |
+
X1,
|
194 |
+
W1,
|
195 |
+
B1,
|
196 |
+
Y1,
|
197 |
+
RESIDUAL_OUT, # pointer to the residual
|
198 |
+
ROWSCALE,
|
199 |
+
SEEDS, # Dropout seeds for each row
|
200 |
+
DROPOUT_MASK,
|
201 |
+
Mean, # pointer to the mean
|
202 |
+
Rstd, # pointer to the 1/std
|
203 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
204 |
+
stride_y_row,
|
205 |
+
stride_res_row,
|
206 |
+
stride_res_out_row,
|
207 |
+
stride_x1_row,
|
208 |
+
stride_y1_row,
|
209 |
+
M, # number of rows in X
|
210 |
+
N, # number of columns in X
|
211 |
+
eps, # epsilon to avoid division by zero
|
212 |
+
dropout_p, # Dropout probability
|
213 |
+
IS_RMS_NORM: tl.constexpr,
|
214 |
+
BLOCK_N: tl.constexpr,
|
215 |
+
HAS_RESIDUAL: tl.constexpr,
|
216 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
217 |
+
HAS_BIAS: tl.constexpr,
|
218 |
+
HAS_DROPOUT: tl.constexpr,
|
219 |
+
STORE_DROPOUT_MASK: tl.constexpr,
|
220 |
+
HAS_ROWSCALE: tl.constexpr,
|
221 |
+
HAS_X1: tl.constexpr,
|
222 |
+
HAS_W1: tl.constexpr,
|
223 |
+
HAS_B1: tl.constexpr,
|
224 |
+
):
|
225 |
+
# Map the program id to the row of X and Y it should compute.
|
226 |
+
row = tl.program_id(0)
|
227 |
+
X += row * stride_x_row
|
228 |
+
Y += row * stride_y_row
|
229 |
+
if HAS_RESIDUAL:
|
230 |
+
RESIDUAL += row * stride_res_row
|
231 |
+
if STORE_RESIDUAL_OUT:
|
232 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
233 |
+
if HAS_X1:
|
234 |
+
X1 += row * stride_x1_row
|
235 |
+
if HAS_W1:
|
236 |
+
Y1 += row * stride_y1_row
|
237 |
+
# Compute mean and variance
|
238 |
+
cols = tl.arange(0, BLOCK_N)
|
239 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
240 |
+
if HAS_ROWSCALE:
|
241 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
242 |
+
x *= rowscale
|
243 |
+
if HAS_DROPOUT:
|
244 |
+
# Compute dropout mask
|
245 |
+
# 7 rounds is good enough, and reduces register pressure
|
246 |
+
keep_mask = (
|
247 |
+
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
248 |
+
)
|
249 |
+
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
250 |
+
if STORE_DROPOUT_MASK:
|
251 |
+
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
252 |
+
if HAS_X1:
|
253 |
+
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
254 |
+
if HAS_ROWSCALE:
|
255 |
+
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
256 |
+
x1 *= rowscale
|
257 |
+
if HAS_DROPOUT:
|
258 |
+
# Compute dropout mask
|
259 |
+
# 7 rounds is good enough, and reduces register pressure
|
260 |
+
keep_mask = (
|
261 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
262 |
+
> dropout_p
|
263 |
+
)
|
264 |
+
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
265 |
+
if STORE_DROPOUT_MASK:
|
266 |
+
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
267 |
+
x += x1
|
268 |
+
if HAS_RESIDUAL:
|
269 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
270 |
+
x += residual
|
271 |
+
if STORE_RESIDUAL_OUT:
|
272 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
273 |
+
if not IS_RMS_NORM:
|
274 |
+
mean = tl.sum(x, axis=0) / N
|
275 |
+
tl.store(Mean + row, mean)
|
276 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
277 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
278 |
+
else:
|
279 |
+
xbar = tl.where(cols < N, x, 0.0)
|
280 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
281 |
+
rstd = 1 / tl.sqrt(var + eps)
|
282 |
+
tl.store(Rstd + row, rstd)
|
283 |
+
# Normalize and apply linear transformation
|
284 |
+
mask = cols < N
|
285 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
286 |
+
if HAS_BIAS:
|
287 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
288 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
289 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
290 |
+
# Write output
|
291 |
+
tl.store(Y + cols, y, mask=mask)
|
292 |
+
if HAS_W1:
|
293 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
294 |
+
if HAS_B1:
|
295 |
+
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
296 |
+
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
297 |
+
tl.store(Y1 + cols, y1, mask=mask)
|
298 |
+
|
299 |
+
|
300 |
+
def _layer_norm_fwd(
|
301 |
+
x,
|
302 |
+
weight,
|
303 |
+
bias,
|
304 |
+
eps,
|
305 |
+
residual=None,
|
306 |
+
x1=None,
|
307 |
+
weight1=None,
|
308 |
+
bias1=None,
|
309 |
+
dropout_p=0.0,
|
310 |
+
rowscale=None,
|
311 |
+
out_dtype=None,
|
312 |
+
residual_dtype=None,
|
313 |
+
is_rms_norm=False,
|
314 |
+
return_dropout_mask=False,
|
315 |
+
):
|
316 |
+
if residual is not None:
|
317 |
+
residual_dtype = residual.dtype
|
318 |
+
M, N = x.shape
|
319 |
+
assert x.stride(-1) == 1
|
320 |
+
if residual is not None:
|
321 |
+
assert residual.stride(-1) == 1
|
322 |
+
assert residual.shape == (M, N)
|
323 |
+
assert weight.shape == (N,)
|
324 |
+
assert weight.stride(-1) == 1
|
325 |
+
if bias is not None:
|
326 |
+
assert bias.stride(-1) == 1
|
327 |
+
assert bias.shape == (N,)
|
328 |
+
if x1 is not None:
|
329 |
+
assert x1.shape == x.shape
|
330 |
+
assert rowscale is None
|
331 |
+
assert x1.stride(-1) == 1
|
332 |
+
if weight1 is not None:
|
333 |
+
assert weight1.shape == (N,)
|
334 |
+
assert weight1.stride(-1) == 1
|
335 |
+
if bias1 is not None:
|
336 |
+
assert bias1.shape == (N,)
|
337 |
+
assert bias1.stride(-1) == 1
|
338 |
+
if rowscale is not None:
|
339 |
+
assert rowscale.is_contiguous()
|
340 |
+
assert rowscale.shape == (M,)
|
341 |
+
# allocate output
|
342 |
+
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
343 |
+
assert y.stride(-1) == 1
|
344 |
+
if weight1 is not None:
|
345 |
+
y1 = torch.empty_like(y)
|
346 |
+
assert y1.stride(-1) == 1
|
347 |
+
else:
|
348 |
+
y1 = None
|
349 |
+
if (
|
350 |
+
residual is not None
|
351 |
+
or (residual_dtype is not None and residual_dtype != x.dtype)
|
352 |
+
or dropout_p > 0.0
|
353 |
+
or rowscale is not None
|
354 |
+
or x1 is not None
|
355 |
+
):
|
356 |
+
residual_out = torch.empty(
|
357 |
+
M,
|
358 |
+
N,
|
359 |
+
device=x.device,
|
360 |
+
dtype=residual_dtype if residual_dtype is not None else x.dtype,
|
361 |
+
)
|
362 |
+
assert residual_out.stride(-1) == 1
|
363 |
+
else:
|
364 |
+
residual_out = None
|
365 |
+
mean = (
|
366 |
+
torch.empty((M,), dtype=torch.float32, device=x.device)
|
367 |
+
if not is_rms_norm
|
368 |
+
else None
|
369 |
+
)
|
370 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
371 |
+
if dropout_p > 0.0:
|
372 |
+
seeds = torch.randint(
|
373 |
+
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
seeds = None
|
377 |
+
if return_dropout_mask and dropout_p > 0.0:
|
378 |
+
dropout_mask = torch.empty(
|
379 |
+
M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
dropout_mask = None
|
383 |
+
# Less than 64KB per feature: enqueue fused kernel
|
384 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
385 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
386 |
+
if N > BLOCK_N:
|
387 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
388 |
+
with torch.cuda.device(x.device.index):
|
389 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
390 |
+
x,
|
391 |
+
y,
|
392 |
+
weight,
|
393 |
+
bias,
|
394 |
+
residual,
|
395 |
+
x1,
|
396 |
+
weight1,
|
397 |
+
bias1,
|
398 |
+
y1,
|
399 |
+
residual_out,
|
400 |
+
rowscale,
|
401 |
+
seeds,
|
402 |
+
dropout_mask,
|
403 |
+
mean,
|
404 |
+
rstd,
|
405 |
+
x.stride(0),
|
406 |
+
y.stride(0),
|
407 |
+
residual.stride(0) if residual is not None else 0,
|
408 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
409 |
+
x1.stride(0) if x1 is not None else 0,
|
410 |
+
y1.stride(0) if y1 is not None else 0,
|
411 |
+
M,
|
412 |
+
N,
|
413 |
+
eps,
|
414 |
+
dropout_p,
|
415 |
+
is_rms_norm,
|
416 |
+
BLOCK_N,
|
417 |
+
residual is not None,
|
418 |
+
residual_out is not None,
|
419 |
+
bias is not None,
|
420 |
+
dropout_p > 0.0,
|
421 |
+
dropout_mask is not None,
|
422 |
+
rowscale is not None,
|
423 |
+
)
|
424 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
425 |
+
if dropout_mask is not None and x1 is not None:
|
426 |
+
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
427 |
+
else:
|
428 |
+
dropout_mask1 = None
|
429 |
+
return (
|
430 |
+
y,
|
431 |
+
y1,
|
432 |
+
mean,
|
433 |
+
rstd,
|
434 |
+
residual_out if residual_out is not None else x,
|
435 |
+
seeds,
|
436 |
+
dropout_mask,
|
437 |
+
dropout_mask1,
|
438 |
+
)
|
439 |
+
|
440 |
+
|
441 |
+
@triton.autotune(
|
442 |
+
configs=pruned_configs_autotune,
|
443 |
+
key=[
|
444 |
+
"N",
|
445 |
+
"HAS_DRESIDUAL",
|
446 |
+
"STORE_DRESIDUAL",
|
447 |
+
"IS_RMS_NORM",
|
448 |
+
"HAS_BIAS",
|
449 |
+
"HAS_DROPOUT",
|
450 |
+
],
|
451 |
+
)
|
452 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
453 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
454 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
455 |
+
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
456 |
+
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
457 |
+
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
458 |
+
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
459 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
460 |
+
@triton.jit
|
461 |
+
def _layer_norm_bwd_kernel(
|
462 |
+
X, # pointer to the input
|
463 |
+
W, # pointer to the weights
|
464 |
+
B, # pointer to the biases
|
465 |
+
Y, # pointer to the output to be recomputed
|
466 |
+
DY, # pointer to the output gradient
|
467 |
+
DX, # pointer to the input gradient
|
468 |
+
DW, # pointer to the partial sum of weights gradient
|
469 |
+
DB, # pointer to the partial sum of biases gradient
|
470 |
+
DRESIDUAL,
|
471 |
+
W1,
|
472 |
+
DY1,
|
473 |
+
DX1,
|
474 |
+
DW1,
|
475 |
+
DB1,
|
476 |
+
DRESIDUAL_IN,
|
477 |
+
ROWSCALE,
|
478 |
+
SEEDS,
|
479 |
+
Mean, # pointer to the mean
|
480 |
+
Rstd, # pointer to the 1/std
|
481 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
482 |
+
stride_y_row,
|
483 |
+
stride_dy_row,
|
484 |
+
stride_dx_row,
|
485 |
+
stride_dres_row,
|
486 |
+
stride_dy1_row,
|
487 |
+
stride_dx1_row,
|
488 |
+
stride_dres_in_row,
|
489 |
+
M, # number of rows in X
|
490 |
+
N, # number of columns in X
|
491 |
+
eps, # epsilon to avoid division by zero
|
492 |
+
dropout_p,
|
493 |
+
rows_per_program,
|
494 |
+
IS_RMS_NORM: tl.constexpr,
|
495 |
+
BLOCK_N: tl.constexpr,
|
496 |
+
HAS_DRESIDUAL: tl.constexpr,
|
497 |
+
STORE_DRESIDUAL: tl.constexpr,
|
498 |
+
HAS_BIAS: tl.constexpr,
|
499 |
+
HAS_DROPOUT: tl.constexpr,
|
500 |
+
HAS_ROWSCALE: tl.constexpr,
|
501 |
+
HAS_DY1: tl.constexpr,
|
502 |
+
HAS_DX1: tl.constexpr,
|
503 |
+
HAS_B1: tl.constexpr,
|
504 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
505 |
+
):
|
506 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
507 |
+
row_block_id = tl.program_id(0)
|
508 |
+
row_start = row_block_id * rows_per_program
|
509 |
+
# Do not early exit if row_start >= M, because we need to write DW and DB
|
510 |
+
cols = tl.arange(0, BLOCK_N)
|
511 |
+
mask = cols < N
|
512 |
+
X += row_start * stride_x_row
|
513 |
+
if HAS_DRESIDUAL:
|
514 |
+
DRESIDUAL += row_start * stride_dres_row
|
515 |
+
if STORE_DRESIDUAL:
|
516 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
517 |
+
DY += row_start * stride_dy_row
|
518 |
+
DX += row_start * stride_dx_row
|
519 |
+
if HAS_DY1:
|
520 |
+
DY1 += row_start * stride_dy1_row
|
521 |
+
if HAS_DX1:
|
522 |
+
DX1 += row_start * stride_dx1_row
|
523 |
+
if RECOMPUTE_OUTPUT:
|
524 |
+
Y += row_start * stride_y_row
|
525 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
526 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
527 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
528 |
+
if HAS_DY1:
|
529 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
530 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
531 |
+
if HAS_BIAS:
|
532 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
533 |
+
if HAS_DY1:
|
534 |
+
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
535 |
+
if HAS_B1:
|
536 |
+
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
537 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
538 |
+
for row in range(row_start, row_end):
|
539 |
+
# Load data to SRAM
|
540 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
541 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
542 |
+
if HAS_DY1:
|
543 |
+
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
544 |
+
if not IS_RMS_NORM:
|
545 |
+
mean = tl.load(Mean + row)
|
546 |
+
rstd = tl.load(Rstd + row)
|
547 |
+
# Compute dx
|
548 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
549 |
+
xhat = tl.where(mask, xhat, 0.0)
|
550 |
+
if RECOMPUTE_OUTPUT:
|
551 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
552 |
+
tl.store(Y + cols, y, mask=mask)
|
553 |
+
wdy = w * dy
|
554 |
+
dw += dy * xhat
|
555 |
+
if HAS_BIAS:
|
556 |
+
db += dy
|
557 |
+
if HAS_DY1:
|
558 |
+
wdy += w1 * dy1
|
559 |
+
dw1 += dy1 * xhat
|
560 |
+
if HAS_B1:
|
561 |
+
db1 += dy1
|
562 |
+
if not IS_RMS_NORM:
|
563 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
564 |
+
c2 = tl.sum(wdy, axis=0) / N
|
565 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
566 |
+
else:
|
567 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
568 |
+
dx = (wdy - xhat * c1) * rstd
|
569 |
+
if HAS_DRESIDUAL:
|
570 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
571 |
+
dx += dres
|
572 |
+
# Write dx
|
573 |
+
if STORE_DRESIDUAL:
|
574 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
575 |
+
if HAS_DX1:
|
576 |
+
if HAS_DROPOUT:
|
577 |
+
keep_mask = (
|
578 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
579 |
+
> dropout_p
|
580 |
+
)
|
581 |
+
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
582 |
+
else:
|
583 |
+
dx1 = dx
|
584 |
+
tl.store(DX1 + cols, dx1, mask=mask)
|
585 |
+
if HAS_DROPOUT:
|
586 |
+
keep_mask = (
|
587 |
+
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
|
588 |
+
> dropout_p
|
589 |
+
)
|
590 |
+
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
591 |
+
if HAS_ROWSCALE:
|
592 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
593 |
+
dx *= rowscale
|
594 |
+
tl.store(DX + cols, dx, mask=mask)
|
595 |
+
|
596 |
+
X += stride_x_row
|
597 |
+
if HAS_DRESIDUAL:
|
598 |
+
DRESIDUAL += stride_dres_row
|
599 |
+
if STORE_DRESIDUAL:
|
600 |
+
DRESIDUAL_IN += stride_dres_in_row
|
601 |
+
if RECOMPUTE_OUTPUT:
|
602 |
+
Y += stride_y_row
|
603 |
+
DY += stride_dy_row
|
604 |
+
DX += stride_dx_row
|
605 |
+
if HAS_DY1:
|
606 |
+
DY1 += stride_dy1_row
|
607 |
+
if HAS_DX1:
|
608 |
+
DX1 += stride_dx1_row
|
609 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
610 |
+
if HAS_BIAS:
|
611 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
612 |
+
if HAS_DY1:
|
613 |
+
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
614 |
+
if HAS_B1:
|
615 |
+
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
616 |
+
|
617 |
+
|
618 |
+
def _layer_norm_bwd(
|
619 |
+
dy,
|
620 |
+
x,
|
621 |
+
weight,
|
622 |
+
bias,
|
623 |
+
eps,
|
624 |
+
mean,
|
625 |
+
rstd,
|
626 |
+
dresidual=None,
|
627 |
+
dy1=None,
|
628 |
+
weight1=None,
|
629 |
+
bias1=None,
|
630 |
+
seeds=None,
|
631 |
+
dropout_p=0.0,
|
632 |
+
rowscale=None,
|
633 |
+
has_residual=False,
|
634 |
+
has_x1=False,
|
635 |
+
is_rms_norm=False,
|
636 |
+
x_dtype=None,
|
637 |
+
recompute_output=False,
|
638 |
+
):
|
639 |
+
M, N = x.shape
|
640 |
+
assert x.stride(-1) == 1
|
641 |
+
assert dy.stride(-1) == 1
|
642 |
+
assert dy.shape == (M, N)
|
643 |
+
if dresidual is not None:
|
644 |
+
assert dresidual.stride(-1) == 1
|
645 |
+
assert dresidual.shape == (M, N)
|
646 |
+
assert weight.shape == (N,)
|
647 |
+
assert weight.stride(-1) == 1
|
648 |
+
if bias is not None:
|
649 |
+
assert bias.stride(-1) == 1
|
650 |
+
assert bias.shape == (N,)
|
651 |
+
if dy1 is not None:
|
652 |
+
assert weight1 is not None
|
653 |
+
assert dy1.shape == dy.shape
|
654 |
+
assert dy1.stride(-1) == 1
|
655 |
+
if weight1 is not None:
|
656 |
+
assert weight1.shape == (N,)
|
657 |
+
assert weight1.stride(-1) == 1
|
658 |
+
if bias1 is not None:
|
659 |
+
assert bias1.shape == (N,)
|
660 |
+
assert bias1.stride(-1) == 1
|
661 |
+
if seeds is not None:
|
662 |
+
assert seeds.is_contiguous()
|
663 |
+
assert seeds.shape == (M if not has_x1 else M * 2,)
|
664 |
+
if rowscale is not None:
|
665 |
+
assert rowscale.is_contiguous()
|
666 |
+
assert rowscale.shape == (M,)
|
667 |
+
# allocate output
|
668 |
+
dx = (
|
669 |
+
torch.empty_like(x)
|
670 |
+
if x_dtype is None
|
671 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
672 |
+
)
|
673 |
+
dresidual_in = (
|
674 |
+
torch.empty_like(x)
|
675 |
+
if has_residual
|
676 |
+
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
677 |
+
else None
|
678 |
+
)
|
679 |
+
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
680 |
+
y = (
|
681 |
+
torch.empty(M, N, dtype=dy.dtype, device=dy.device)
|
682 |
+
if recompute_output
|
683 |
+
else None
|
684 |
+
)
|
685 |
+
if recompute_output:
|
686 |
+
assert (
|
687 |
+
weight1 is None
|
688 |
+
), "recompute_output is not supported with parallel LayerNorm"
|
689 |
+
|
690 |
+
# Less than 64KB per feature: enqueue fused kernel
|
691 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
692 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
693 |
+
if N > BLOCK_N:
|
694 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
695 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
696 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
697 |
+
_db = (
|
698 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
699 |
+
if bias is not None
|
700 |
+
else None
|
701 |
+
)
|
702 |
+
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
703 |
+
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
704 |
+
rows_per_program = math.ceil(M / sm_count)
|
705 |
+
grid = (sm_count,)
|
706 |
+
with torch.cuda.device(x.device.index):
|
707 |
+
_layer_norm_bwd_kernel[grid](
|
708 |
+
x,
|
709 |
+
weight,
|
710 |
+
bias,
|
711 |
+
y,
|
712 |
+
dy,
|
713 |
+
dx,
|
714 |
+
_dw,
|
715 |
+
_db,
|
716 |
+
dresidual,
|
717 |
+
weight1,
|
718 |
+
dy1,
|
719 |
+
dx1,
|
720 |
+
_dw1,
|
721 |
+
_db1,
|
722 |
+
dresidual_in,
|
723 |
+
rowscale,
|
724 |
+
seeds,
|
725 |
+
mean,
|
726 |
+
rstd,
|
727 |
+
x.stride(0),
|
728 |
+
0 if not recompute_output else y.stride(0),
|
729 |
+
dy.stride(0),
|
730 |
+
dx.stride(0),
|
731 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
732 |
+
dy1.stride(0) if dy1 is not None else 0,
|
733 |
+
dx1.stride(0) if dx1 is not None else 0,
|
734 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
735 |
+
M,
|
736 |
+
N,
|
737 |
+
eps,
|
738 |
+
dropout_p,
|
739 |
+
rows_per_program,
|
740 |
+
is_rms_norm,
|
741 |
+
BLOCK_N,
|
742 |
+
dresidual is not None,
|
743 |
+
dresidual_in is not None,
|
744 |
+
bias is not None,
|
745 |
+
dropout_p > 0.0,
|
746 |
+
)
|
747 |
+
dw = _dw.sum(0).to(weight.dtype)
|
748 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
749 |
+
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
750 |
+
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
751 |
+
# Don't need to compute dresidual_in separately in this case
|
752 |
+
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
753 |
+
dresidual_in = dx
|
754 |
+
if has_x1 and dropout_p == 0.0:
|
755 |
+
dx1 = dx
|
756 |
+
return (
|
757 |
+
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
758 |
+
if not recompute_output
|
759 |
+
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
760 |
+
)
|
761 |
+
|
762 |
+
|
763 |
+
class LayerNormFn(torch.autograd.Function):
|
764 |
+
@staticmethod
|
765 |
+
def forward(
|
766 |
+
ctx,
|
767 |
+
x,
|
768 |
+
weight,
|
769 |
+
bias,
|
770 |
+
residual=None,
|
771 |
+
x1=None,
|
772 |
+
weight1=None,
|
773 |
+
bias1=None,
|
774 |
+
eps=1e-6,
|
775 |
+
dropout_p=0.0,
|
776 |
+
rowscale=None,
|
777 |
+
prenorm=False,
|
778 |
+
residual_in_fp32=False,
|
779 |
+
is_rms_norm=False,
|
780 |
+
return_dropout_mask=False,
|
781 |
+
):
|
782 |
+
x_shape_og = x.shape
|
783 |
+
# reshape input data into 2D tensor
|
784 |
+
x = x.reshape(-1, x.shape[-1])
|
785 |
+
if x.stride(-1) != 1:
|
786 |
+
x = x.contiguous()
|
787 |
+
if residual is not None:
|
788 |
+
assert residual.shape == x_shape_og
|
789 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
790 |
+
if residual.stride(-1) != 1:
|
791 |
+
residual = residual.contiguous()
|
792 |
+
if x1 is not None:
|
793 |
+
assert x1.shape == x_shape_og
|
794 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
795 |
+
x1 = x1.reshape(-1, x1.shape[-1])
|
796 |
+
if x1.stride(-1) != 1:
|
797 |
+
x1 = x1.contiguous()
|
798 |
+
weight = weight.contiguous()
|
799 |
+
if bias is not None:
|
800 |
+
bias = bias.contiguous()
|
801 |
+
if weight1 is not None:
|
802 |
+
weight1 = weight1.contiguous()
|
803 |
+
if bias1 is not None:
|
804 |
+
bias1 = bias1.contiguous()
|
805 |
+
if rowscale is not None:
|
806 |
+
rowscale = rowscale.reshape(-1).contiguous()
|
807 |
+
residual_dtype = (
|
808 |
+
residual.dtype
|
809 |
+
if residual is not None
|
810 |
+
else (torch.float32 if residual_in_fp32 else None)
|
811 |
+
)
|
812 |
+
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
|
813 |
+
_layer_norm_fwd(
|
814 |
+
x,
|
815 |
+
weight,
|
816 |
+
bias,
|
817 |
+
eps,
|
818 |
+
residual,
|
819 |
+
x1,
|
820 |
+
weight1,
|
821 |
+
bias1,
|
822 |
+
dropout_p=dropout_p,
|
823 |
+
rowscale=rowscale,
|
824 |
+
residual_dtype=residual_dtype,
|
825 |
+
is_rms_norm=is_rms_norm,
|
826 |
+
return_dropout_mask=return_dropout_mask,
|
827 |
+
)
|
828 |
+
)
|
829 |
+
ctx.save_for_backward(
|
830 |
+
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
831 |
+
)
|
832 |
+
ctx.x_shape_og = x_shape_og
|
833 |
+
ctx.eps = eps
|
834 |
+
ctx.dropout_p = dropout_p
|
835 |
+
ctx.is_rms_norm = is_rms_norm
|
836 |
+
ctx.has_residual = residual is not None
|
837 |
+
ctx.has_x1 = x1 is not None
|
838 |
+
ctx.prenorm = prenorm
|
839 |
+
ctx.x_dtype = x.dtype
|
840 |
+
y = y.reshape(x_shape_og)
|
841 |
+
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
842 |
+
residual_out = (
|
843 |
+
residual_out.reshape(x_shape_og) if residual_out is not None else None
|
844 |
+
)
|
845 |
+
dropout_mask = (
|
846 |
+
dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
847 |
+
)
|
848 |
+
dropout_mask1 = (
|
849 |
+
dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
850 |
+
)
|
851 |
+
if not return_dropout_mask:
|
852 |
+
if weight1 is None:
|
853 |
+
return y if not prenorm else (y, residual_out)
|
854 |
+
else:
|
855 |
+
return (y, y1) if not prenorm else (y, y1, residual_out)
|
856 |
+
else:
|
857 |
+
if weight1 is None:
|
858 |
+
return (
|
859 |
+
(y, dropout_mask, dropout_mask1)
|
860 |
+
if not prenorm
|
861 |
+
else (y, residual_out, dropout_mask, dropout_mask1)
|
862 |
+
)
|
863 |
+
else:
|
864 |
+
return (
|
865 |
+
(y, y1, dropout_mask, dropout_mask1)
|
866 |
+
if not prenorm
|
867 |
+
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
868 |
+
)
|
869 |
+
|
870 |
+
@staticmethod
|
871 |
+
def backward(ctx, dy, *args):
|
872 |
+
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
873 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
874 |
+
if dy.stride(-1) != 1:
|
875 |
+
dy = dy.contiguous()
|
876 |
+
assert dy.shape == x.shape
|
877 |
+
if weight1 is not None:
|
878 |
+
dy1, args = args[0], args[1:]
|
879 |
+
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
880 |
+
if dy1.stride(-1) != 1:
|
881 |
+
dy1 = dy1.contiguous()
|
882 |
+
assert dy1.shape == x.shape
|
883 |
+
else:
|
884 |
+
dy1 = None
|
885 |
+
if ctx.prenorm:
|
886 |
+
dresidual = args[0]
|
887 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
888 |
+
if dresidual.stride(-1) != 1:
|
889 |
+
dresidual = dresidual.contiguous()
|
890 |
+
assert dresidual.shape == x.shape
|
891 |
+
else:
|
892 |
+
dresidual = None
|
893 |
+
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
894 |
+
dy,
|
895 |
+
x,
|
896 |
+
weight,
|
897 |
+
bias,
|
898 |
+
ctx.eps,
|
899 |
+
mean,
|
900 |
+
rstd,
|
901 |
+
dresidual,
|
902 |
+
dy1,
|
903 |
+
weight1,
|
904 |
+
bias1,
|
905 |
+
seeds,
|
906 |
+
ctx.dropout_p,
|
907 |
+
rowscale,
|
908 |
+
ctx.has_residual,
|
909 |
+
ctx.has_x1,
|
910 |
+
ctx.is_rms_norm,
|
911 |
+
x_dtype=ctx.x_dtype,
|
912 |
+
)
|
913 |
+
return (
|
914 |
+
dx.reshape(ctx.x_shape_og),
|
915 |
+
dw,
|
916 |
+
db,
|
917 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
918 |
+
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
919 |
+
dw1,
|
920 |
+
db1,
|
921 |
+
None,
|
922 |
+
None,
|
923 |
+
None,
|
924 |
+
None,
|
925 |
+
None,
|
926 |
+
None,
|
927 |
+
None,
|
928 |
+
)
|
929 |
+
|
930 |
+
|
931 |
+
def layer_norm_fn(
|
932 |
+
x,
|
933 |
+
weight,
|
934 |
+
bias,
|
935 |
+
residual=None,
|
936 |
+
x1=None,
|
937 |
+
weight1=None,
|
938 |
+
bias1=None,
|
939 |
+
eps=1e-6,
|
940 |
+
dropout_p=0.0,
|
941 |
+
rowscale=None,
|
942 |
+
prenorm=False,
|
943 |
+
residual_in_fp32=False,
|
944 |
+
is_rms_norm=False,
|
945 |
+
return_dropout_mask=False,
|
946 |
+
):
|
947 |
+
return LayerNormFn.apply(
|
948 |
+
x,
|
949 |
+
weight,
|
950 |
+
bias,
|
951 |
+
residual,
|
952 |
+
x1,
|
953 |
+
weight1,
|
954 |
+
bias1,
|
955 |
+
eps,
|
956 |
+
dropout_p,
|
957 |
+
rowscale,
|
958 |
+
prenorm,
|
959 |
+
residual_in_fp32,
|
960 |
+
is_rms_norm,
|
961 |
+
return_dropout_mask,
|
962 |
+
)
|
963 |
+
|
964 |
+
|
965 |
+
def rms_norm_fn(
|
966 |
+
x,
|
967 |
+
weight,
|
968 |
+
bias,
|
969 |
+
residual=None,
|
970 |
+
x1=None,
|
971 |
+
weight1=None,
|
972 |
+
bias1=None,
|
973 |
+
eps=1e-6,
|
974 |
+
dropout_p=0.0,
|
975 |
+
rowscale=None,
|
976 |
+
prenorm=False,
|
977 |
+
residual_in_fp32=False,
|
978 |
+
return_dropout_mask=False,
|
979 |
+
):
|
980 |
+
return LayerNormFn.apply(
|
981 |
+
x,
|
982 |
+
weight,
|
983 |
+
bias,
|
984 |
+
residual,
|
985 |
+
x1,
|
986 |
+
weight1,
|
987 |
+
bias1,
|
988 |
+
eps,
|
989 |
+
dropout_p,
|
990 |
+
rowscale,
|
991 |
+
prenorm,
|
992 |
+
residual_in_fp32,
|
993 |
+
True,
|
994 |
+
return_dropout_mask,
|
995 |
+
)
|
996 |
+
|
997 |
+
|
998 |
+
class RMSNorm(torch.nn.Module):
|
999 |
+
|
1000 |
+
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
1001 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
1002 |
+
super().__init__()
|
1003 |
+
self.eps = eps
|
1004 |
+
if dropout_p > 0.0:
|
1005 |
+
self.drop = torch.nn.Dropout(dropout_p)
|
1006 |
+
else:
|
1007 |
+
self.drop = None
|
1008 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
1009 |
+
self.register_parameter("bias", None)
|
1010 |
+
self.reset_parameters()
|
1011 |
+
|
1012 |
+
def reset_parameters(self):
|
1013 |
+
torch.nn.init.ones_(self.weight)
|
1014 |
+
|
1015 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
1016 |
+
return rms_norm_fn(
|
1017 |
+
x,
|
1018 |
+
self.weight,
|
1019 |
+
self.bias,
|
1020 |
+
residual=residual,
|
1021 |
+
eps=self.eps,
|
1022 |
+
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
1023 |
+
prenorm=prenorm,
|
1024 |
+
residual_in_fp32=residual_in_fp32,
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
|
1028 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
1029 |
+
@staticmethod
|
1030 |
+
@custom_fwd
|
1031 |
+
def forward(
|
1032 |
+
ctx,
|
1033 |
+
x,
|
1034 |
+
norm_weight,
|
1035 |
+
norm_bias,
|
1036 |
+
linear_weight,
|
1037 |
+
linear_bias,
|
1038 |
+
residual=None,
|
1039 |
+
eps=1e-6,
|
1040 |
+
prenorm=False,
|
1041 |
+
residual_in_fp32=False,
|
1042 |
+
is_rms_norm=False,
|
1043 |
+
):
|
1044 |
+
x_shape_og = x.shape
|
1045 |
+
# reshape input data into 2D tensor
|
1046 |
+
x = x.reshape(-1, x.shape[-1])
|
1047 |
+
if x.stride(-1) != 1:
|
1048 |
+
x = x.contiguous()
|
1049 |
+
if residual is not None:
|
1050 |
+
assert residual.shape == x_shape_og
|
1051 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
1052 |
+
if residual.stride(-1) != 1:
|
1053 |
+
residual = residual.contiguous()
|
1054 |
+
norm_weight = norm_weight.contiguous()
|
1055 |
+
if norm_bias is not None:
|
1056 |
+
norm_bias = norm_bias.contiguous()
|
1057 |
+
residual_dtype = (
|
1058 |
+
residual.dtype
|
1059 |
+
if residual is not None
|
1060 |
+
else (torch.float32 if residual_in_fp32 else None)
|
1061 |
+
)
|
1062 |
+
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
1063 |
+
x,
|
1064 |
+
norm_weight,
|
1065 |
+
norm_bias,
|
1066 |
+
eps,
|
1067 |
+
residual,
|
1068 |
+
out_dtype=(
|
1069 |
+
None
|
1070 |
+
if not torch.is_autocast_enabled()
|
1071 |
+
else torch.get_autocast_gpu_dtype()
|
1072 |
+
),
|
1073 |
+
residual_dtype=residual_dtype,
|
1074 |
+
is_rms_norm=is_rms_norm,
|
1075 |
+
)
|
1076 |
+
y = y.reshape(x_shape_og)
|
1077 |
+
dtype = (
|
1078 |
+
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
1079 |
+
)
|
1080 |
+
linear_weight = linear_weight.to(dtype)
|
1081 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
1082 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
1083 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
1084 |
+
ctx.save_for_backward(
|
1085 |
+
residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
|
1086 |
+
)
|
1087 |
+
ctx.x_shape_og = x_shape_og
|
1088 |
+
ctx.eps = eps
|
1089 |
+
ctx.is_rms_norm = is_rms_norm
|
1090 |
+
ctx.has_residual = residual is not None
|
1091 |
+
ctx.prenorm = prenorm
|
1092 |
+
ctx.x_dtype = x.dtype
|
1093 |
+
ctx.linear_bias_is_none = linear_bias is None
|
1094 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1095 |
+
|
1096 |
+
@staticmethod
|
1097 |
+
@custom_bwd
|
1098 |
+
def backward(ctx, dout, *args):
|
1099 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1100 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
1101 |
+
dy = F.linear(dout, linear_weight.t())
|
1102 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
1103 |
+
if dy.stride(-1) != 1:
|
1104 |
+
dy = dy.contiguous()
|
1105 |
+
assert dy.shape == x.shape
|
1106 |
+
if ctx.prenorm:
|
1107 |
+
dresidual = args[0]
|
1108 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
1109 |
+
if dresidual.stride(-1) != 1:
|
1110 |
+
dresidual = dresidual.contiguous()
|
1111 |
+
assert dresidual.shape == x.shape
|
1112 |
+
else:
|
1113 |
+
dresidual = None
|
1114 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
1115 |
+
dy,
|
1116 |
+
x,
|
1117 |
+
norm_weight,
|
1118 |
+
norm_bias,
|
1119 |
+
ctx.eps,
|
1120 |
+
mean,
|
1121 |
+
rstd,
|
1122 |
+
dresidual=dresidual,
|
1123 |
+
has_residual=ctx.has_residual,
|
1124 |
+
is_rms_norm=ctx.is_rms_norm,
|
1125 |
+
x_dtype=ctx.x_dtype,
|
1126 |
+
recompute_output=True,
|
1127 |
+
)
|
1128 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
1129 |
+
return (
|
1130 |
+
dx.reshape(ctx.x_shape_og),
|
1131 |
+
dnorm_weight,
|
1132 |
+
dnorm_bias,
|
1133 |
+
dlinear_weight,
|
1134 |
+
dlinear_bias,
|
1135 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
1136 |
+
None,
|
1137 |
+
None,
|
1138 |
+
None,
|
1139 |
+
None,
|
1140 |
+
)
|
1141 |
+
|
1142 |
+
|
1143 |
+
def layer_norm_linear_fn(
|
1144 |
+
x,
|
1145 |
+
norm_weight,
|
1146 |
+
norm_bias,
|
1147 |
+
linear_weight,
|
1148 |
+
linear_bias,
|
1149 |
+
residual=None,
|
1150 |
+
eps=1e-6,
|
1151 |
+
prenorm=False,
|
1152 |
+
residual_in_fp32=False,
|
1153 |
+
is_rms_norm=False,
|
1154 |
+
):
|
1155 |
+
return LayerNormLinearFn.apply(
|
1156 |
+
x,
|
1157 |
+
norm_weight,
|
1158 |
+
norm_bias,
|
1159 |
+
linear_weight,
|
1160 |
+
linear_bias,
|
1161 |
+
residual,
|
1162 |
+
eps,
|
1163 |
+
prenorm,
|
1164 |
+
residual_in_fp32,
|
1165 |
+
is_rms_norm,
|
1166 |
+
)
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/layernorm_gated.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
3 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
4 |
+
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
5 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import triton
|
13 |
+
import triton.language as tl
|
14 |
+
|
15 |
+
from einops import rearrange
|
16 |
+
|
17 |
+
|
18 |
+
def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
|
19 |
+
dtype = x.dtype
|
20 |
+
N = x.shape[-1]
|
21 |
+
weight = weight.float()
|
22 |
+
bias = bias.float() if bias is not None else None
|
23 |
+
if upcast:
|
24 |
+
x = x.float()
|
25 |
+
z = z.float() if z is not None else z
|
26 |
+
if z is not None and not norm_before_gate:
|
27 |
+
x = x * F.silu(z)
|
28 |
+
if group_size is None:
|
29 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
30 |
+
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
31 |
+
else:
|
32 |
+
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
33 |
+
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
34 |
+
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
35 |
+
if bias is not None:
|
36 |
+
out = out + bias
|
37 |
+
if z is not None and norm_before_gate:
|
38 |
+
out *= F.silu(z)
|
39 |
+
return out.to(dtype)
|
40 |
+
|
41 |
+
|
42 |
+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
43 |
+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
44 |
+
@triton.jit
|
45 |
+
def _layer_norm_fwd_1pass_kernel(
|
46 |
+
X, # pointer to the input
|
47 |
+
Y, # pointer to the output
|
48 |
+
W, # pointer to the weights
|
49 |
+
B, # pointer to the biases
|
50 |
+
Z, # pointer to the other branch
|
51 |
+
Mean, # pointer to the mean
|
52 |
+
Rstd, # pointer to the 1/std
|
53 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
54 |
+
stride_y_row,
|
55 |
+
stride_z_row,
|
56 |
+
M, # number of rows in X
|
57 |
+
N, # number of columns in X
|
58 |
+
eps, # epsilon to avoid division by zero
|
59 |
+
BLOCK_N: tl.constexpr,
|
60 |
+
HAS_BIAS: tl.constexpr,
|
61 |
+
HAS_Z: tl.constexpr,
|
62 |
+
NORM_BEFORE_GATE: tl.constexpr,
|
63 |
+
IS_RMS_NORM: tl.constexpr,
|
64 |
+
):
|
65 |
+
# Map the program id to the row of X and Y it should compute.
|
66 |
+
row = tl.program_id(0)
|
67 |
+
group = tl.program_id(1)
|
68 |
+
X += row * stride_x_row + group * N
|
69 |
+
Y += row * stride_y_row + group * N
|
70 |
+
if HAS_Z:
|
71 |
+
Z += row * stride_z_row + group * N
|
72 |
+
if not IS_RMS_NORM:
|
73 |
+
Mean += group * M
|
74 |
+
Rstd += group * M
|
75 |
+
W += group * N
|
76 |
+
if HAS_BIAS:
|
77 |
+
B += group * N
|
78 |
+
# Compute mean and variance
|
79 |
+
cols = tl.arange(0, BLOCK_N)
|
80 |
+
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
81 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
82 |
+
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
83 |
+
x *= z * tl.sigmoid(z)
|
84 |
+
if not IS_RMS_NORM:
|
85 |
+
mean = tl.sum(x, axis=0) / N
|
86 |
+
tl.store(Mean + row, mean)
|
87 |
+
xbar = tl.where(cols < N, x - mean, 0.)
|
88 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
89 |
+
else:
|
90 |
+
xbar = tl.where(cols < N, x, 0.)
|
91 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
92 |
+
rstd = 1 / tl.sqrt(var + eps)
|
93 |
+
tl.store(Rstd + row, rstd)
|
94 |
+
# Normalize and apply linear transformation
|
95 |
+
mask = cols < N
|
96 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
97 |
+
if HAS_BIAS:
|
98 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
99 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
100 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
101 |
+
if HAS_Z and NORM_BEFORE_GATE:
|
102 |
+
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
103 |
+
y *= z * tl.sigmoid(z)
|
104 |
+
# Write output
|
105 |
+
tl.store(Y + cols, y, mask=mask)
|
106 |
+
|
107 |
+
|
108 |
+
def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
|
109 |
+
M, N = x.shape
|
110 |
+
if group_size is None:
|
111 |
+
group_size = N
|
112 |
+
assert N % group_size == 0
|
113 |
+
ngroups = N // group_size
|
114 |
+
assert x.stride(-1) == 1
|
115 |
+
if z is not None:
|
116 |
+
assert z.stride(-1) == 1
|
117 |
+
assert z.shape == (M, N)
|
118 |
+
assert weight.shape == (N,)
|
119 |
+
assert weight.stride(-1) == 1
|
120 |
+
if bias is not None:
|
121 |
+
assert bias.stride(-1) == 1
|
122 |
+
assert bias.shape == (N,)
|
123 |
+
# allocate output
|
124 |
+
if out is not None:
|
125 |
+
assert out.shape == x.shape
|
126 |
+
else:
|
127 |
+
out = torch.empty_like(x)
|
128 |
+
assert out.stride(-1) == 1
|
129 |
+
mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
130 |
+
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
131 |
+
# Less than 64KB per feature: enqueue fused kernel
|
132 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
133 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
134 |
+
if group_size > BLOCK_N:
|
135 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
136 |
+
# heuristics for number of warps
|
137 |
+
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
138 |
+
grid = (M, ngroups)
|
139 |
+
with torch.cuda.device(x.device.index):
|
140 |
+
_layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
|
141 |
+
x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
|
142 |
+
M, group_size, eps,
|
143 |
+
BLOCK_N=BLOCK_N,
|
144 |
+
NORM_BEFORE_GATE=norm_before_gate,
|
145 |
+
IS_RMS_NORM=is_rms_norm,
|
146 |
+
num_warps=num_warps)
|
147 |
+
return out, mean, rstd
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
152 |
+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
153 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
154 |
+
@triton.jit
|
155 |
+
def _layer_norm_bwd_kernel(
|
156 |
+
X, # pointer to the input
|
157 |
+
W, # pointer to the weights
|
158 |
+
B, # pointer to the biases
|
159 |
+
Z, # pointer to the other branch
|
160 |
+
Y, # pointer to the output to be recomputed
|
161 |
+
DY, # pointer to the output gradient
|
162 |
+
DX, # pointer to the input gradient
|
163 |
+
DW, # pointer to the partial sum of weights gradient
|
164 |
+
DB, # pointer to the partial sum of biases gradient
|
165 |
+
DZ, # pointer to the other branch
|
166 |
+
Mean, # pointer to the mean
|
167 |
+
Rstd, # pointer to the 1/std
|
168 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
169 |
+
stride_z_row,
|
170 |
+
stride_y_row,
|
171 |
+
stride_dy_row,
|
172 |
+
stride_dx_row,
|
173 |
+
stride_dz_row,
|
174 |
+
stride_dw_row,
|
175 |
+
stride_db_row,
|
176 |
+
M, # number of rows in X
|
177 |
+
N, # number of columns in X
|
178 |
+
eps, # epsilon to avoid division by zero
|
179 |
+
rows_per_program,
|
180 |
+
NORM_BEFORE_GATE: tl.constexpr,
|
181 |
+
IS_RMS_NORM: tl.constexpr,
|
182 |
+
HAS_BIAS: tl.constexpr,
|
183 |
+
HAS_Z: tl.constexpr,
|
184 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
185 |
+
BLOCK_N: tl.constexpr,
|
186 |
+
):
|
187 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
188 |
+
row_block_id = tl.program_id(0)
|
189 |
+
group = tl.program_id(1)
|
190 |
+
row_start = row_block_id * rows_per_program
|
191 |
+
cols = tl.arange(0, BLOCK_N)
|
192 |
+
mask = cols < N
|
193 |
+
X += row_start * stride_x_row + group * N
|
194 |
+
if HAS_Z:
|
195 |
+
Z += row_start * stride_z_row + group * N
|
196 |
+
DZ += row_start * stride_dz_row + group * N
|
197 |
+
DY += row_start * stride_dy_row + group * N
|
198 |
+
DX += row_start * stride_dx_row + group * N
|
199 |
+
if RECOMPUTE_OUTPUT:
|
200 |
+
Y += row_start * stride_y_row + group * N
|
201 |
+
if not IS_RMS_NORM:
|
202 |
+
Mean += group * M
|
203 |
+
Rstd += group * M
|
204 |
+
W += group * N
|
205 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
206 |
+
if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
|
207 |
+
B += group * N
|
208 |
+
b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
|
209 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
210 |
+
if HAS_BIAS:
|
211 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
212 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
213 |
+
for row in range(row_start, row_end):
|
214 |
+
# Load data to SRAM
|
215 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
216 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
217 |
+
if not IS_RMS_NORM:
|
218 |
+
mean = tl.load(Mean + row)
|
219 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
220 |
+
z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
|
221 |
+
x_og = x
|
222 |
+
x = x_og * z * tl.sigmoid(z)
|
223 |
+
rstd = tl.load(Rstd + row)
|
224 |
+
# Compute dx
|
225 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
226 |
+
xhat = tl.where(mask, xhat, 0.)
|
227 |
+
if HAS_Z and NORM_BEFORE_GATE:
|
228 |
+
z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
|
229 |
+
z_sigmoid = tl.sigmoid(z)
|
230 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
231 |
+
if RECOMPUTE_OUTPUT:
|
232 |
+
tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
|
233 |
+
dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
|
234 |
+
tl.store(DZ + cols, dz, mask=mask)
|
235 |
+
dy *= z * z_sigmoid
|
236 |
+
else:
|
237 |
+
if RECOMPUTE_OUTPUT:
|
238 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
239 |
+
tl.store(Y + cols, y, mask=mask)
|
240 |
+
wdy = w * dy
|
241 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
242 |
+
if not IS_RMS_NORM:
|
243 |
+
c2 = tl.sum(wdy, axis=0) / N
|
244 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
245 |
+
else:
|
246 |
+
dx = (wdy - xhat * c1) * rstd
|
247 |
+
dw += dy * xhat
|
248 |
+
if HAS_BIAS:
|
249 |
+
db += dy
|
250 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
251 |
+
z_sigmoid = tl.sigmoid(z)
|
252 |
+
dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
|
253 |
+
tl.store(DZ + cols, dz, mask=mask)
|
254 |
+
dx *= z * z_sigmoid
|
255 |
+
# Write dx
|
256 |
+
tl.store(DX + cols, dx, mask=mask)
|
257 |
+
|
258 |
+
X += stride_x_row
|
259 |
+
if HAS_Z:
|
260 |
+
Z += stride_z_row
|
261 |
+
DZ += stride_dz_row
|
262 |
+
if RECOMPUTE_OUTPUT:
|
263 |
+
Y += stride_y_row
|
264 |
+
DY += stride_dy_row
|
265 |
+
DX += stride_dx_row
|
266 |
+
tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
|
267 |
+
if HAS_BIAS:
|
268 |
+
tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
|
269 |
+
|
270 |
+
|
271 |
+
def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
|
272 |
+
norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
|
273 |
+
M, N = x.shape
|
274 |
+
if group_size is None:
|
275 |
+
group_size = N
|
276 |
+
assert N % group_size == 0
|
277 |
+
ngroups = N // group_size
|
278 |
+
assert x.stride(-1) == 1
|
279 |
+
assert dy.stride(-1) == 1
|
280 |
+
assert dy.shape == (M, N)
|
281 |
+
if z is not None:
|
282 |
+
assert z.stride(-1) == 1
|
283 |
+
assert z.shape == (M, N)
|
284 |
+
assert weight.shape == (N,)
|
285 |
+
assert weight.stride(-1) == 1
|
286 |
+
if bias is not None:
|
287 |
+
assert bias.stride(-1) == 1
|
288 |
+
assert bias.shape == (N,)
|
289 |
+
# allocate output
|
290 |
+
dx = torch.empty_like(x)
|
291 |
+
if dz is not None:
|
292 |
+
assert z is not None
|
293 |
+
assert dz.shape == z.shape
|
294 |
+
assert dz.stride(-1) == 1
|
295 |
+
else:
|
296 |
+
dz = torch.empty_like(z) if z is not None else None
|
297 |
+
if recompute_output:
|
298 |
+
if out is None:
|
299 |
+
out = torch.empty_like(x)
|
300 |
+
assert out.shape == x.shape
|
301 |
+
|
302 |
+
# Less than 64KB per feature: enqueue fused kernel
|
303 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
304 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
305 |
+
if group_size > BLOCK_N:
|
306 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
307 |
+
# heuristics for number of warps
|
308 |
+
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
309 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
310 |
+
# If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
|
311 |
+
# would limit the occupancy.
|
312 |
+
nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
|
313 |
+
_dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
|
314 |
+
_db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
|
315 |
+
rows_per_program = math.ceil(M / nrow_groups)
|
316 |
+
grid = (nrow_groups, ngroups)
|
317 |
+
with torch.cuda.device(x.device.index):
|
318 |
+
_layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
|
319 |
+
dy, dx, _dw, _db, dz, mean, rstd,
|
320 |
+
x.stride(0),
|
321 |
+
z.stride(0) if z is not None else 0,
|
322 |
+
0 if not recompute_output else out.stride(0),
|
323 |
+
dy.stride(0), dx.stride(0),
|
324 |
+
dz.stride(0) if dz is not None else 0,
|
325 |
+
_dw.stride(0),
|
326 |
+
_db.stride(0) if _db is not None else 0,
|
327 |
+
M, group_size, eps,
|
328 |
+
rows_per_program,
|
329 |
+
BLOCK_N=BLOCK_N,
|
330 |
+
NORM_BEFORE_GATE=norm_before_gate,
|
331 |
+
IS_RMS_NORM=is_rms_norm,
|
332 |
+
num_warps=num_warps)
|
333 |
+
dw = _dw.sum(0).to(weight.dtype)
|
334 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
335 |
+
return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
|
336 |
+
|
337 |
+
|
338 |
+
class LayerNormFn(torch.autograd.Function):
|
339 |
+
|
340 |
+
@staticmethod
|
341 |
+
def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
|
342 |
+
is_rms_norm=False):
|
343 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
344 |
+
"""
|
345 |
+
|
346 |
+
x_shape_og = x.shape
|
347 |
+
# reshape input data into 2D tensor
|
348 |
+
x = x.reshape(-1, x.shape[-1])
|
349 |
+
if x.stride(-1) != 1:
|
350 |
+
x = x.contiguous()
|
351 |
+
if z is not None:
|
352 |
+
assert z.shape == x_shape_og
|
353 |
+
z = z.reshape(-1, z.shape[-1])
|
354 |
+
if z.stride(-1) != 1:
|
355 |
+
z = z.contiguous()
|
356 |
+
weight = weight.contiguous()
|
357 |
+
if bias is not None:
|
358 |
+
bias = bias.contiguous()
|
359 |
+
y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
|
360 |
+
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
361 |
+
ctx.x_shape_og = x_shape_og
|
362 |
+
ctx.eps = eps
|
363 |
+
ctx.group_size = group_size
|
364 |
+
ctx.norm_before_gate = norm_before_gate
|
365 |
+
ctx.is_rms_norm = is_rms_norm
|
366 |
+
return y.reshape(x_shape_og)
|
367 |
+
|
368 |
+
@staticmethod
|
369 |
+
def backward(ctx, dy):
|
370 |
+
x, weight, bias, mean, rstd, z = ctx.saved_tensors
|
371 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
372 |
+
if dy.stride(-1) != 1:
|
373 |
+
dy = dy.contiguous()
|
374 |
+
assert dy.shape == x.shape
|
375 |
+
dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
|
376 |
+
ctx.norm_before_gate, ctx.is_rms_norm)
|
377 |
+
return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
|
378 |
+
|
379 |
+
|
380 |
+
def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
|
381 |
+
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
|
382 |
+
|
383 |
+
|
384 |
+
def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
|
385 |
+
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
|
386 |
+
|
387 |
+
|
388 |
+
class LayerNorm(torch.nn.Module):
|
389 |
+
|
390 |
+
def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
|
391 |
+
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
392 |
+
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
393 |
+
"""
|
394 |
+
|
395 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
396 |
+
super().__init__()
|
397 |
+
self.eps = eps
|
398 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
399 |
+
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
400 |
+
self.group_size = group_size
|
401 |
+
self.norm_before_gate = norm_before_gate
|
402 |
+
self.reset_parameters()
|
403 |
+
|
404 |
+
def reset_parameters(self):
|
405 |
+
torch.nn.init.ones_(self.weight)
|
406 |
+
torch.nn.init.zeros_(self.bias)
|
407 |
+
|
408 |
+
def forward(self, x, z=None):
|
409 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
410 |
+
"""
|
411 |
+
return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
|
412 |
+
norm_before_gate=self.norm_before_gate)
|
413 |
+
|
414 |
+
|
415 |
+
class RMSNorm(torch.nn.Module):
|
416 |
+
|
417 |
+
def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
|
418 |
+
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
419 |
+
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
420 |
+
"""
|
421 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
422 |
+
super().__init__()
|
423 |
+
self.eps = eps
|
424 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
425 |
+
self.register_parameter("bias", None)
|
426 |
+
self.group_size = group_size
|
427 |
+
self.norm_before_gate = norm_before_gate
|
428 |
+
self.reset_parameters()
|
429 |
+
|
430 |
+
def reset_parameters(self):
|
431 |
+
torch.nn.init.ones_(self.weight)
|
432 |
+
|
433 |
+
def forward(self, x, z=None):
|
434 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
435 |
+
"""
|
436 |
+
return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
|
437 |
+
norm_before_gate=self.norm_before_gate)
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/selective_state_update.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from .softplus import softplus
|
16 |
+
|
17 |
+
|
18 |
+
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
19 |
+
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
20 |
+
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
21 |
+
@triton.heuristics(
|
22 |
+
{
|
23 |
+
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
24 |
+
is not None
|
25 |
+
}
|
26 |
+
)
|
27 |
+
@triton.heuristics(
|
28 |
+
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
29 |
+
)
|
30 |
+
@triton.jit
|
31 |
+
def _selective_scan_update_kernel(
|
32 |
+
# Pointers to matrices
|
33 |
+
state_ptr,
|
34 |
+
x_ptr,
|
35 |
+
dt_ptr,
|
36 |
+
dt_bias_ptr,
|
37 |
+
A_ptr,
|
38 |
+
B_ptr,
|
39 |
+
C_ptr,
|
40 |
+
D_ptr,
|
41 |
+
z_ptr,
|
42 |
+
out_ptr,
|
43 |
+
state_batch_indices_ptr,
|
44 |
+
# Matrix dimensions
|
45 |
+
batch,
|
46 |
+
nheads,
|
47 |
+
dim,
|
48 |
+
dstate,
|
49 |
+
nheads_ngroups_ratio,
|
50 |
+
# Strides
|
51 |
+
stride_state_batch,
|
52 |
+
stride_state_head,
|
53 |
+
stride_state_dim,
|
54 |
+
stride_state_dstate,
|
55 |
+
stride_x_batch,
|
56 |
+
stride_x_head,
|
57 |
+
stride_x_dim,
|
58 |
+
stride_dt_batch,
|
59 |
+
stride_dt_head,
|
60 |
+
stride_dt_dim,
|
61 |
+
stride_dt_bias_head,
|
62 |
+
stride_dt_bias_dim,
|
63 |
+
stride_A_head,
|
64 |
+
stride_A_dim,
|
65 |
+
stride_A_dstate,
|
66 |
+
stride_B_batch,
|
67 |
+
stride_B_group,
|
68 |
+
stride_B_dstate,
|
69 |
+
stride_C_batch,
|
70 |
+
stride_C_group,
|
71 |
+
stride_C_dstate,
|
72 |
+
stride_D_head,
|
73 |
+
stride_D_dim,
|
74 |
+
stride_z_batch,
|
75 |
+
stride_z_head,
|
76 |
+
stride_z_dim,
|
77 |
+
stride_out_batch,
|
78 |
+
stride_out_head,
|
79 |
+
stride_out_dim,
|
80 |
+
# Meta-parameters
|
81 |
+
DT_SOFTPLUS: tl.constexpr,
|
82 |
+
TIE_HDIM: tl.constexpr,
|
83 |
+
BLOCK_SIZE_M: tl.constexpr,
|
84 |
+
HAS_DT_BIAS: tl.constexpr,
|
85 |
+
HAS_D: tl.constexpr,
|
86 |
+
HAS_Z: tl.constexpr,
|
87 |
+
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
88 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
89 |
+
):
|
90 |
+
pid_m = tl.program_id(axis=0)
|
91 |
+
pid_b = tl.program_id(axis=1)
|
92 |
+
pid_h = tl.program_id(axis=2)
|
93 |
+
|
94 |
+
if HAS_STATE_BATCH_INDICES:
|
95 |
+
state_batch_indices_ptr += pid_b
|
96 |
+
state_batch_idx = tl.load(state_batch_indices_ptr)
|
97 |
+
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
98 |
+
else:
|
99 |
+
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
100 |
+
|
101 |
+
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
102 |
+
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
103 |
+
if HAS_DT_BIAS:
|
104 |
+
dt_bias_ptr += pid_h * stride_dt_bias_head
|
105 |
+
A_ptr += pid_h * stride_A_head
|
106 |
+
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
107 |
+
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
108 |
+
if HAS_Z:
|
109 |
+
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
110 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
111 |
+
|
112 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
113 |
+
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
114 |
+
state_ptrs = state_ptr + (
|
115 |
+
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
116 |
+
)
|
117 |
+
x_ptrs = x_ptr + offs_m * stride_x_dim
|
118 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
119 |
+
if HAS_DT_BIAS:
|
120 |
+
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
121 |
+
if HAS_D:
|
122 |
+
D_ptr += pid_h * stride_D_head
|
123 |
+
A_ptrs = A_ptr + (
|
124 |
+
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
125 |
+
)
|
126 |
+
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
127 |
+
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
128 |
+
if HAS_D:
|
129 |
+
D_ptrs = D_ptr + offs_m * stride_D_dim
|
130 |
+
if HAS_Z:
|
131 |
+
z_ptrs = z_ptr + offs_m * stride_z_dim
|
132 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
133 |
+
|
134 |
+
state = tl.load(
|
135 |
+
state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
136 |
+
)
|
137 |
+
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
138 |
+
if not TIE_HDIM:
|
139 |
+
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
140 |
+
if HAS_DT_BIAS:
|
141 |
+
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
142 |
+
if DT_SOFTPLUS:
|
143 |
+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
144 |
+
A = tl.load(
|
145 |
+
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
146 |
+
).to(tl.float32)
|
147 |
+
dA = tl.exp(A * dt[:, None])
|
148 |
+
else:
|
149 |
+
dt = tl.load(dt_ptr).to(tl.float32)
|
150 |
+
if HAS_DT_BIAS:
|
151 |
+
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
152 |
+
if DT_SOFTPLUS:
|
153 |
+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
154 |
+
A = tl.load(A_ptr).to(tl.float32)
|
155 |
+
dA = tl.exp(A * dt) # scalar, not a matrix
|
156 |
+
|
157 |
+
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
158 |
+
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
159 |
+
if HAS_D:
|
160 |
+
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
161 |
+
if HAS_Z:
|
162 |
+
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
163 |
+
|
164 |
+
if not TIE_HDIM:
|
165 |
+
dB = B[None, :] * dt[:, None]
|
166 |
+
else:
|
167 |
+
dB = B * dt # vector of size (dstate,)
|
168 |
+
state = state * dA + dB * x[:, None]
|
169 |
+
tl.store(
|
170 |
+
state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
171 |
+
)
|
172 |
+
out = tl.sum(state * C[None, :], axis=1)
|
173 |
+
if HAS_D:
|
174 |
+
out += x * D
|
175 |
+
if HAS_Z:
|
176 |
+
out *= z * tl.sigmoid(z)
|
177 |
+
tl.store(out_ptrs, out, mask=offs_m < dim)
|
178 |
+
|
179 |
+
|
180 |
+
def selective_state_update(
|
181 |
+
state,
|
182 |
+
x,
|
183 |
+
dt,
|
184 |
+
A,
|
185 |
+
B,
|
186 |
+
C,
|
187 |
+
D=None,
|
188 |
+
z=None,
|
189 |
+
dt_bias=None,
|
190 |
+
dt_softplus=False,
|
191 |
+
state_batch_indices=None,
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
Argument:
|
195 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
196 |
+
x: (batch, dim) or (batch, nheads, dim)
|
197 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
198 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
199 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
200 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
201 |
+
D: (dim,) or (nheads, dim)
|
202 |
+
z: (batch, dim) or (batch, nheads, dim)
|
203 |
+
dt_bias: (dim,) or (nheads, dim)
|
204 |
+
Return:
|
205 |
+
out: (batch, dim) or (batch, nheads, dim)
|
206 |
+
"""
|
207 |
+
has_heads = state.dim() > 3
|
208 |
+
if state.dim() == 3:
|
209 |
+
state = state.unsqueeze(1)
|
210 |
+
if x.dim() == 2:
|
211 |
+
x = x.unsqueeze(1)
|
212 |
+
if dt.dim() == 2:
|
213 |
+
dt = dt.unsqueeze(1)
|
214 |
+
if A.dim() == 2:
|
215 |
+
A = A.unsqueeze(0)
|
216 |
+
if B.dim() == 2:
|
217 |
+
B = B.unsqueeze(1)
|
218 |
+
if C.dim() == 2:
|
219 |
+
C = C.unsqueeze(1)
|
220 |
+
if D is not None and D.dim() == 1:
|
221 |
+
D = D.unsqueeze(0)
|
222 |
+
if z is not None and z.dim() == 2:
|
223 |
+
z = z.unsqueeze(1)
|
224 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
225 |
+
dt_bias = dt_bias.unsqueeze(0)
|
226 |
+
_, nheads, dim, dstate = state.shape
|
227 |
+
batch = x.shape[0]
|
228 |
+
if x.shape != (batch, nheads, dim):
|
229 |
+
print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
|
230 |
+
assert x.shape == (batch, nheads, dim)
|
231 |
+
assert dt.shape == x.shape
|
232 |
+
assert A.shape == (nheads, dim, dstate)
|
233 |
+
ngroups = B.shape[1]
|
234 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
235 |
+
assert B.shape == (batch, ngroups, dstate)
|
236 |
+
assert C.shape == B.shape
|
237 |
+
if D is not None:
|
238 |
+
assert D.shape == (nheads, dim)
|
239 |
+
if z is not None:
|
240 |
+
assert z.shape == x.shape
|
241 |
+
if dt_bias is not None:
|
242 |
+
assert dt_bias.shape == (nheads, dim)
|
243 |
+
if state_batch_indices is not None:
|
244 |
+
assert state_batch_indices.shape == (batch,)
|
245 |
+
out = torch.empty_like(x)
|
246 |
+
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
247 |
+
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
248 |
+
# We don't want autotune since it will overwrite the state
|
249 |
+
# We instead tune by hand.
|
250 |
+
BLOCK_SIZE_M, num_warps = (
|
251 |
+
(32, 4)
|
252 |
+
if dstate <= 16
|
253 |
+
else (
|
254 |
+
(16, 4)
|
255 |
+
if dstate <= 32
|
256 |
+
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
257 |
+
)
|
258 |
+
)
|
259 |
+
tie_hdim = (
|
260 |
+
A.stride(-1) == 0
|
261 |
+
and A.stride(-2) == 0
|
262 |
+
and dt.stride(-1) == 0
|
263 |
+
and dt_bias.stride(-1) == 0
|
264 |
+
)
|
265 |
+
with torch.cuda.device(x.device.index):
|
266 |
+
_selective_scan_update_kernel[grid](
|
267 |
+
state,
|
268 |
+
x,
|
269 |
+
dt,
|
270 |
+
dt_bias,
|
271 |
+
A,
|
272 |
+
B,
|
273 |
+
C,
|
274 |
+
D,
|
275 |
+
z,
|
276 |
+
out,
|
277 |
+
state_batch_indices,
|
278 |
+
batch,
|
279 |
+
nheads,
|
280 |
+
dim,
|
281 |
+
dstate,
|
282 |
+
nheads // ngroups,
|
283 |
+
state.stride(0),
|
284 |
+
state.stride(1),
|
285 |
+
state.stride(2),
|
286 |
+
state.stride(3),
|
287 |
+
x.stride(0),
|
288 |
+
x.stride(1),
|
289 |
+
x.stride(2),
|
290 |
+
dt.stride(0),
|
291 |
+
dt.stride(1),
|
292 |
+
dt.stride(2),
|
293 |
+
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
294 |
+
A.stride(0),
|
295 |
+
A.stride(1),
|
296 |
+
A.stride(2),
|
297 |
+
B.stride(0),
|
298 |
+
B.stride(1),
|
299 |
+
B.stride(2),
|
300 |
+
C.stride(0),
|
301 |
+
C.stride(1),
|
302 |
+
C.stride(2),
|
303 |
+
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
304 |
+
z_strides[0],
|
305 |
+
z_strides[1],
|
306 |
+
z_strides[2],
|
307 |
+
out.stride(0),
|
308 |
+
out.stride(1),
|
309 |
+
out.stride(2),
|
310 |
+
dt_softplus,
|
311 |
+
tie_hdim,
|
312 |
+
BLOCK_SIZE_M,
|
313 |
+
num_warps=num_warps,
|
314 |
+
)
|
315 |
+
if not has_heads:
|
316 |
+
out = out.squeeze(1)
|
317 |
+
return out
|
318 |
+
|
319 |
+
|
320 |
+
def selective_state_update_ref(
|
321 |
+
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
322 |
+
):
|
323 |
+
"""
|
324 |
+
Argument:
|
325 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
326 |
+
x: (batch, dim) or (batch, nheads, dim)
|
327 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
328 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
329 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
330 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
331 |
+
D: (dim,) or (nheads, dim)
|
332 |
+
z: (batch, dim) or (batch, nheads, dim)
|
333 |
+
dt_bias: (dim,) or (nheads, dim)
|
334 |
+
Return:
|
335 |
+
out: (batch, dim) or (batch, nheads, dim)
|
336 |
+
"""
|
337 |
+
has_heads = state.dim() > 3
|
338 |
+
if state.dim() == 3:
|
339 |
+
state = state.unsqueeze(1)
|
340 |
+
if x.dim() == 2:
|
341 |
+
x = x.unsqueeze(1)
|
342 |
+
if dt.dim() == 2:
|
343 |
+
dt = dt.unsqueeze(1)
|
344 |
+
if A.dim() == 2:
|
345 |
+
A = A.unsqueeze(0)
|
346 |
+
if B.dim() == 2:
|
347 |
+
B = B.unsqueeze(1)
|
348 |
+
if C.dim() == 2:
|
349 |
+
C = C.unsqueeze(1)
|
350 |
+
if D is not None and D.dim() == 1:
|
351 |
+
D = D.unsqueeze(0)
|
352 |
+
if z is not None and z.dim() == 2:
|
353 |
+
z = z.unsqueeze(1)
|
354 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
355 |
+
dt_bias = dt_bias.unsqueeze(0)
|
356 |
+
batch, nheads, dim, dstate = state.shape
|
357 |
+
assert x.shape == (batch, nheads, dim)
|
358 |
+
assert dt.shape == x.shape
|
359 |
+
assert A.shape == (nheads, dim, dstate)
|
360 |
+
ngroups = B.shape[1]
|
361 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
362 |
+
assert B.shape == (batch, ngroups, dstate)
|
363 |
+
assert C.shape == B.shape
|
364 |
+
if D is not None:
|
365 |
+
assert D.shape == (nheads, dim)
|
366 |
+
if z is not None:
|
367 |
+
assert z.shape == x.shape
|
368 |
+
if dt_bias is not None:
|
369 |
+
assert dt_bias.shape == (nheads, dim)
|
370 |
+
dt = dt + dt_bias
|
371 |
+
dt = F.softplus(dt) if dt_softplus else dt
|
372 |
+
dA = torch.exp(
|
373 |
+
rearrange(dt, "b h d -> b h d 1") * A
|
374 |
+
) # (batch, nheads, dim, dstate)
|
375 |
+
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
376 |
+
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
377 |
+
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
378 |
+
B, "b h n -> b h 1 n"
|
379 |
+
) # (batch, nheads, dim, dstate)
|
380 |
+
state.copy_(
|
381 |
+
state * dA + dB * rearrange(x, "b h d -> b h d 1")
|
382 |
+
) # (batch, dim, dstate
|
383 |
+
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
384 |
+
if D is not None:
|
385 |
+
out += (x * D).to(out.dtype)
|
386 |
+
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
387 |
+
if not has_heads:
|
388 |
+
out = out.squeeze(1)
|
389 |
+
return out
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/softplus.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
from packaging import version
|
4 |
+
|
5 |
+
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
|
6 |
+
|
7 |
+
|
8 |
+
if TRITON3:
|
9 |
+
@triton.jit
|
10 |
+
def softplus(dt):
|
11 |
+
return tl.math.log(tl.math.exp(dt) + 1)
|
12 |
+
else:
|
13 |
+
@triton.jit
|
14 |
+
def softplus(dt):
|
15 |
+
return tl.math.log1p(tl.exp(dt))
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_bmm.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
|
16 |
+
def init_to_zero(names):
|
17 |
+
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
|
18 |
+
|
19 |
+
|
20 |
+
@triton.autotune(
|
21 |
+
configs=[
|
22 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
|
23 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
24 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
25 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
26 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
27 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
28 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
29 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
30 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
|
31 |
+
],
|
32 |
+
key=['chunk_size', 'K', 'IS_CAUSAL'],
|
33 |
+
)
|
34 |
+
@triton.jit
|
35 |
+
def _bmm_chunk_fwd_kernel(
|
36 |
+
# Pointers to matrices
|
37 |
+
a_ptr, b_ptr, out_ptr, seq_idx_ptr,
|
38 |
+
# Matrix dimensions
|
39 |
+
seqlen, chunk_size, K, ngroups,
|
40 |
+
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
|
41 |
+
stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
|
42 |
+
stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
|
43 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
44 |
+
# Meta-parameters
|
45 |
+
IS_CAUSAL: tl.constexpr,
|
46 |
+
dot_dtype: tl.constexpr,
|
47 |
+
HAS_SEQ_IDX: tl.constexpr,
|
48 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
49 |
+
):
|
50 |
+
pid_b = tl.program_id(axis=1)
|
51 |
+
pid_ch = tl.program_id(axis=2)
|
52 |
+
pid_c = pid_ch // ngroups
|
53 |
+
pid_h = pid_ch - pid_c * ngroups
|
54 |
+
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
55 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
56 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
57 |
+
if IS_CAUSAL:
|
58 |
+
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
59 |
+
return
|
60 |
+
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
61 |
+
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
62 |
+
if HAS_SEQ_IDX:
|
63 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
64 |
+
|
65 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
66 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
67 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
68 |
+
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
69 |
+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
70 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
71 |
+
|
72 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
73 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
74 |
+
a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
|
75 |
+
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
|
76 |
+
acc += tl.dot(a, b)
|
77 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
78 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
79 |
+
|
80 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
81 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
82 |
+
if HAS_SEQ_IDX:
|
83 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
84 |
+
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
85 |
+
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
|
86 |
+
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
87 |
+
out = acc.to(out_ptr.dtype.element_ty)
|
88 |
+
|
89 |
+
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
|
90 |
+
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
91 |
+
tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
|
92 |
+
|
93 |
+
|
94 |
+
@triton.autotune(
|
95 |
+
configs=[
|
96 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
|
97 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
98 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
99 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
100 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
101 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
102 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
|
103 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
|
104 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
|
105 |
+
],
|
106 |
+
key=['chunk_size', 'K'],
|
107 |
+
)
|
108 |
+
@triton.jit
|
109 |
+
def _bmm_chunk_bwd_kernel(
|
110 |
+
# Pointers to matrices
|
111 |
+
a_ptr, dout_ptr, db_ptr, res_ptr,
|
112 |
+
# Matrix dimensions
|
113 |
+
seqlen, chunk_size, K, ngroups,
|
114 |
+
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
|
115 |
+
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
|
116 |
+
stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
|
117 |
+
stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
|
118 |
+
# Meta-parameters
|
119 |
+
dot_dtype: tl.constexpr,
|
120 |
+
HAS_RESIDUAL: tl.constexpr,
|
121 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
|
122 |
+
):
|
123 |
+
pid_b = tl.program_id(axis=1)
|
124 |
+
pid_ch = tl.program_id(axis=2)
|
125 |
+
pid_c = pid_ch // ngroups
|
126 |
+
pid_h = pid_ch - pid_c * ngroups
|
127 |
+
num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
|
128 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
129 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
130 |
+
|
131 |
+
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
132 |
+
dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
|
133 |
+
|
134 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
135 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
136 |
+
offs_cs = tl.arange(0, BLOCK_SIZE_CS)
|
137 |
+
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
|
138 |
+
a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
|
139 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
140 |
+
|
141 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
142 |
+
for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
|
143 |
+
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
|
144 |
+
a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
|
145 |
+
acc += tl.dot(dout, a)
|
146 |
+
dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
|
147 |
+
a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
|
148 |
+
|
149 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
150 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
151 |
+
if HAS_RESIDUAL:
|
152 |
+
res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
|
153 |
+
res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
|
154 |
+
res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
|
155 |
+
acc += res
|
156 |
+
db = acc.to(db_ptr.dtype.element_ty)
|
157 |
+
|
158 |
+
db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
|
159 |
+
db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
|
160 |
+
tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
|
161 |
+
|
162 |
+
|
163 |
+
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
|
164 |
+
"""
|
165 |
+
Argument:
|
166 |
+
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
167 |
+
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
168 |
+
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
169 |
+
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
170 |
+
guaranteed to be correct.
|
171 |
+
Return:
|
172 |
+
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
173 |
+
"""
|
174 |
+
# Check constraints.
|
175 |
+
has_groups = a.dim() == 4
|
176 |
+
if not has_groups:
|
177 |
+
batch, seqlen, k = a.shape
|
178 |
+
else:
|
179 |
+
batch, seqlen, ngroups, k = a.shape
|
180 |
+
assert b.shape == a.shape
|
181 |
+
if seq_idx is not None:
|
182 |
+
assert seq_idx.shape == (batch, seqlen)
|
183 |
+
if a.stride(-1) != 1 and a.stride(1) != 1:
|
184 |
+
a = a.contiguous()
|
185 |
+
if b.stride(-1) != 1 and b.stride(1) != 1:
|
186 |
+
b = b.contiguous()
|
187 |
+
nchunks = math.ceil(seqlen / chunk_size)
|
188 |
+
# Allocates output.
|
189 |
+
out_dtype = a.dtype if output_dtype is None else output_dtype
|
190 |
+
out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
|
191 |
+
device=a.device, dtype=out_dtype)
|
192 |
+
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
193 |
+
(tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
|
194 |
+
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
|
195 |
+
batch, nchunks if not has_groups else nchunks * ngroups)
|
196 |
+
with torch.cuda.device(a.device.index):
|
197 |
+
_bmm_chunk_fwd_kernel[grid](
|
198 |
+
a, b, out, seq_idx,
|
199 |
+
seqlen, chunk_size, k, ngroups if has_groups else 1,
|
200 |
+
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
|
201 |
+
b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
|
202 |
+
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
|
203 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
204 |
+
causal,
|
205 |
+
dot_dtype,
|
206 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
207 |
+
)
|
208 |
+
return out
|
209 |
+
|
210 |
+
|
211 |
+
def _bmm_chunk_bwd(a, dout, residual=None, out=None):
|
212 |
+
"""
|
213 |
+
Argument:
|
214 |
+
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
215 |
+
dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
216 |
+
residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
217 |
+
Return:
|
218 |
+
out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
219 |
+
|
220 |
+
If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
|
221 |
+
zeroed out before calling this function.
|
222 |
+
"""
|
223 |
+
# Check constraints.
|
224 |
+
has_groups = a.dim() == 4
|
225 |
+
if not has_groups:
|
226 |
+
batch, seqlen, k = a.shape
|
227 |
+
else:
|
228 |
+
batch, seqlen, ngroups, k = a.shape
|
229 |
+
nchunks, chunk_size = dout.shape[1], dout.shape[-1]
|
230 |
+
if a.stride(-1) != 1 and a.stride(-2) != 1:
|
231 |
+
a = a.contiguous()
|
232 |
+
if dout.stride(-1) != 1 and dout.stride(-2) != 1:
|
233 |
+
dout = dout.contiguous()
|
234 |
+
if residual is not None:
|
235 |
+
assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
|
236 |
+
if residual.stride(-1) != 1 and residual.stride(1) != 1:
|
237 |
+
residual = residual.contiguous()
|
238 |
+
# Allocates output.
|
239 |
+
if out is not None:
|
240 |
+
assert out.shape == a.shape
|
241 |
+
assert out.stride(-1) == 1 or out.stride(1) == 1
|
242 |
+
else:
|
243 |
+
out = torch.empty_like(a)
|
244 |
+
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
|
245 |
+
(tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
|
246 |
+
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
|
247 |
+
nchunks if not has_groups else nchunks * ngroups)
|
248 |
+
residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
|
249 |
+
residual.stride(-1))
|
250 |
+
if residual is not None else (0, 0, 0, 0))
|
251 |
+
with torch.cuda.device(a.device.index):
|
252 |
+
_bmm_chunk_bwd_kernel[grid](
|
253 |
+
a, dout, out, residual,
|
254 |
+
seqlen, chunk_size, k, ngroups if has_groups else 1,
|
255 |
+
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
|
256 |
+
dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
|
257 |
+
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
|
258 |
+
residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
|
259 |
+
dot_dtype,
|
260 |
+
HAS_RESIDUAL=residual is not None,
|
261 |
+
)
|
262 |
+
return out
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_scan.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_chunk_state.py
ADDED
@@ -0,0 +1,2012 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from .softplus import softplus
|
16 |
+
|
17 |
+
|
18 |
+
def init_to_zero(names):
|
19 |
+
return lambda nargs: [
|
20 |
+
nargs[name].zero_() for name in names if nargs[name] is not None
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
@triton.autotune(
|
25 |
+
configs=[
|
26 |
+
triton.Config({"BLOCK_SIZE_H": 1}),
|
27 |
+
triton.Config({"BLOCK_SIZE_H": 2}),
|
28 |
+
triton.Config({"BLOCK_SIZE_H": 4}),
|
29 |
+
triton.Config({"BLOCK_SIZE_H": 8}),
|
30 |
+
triton.Config({"BLOCK_SIZE_H": 16}),
|
31 |
+
triton.Config({"BLOCK_SIZE_H": 32}),
|
32 |
+
triton.Config({"BLOCK_SIZE_H": 64}),
|
33 |
+
],
|
34 |
+
key=["chunk_size", "nheads"],
|
35 |
+
)
|
36 |
+
@triton.jit
|
37 |
+
def _chunk_cumsum_fwd_kernel(
|
38 |
+
# Pointers to matrices
|
39 |
+
dt_ptr,
|
40 |
+
A_ptr,
|
41 |
+
dt_bias_ptr,
|
42 |
+
dt_out_ptr,
|
43 |
+
dA_cumsum_ptr,
|
44 |
+
# Matrix dimension
|
45 |
+
batch,
|
46 |
+
seqlen,
|
47 |
+
nheads,
|
48 |
+
chunk_size,
|
49 |
+
dt_min,
|
50 |
+
dt_max,
|
51 |
+
# Strides
|
52 |
+
stride_dt_batch,
|
53 |
+
stride_dt_seqlen,
|
54 |
+
stride_dt_head,
|
55 |
+
stride_A_head,
|
56 |
+
stride_dt_bias_head,
|
57 |
+
stride_dt_out_batch,
|
58 |
+
stride_dt_out_chunk,
|
59 |
+
stride_dt_out_head,
|
60 |
+
stride_dt_out_csize,
|
61 |
+
stride_dA_cs_batch,
|
62 |
+
stride_dA_cs_chunk,
|
63 |
+
stride_dA_cs_head,
|
64 |
+
stride_dA_cs_csize,
|
65 |
+
# Meta-parameters
|
66 |
+
DT_SOFTPLUS: tl.constexpr,
|
67 |
+
HAS_DT_BIAS: tl.constexpr,
|
68 |
+
BLOCK_SIZE_H: tl.constexpr,
|
69 |
+
BLOCK_SIZE_CHUNK: tl.constexpr,
|
70 |
+
):
|
71 |
+
pid_b = tl.program_id(axis=0)
|
72 |
+
pid_c = tl.program_id(axis=1)
|
73 |
+
pid_h = tl.program_id(axis=2)
|
74 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
75 |
+
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
76 |
+
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
77 |
+
|
78 |
+
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
79 |
+
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
80 |
+
dt_ptrs = dt_ptr + (
|
81 |
+
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
82 |
+
)
|
83 |
+
A_ptrs = A_ptr + offs_h * stride_A_head
|
84 |
+
dt_out_ptrs = dt_out_ptr + (
|
85 |
+
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
86 |
+
)
|
87 |
+
dA_cs_ptrs = dA_cumsum_ptr + (
|
88 |
+
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
89 |
+
)
|
90 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
91 |
+
|
92 |
+
dt = tl.load(
|
93 |
+
dt_ptrs,
|
94 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
95 |
+
other=0.0,
|
96 |
+
).to(tl.float32)
|
97 |
+
if HAS_DT_BIAS:
|
98 |
+
dt_bias = tl.load(
|
99 |
+
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
100 |
+
).to(tl.float32)
|
101 |
+
dt += dt_bias[:, None]
|
102 |
+
if DT_SOFTPLUS:
|
103 |
+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
104 |
+
# As of Triton 2.2.0, tl.clamp is not available yet
|
105 |
+
# dt = tl.clamp(dt, dt_min, dt_max)
|
106 |
+
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
107 |
+
dt = tl.where(
|
108 |
+
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
109 |
+
)
|
110 |
+
tl.store(
|
111 |
+
dt_out_ptrs,
|
112 |
+
dt,
|
113 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
114 |
+
)
|
115 |
+
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
116 |
+
dA = dt * A[:, None]
|
117 |
+
dA_cs = tl.cumsum(dA, axis=1)
|
118 |
+
tl.store(
|
119 |
+
dA_cs_ptrs,
|
120 |
+
dA_cs,
|
121 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
@triton.autotune(
|
126 |
+
configs=[
|
127 |
+
triton.Config(
|
128 |
+
{"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
129 |
+
),
|
130 |
+
triton.Config(
|
131 |
+
{"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
132 |
+
),
|
133 |
+
triton.Config(
|
134 |
+
{"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
135 |
+
),
|
136 |
+
triton.Config(
|
137 |
+
{"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
138 |
+
),
|
139 |
+
triton.Config(
|
140 |
+
{"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
141 |
+
),
|
142 |
+
triton.Config(
|
143 |
+
{"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
144 |
+
),
|
145 |
+
triton.Config(
|
146 |
+
{"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
147 |
+
),
|
148 |
+
],
|
149 |
+
key=["chunk_size", "nheads"],
|
150 |
+
)
|
151 |
+
@triton.jit
|
152 |
+
def _chunk_cumsum_bwd_kernel(
|
153 |
+
# Pointers to matrices
|
154 |
+
ddA_ptr,
|
155 |
+
ddt_out_ptr,
|
156 |
+
dt_ptr,
|
157 |
+
A_ptr,
|
158 |
+
dt_bias_ptr,
|
159 |
+
ddt_ptr,
|
160 |
+
dA_ptr,
|
161 |
+
ddt_bias_ptr,
|
162 |
+
# Matrix dimensions
|
163 |
+
batch,
|
164 |
+
seqlen,
|
165 |
+
nheads,
|
166 |
+
chunk_size,
|
167 |
+
dt_min,
|
168 |
+
dt_max,
|
169 |
+
# Strides
|
170 |
+
stride_ddA_batch,
|
171 |
+
stride_ddA_chunk,
|
172 |
+
stride_ddA_head,
|
173 |
+
stride_ddA_csize,
|
174 |
+
stride_ddt_out_batch,
|
175 |
+
stride_ddt_out_chunk,
|
176 |
+
stride_ddt_out_head,
|
177 |
+
stride_ddt_out_csize,
|
178 |
+
stride_dt_batch,
|
179 |
+
stride_dt_seqlen,
|
180 |
+
stride_dt_head,
|
181 |
+
stride_A_head,
|
182 |
+
stride_dt_bias_head,
|
183 |
+
stride_ddt_batch,
|
184 |
+
stride_ddt_seqlen,
|
185 |
+
stride_ddt_head,
|
186 |
+
stride_dA_head,
|
187 |
+
stride_ddt_bias_head,
|
188 |
+
# Meta-parameters
|
189 |
+
DT_SOFTPLUS: tl.constexpr,
|
190 |
+
HAS_DT_BIAS: tl.constexpr,
|
191 |
+
BLOCK_SIZE_H: tl.constexpr,
|
192 |
+
BLOCK_SIZE_CHUNK: tl.constexpr,
|
193 |
+
):
|
194 |
+
pid_b = tl.program_id(axis=0)
|
195 |
+
pid_c = tl.program_id(axis=1)
|
196 |
+
pid_h = tl.program_id(axis=2)
|
197 |
+
ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
|
198 |
+
ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
|
199 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
200 |
+
ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
|
201 |
+
|
202 |
+
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
203 |
+
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
204 |
+
ddt_out_ptrs = ddt_out_ptr + (
|
205 |
+
offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
|
206 |
+
)
|
207 |
+
ddA_ptrs = ddA_ptr + (
|
208 |
+
offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
|
209 |
+
)
|
210 |
+
dt_ptrs = dt_ptr + (
|
211 |
+
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
212 |
+
)
|
213 |
+
ddt_ptrs = ddt_ptr + (
|
214 |
+
offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
|
215 |
+
)
|
216 |
+
A_ptrs = A_ptr + offs_h * stride_A_head
|
217 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
218 |
+
|
219 |
+
ddA = tl.load(
|
220 |
+
ddA_ptrs,
|
221 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
222 |
+
other=0.0,
|
223 |
+
).to(tl.float32)
|
224 |
+
ddt_out = tl.load(
|
225 |
+
ddt_out_ptrs,
|
226 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
227 |
+
other=0.0,
|
228 |
+
).to(tl.float32)
|
229 |
+
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
230 |
+
ddt = ddA * A[:, None] + ddt_out
|
231 |
+
dt = tl.load(
|
232 |
+
dt_ptrs,
|
233 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
234 |
+
other=0.0,
|
235 |
+
).to(tl.float32)
|
236 |
+
if HAS_DT_BIAS:
|
237 |
+
dt_bias = tl.load(
|
238 |
+
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
239 |
+
).to(tl.float32)
|
240 |
+
dt += dt_bias[:, None]
|
241 |
+
if DT_SOFTPLUS:
|
242 |
+
dt_presoftplus = dt
|
243 |
+
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
|
244 |
+
clamp_mask = (dt < dt_min) | (dt > dt_max)
|
245 |
+
# As of Triton 2.2.0, tl.clamp is not available yet
|
246 |
+
# dt = tl.clamp(dt, dt_min, dt_max)
|
247 |
+
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
248 |
+
dt = tl.where(
|
249 |
+
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
250 |
+
)
|
251 |
+
ddt = tl.where(
|
252 |
+
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
|
253 |
+
)
|
254 |
+
ddt = tl.where(clamp_mask, 0.0, ddt)
|
255 |
+
if DT_SOFTPLUS:
|
256 |
+
ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
|
257 |
+
tl.store(
|
258 |
+
ddt_ptrs,
|
259 |
+
ddt,
|
260 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
261 |
+
)
|
262 |
+
dA = tl.sum(ddA * dt, axis=1)
|
263 |
+
tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
|
264 |
+
if HAS_DT_BIAS:
|
265 |
+
ddt_bias = tl.sum(ddt, axis=1)
|
266 |
+
tl.atomic_add(
|
267 |
+
ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
@triton.autotune(
|
272 |
+
configs=[
|
273 |
+
triton.Config(
|
274 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
275 |
+
num_stages=3,
|
276 |
+
num_warps=8,
|
277 |
+
),
|
278 |
+
triton.Config(
|
279 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
280 |
+
num_stages=4,
|
281 |
+
num_warps=4,
|
282 |
+
),
|
283 |
+
triton.Config(
|
284 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
285 |
+
num_stages=4,
|
286 |
+
num_warps=4,
|
287 |
+
),
|
288 |
+
triton.Config(
|
289 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
290 |
+
num_stages=4,
|
291 |
+
num_warps=4,
|
292 |
+
),
|
293 |
+
triton.Config(
|
294 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
295 |
+
num_stages=4,
|
296 |
+
num_warps=4,
|
297 |
+
),
|
298 |
+
triton.Config(
|
299 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
300 |
+
num_stages=4,
|
301 |
+
num_warps=4,
|
302 |
+
),
|
303 |
+
triton.Config(
|
304 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
305 |
+
num_stages=5,
|
306 |
+
num_warps=2,
|
307 |
+
),
|
308 |
+
triton.Config(
|
309 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
310 |
+
num_stages=5,
|
311 |
+
num_warps=2,
|
312 |
+
),
|
313 |
+
triton.Config(
|
314 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
315 |
+
num_stages=4,
|
316 |
+
num_warps=2,
|
317 |
+
),
|
318 |
+
],
|
319 |
+
key=["hdim", "dstate", "chunk_size"],
|
320 |
+
)
|
321 |
+
@triton.jit
|
322 |
+
def _chunk_state_fwd_kernel(
|
323 |
+
# Pointers to matrices
|
324 |
+
x_ptr,
|
325 |
+
b_ptr,
|
326 |
+
states_ptr,
|
327 |
+
dt_ptr,
|
328 |
+
dA_cumsum_ptr,
|
329 |
+
seq_idx_ptr,
|
330 |
+
# Matrix dimensions
|
331 |
+
hdim,
|
332 |
+
dstate,
|
333 |
+
chunk_size,
|
334 |
+
batch,
|
335 |
+
seqlen,
|
336 |
+
nheads_ngroups_ratio,
|
337 |
+
# Strides
|
338 |
+
stride_x_batch,
|
339 |
+
stride_x_seqlen,
|
340 |
+
stride_x_head,
|
341 |
+
stride_x_hdim,
|
342 |
+
stride_b_batch,
|
343 |
+
stride_b_seqlen,
|
344 |
+
stride_b_head,
|
345 |
+
stride_b_dstate,
|
346 |
+
stride_states_batch,
|
347 |
+
stride_states_chunk,
|
348 |
+
stride_states_head,
|
349 |
+
stride_states_hdim,
|
350 |
+
stride_states_dstate,
|
351 |
+
stride_dt_batch,
|
352 |
+
stride_dt_chunk,
|
353 |
+
stride_dt_head,
|
354 |
+
stride_dt_csize,
|
355 |
+
stride_dA_cs_batch,
|
356 |
+
stride_dA_cs_chunk,
|
357 |
+
stride_dA_cs_head,
|
358 |
+
stride_dA_cs_csize,
|
359 |
+
stride_seq_idx_batch,
|
360 |
+
stride_seq_idx_seqlen,
|
361 |
+
# Meta-parameters
|
362 |
+
HAS_SEQ_IDX: tl.constexpr,
|
363 |
+
BLOCK_SIZE_M: tl.constexpr,
|
364 |
+
BLOCK_SIZE_N: tl.constexpr,
|
365 |
+
BLOCK_SIZE_K: tl.constexpr,
|
366 |
+
):
|
367 |
+
pid_bc = tl.program_id(axis=1)
|
368 |
+
pid_c = pid_bc // batch
|
369 |
+
pid_b = pid_bc - pid_c * batch
|
370 |
+
pid_h = tl.program_id(axis=2)
|
371 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
372 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
373 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
374 |
+
b_ptr += (
|
375 |
+
pid_b * stride_b_batch
|
376 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
377 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
378 |
+
)
|
379 |
+
x_ptr += (
|
380 |
+
pid_b * stride_x_batch
|
381 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
382 |
+
+ pid_h * stride_x_head
|
383 |
+
)
|
384 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
385 |
+
dA_cumsum_ptr += (
|
386 |
+
pid_b * stride_dA_cs_batch
|
387 |
+
+ pid_c * stride_dA_cs_chunk
|
388 |
+
+ pid_h * stride_dA_cs_head
|
389 |
+
)
|
390 |
+
if HAS_SEQ_IDX:
|
391 |
+
seq_idx_ptr += (
|
392 |
+
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
393 |
+
)
|
394 |
+
|
395 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
396 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
397 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
398 |
+
x_ptrs = x_ptr + (
|
399 |
+
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
400 |
+
)
|
401 |
+
b_ptrs = b_ptr + (
|
402 |
+
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
403 |
+
)
|
404 |
+
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
405 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
406 |
+
tl.float32
|
407 |
+
)
|
408 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
409 |
+
if HAS_SEQ_IDX:
|
410 |
+
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
411 |
+
|
412 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
413 |
+
if HAS_SEQ_IDX:
|
414 |
+
seq_idx_last = tl.load(
|
415 |
+
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
416 |
+
)
|
417 |
+
|
418 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
419 |
+
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
420 |
+
x = tl.load(
|
421 |
+
x_ptrs,
|
422 |
+
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
423 |
+
other=0.0,
|
424 |
+
)
|
425 |
+
b = tl.load(
|
426 |
+
b_ptrs,
|
427 |
+
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
428 |
+
other=0.0,
|
429 |
+
).to(tl.float32)
|
430 |
+
dA_cs_k = tl.load(
|
431 |
+
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
432 |
+
).to(tl.float32)
|
433 |
+
if HAS_SEQ_IDX:
|
434 |
+
seq_idx_k = tl.load(
|
435 |
+
seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
|
436 |
+
)
|
437 |
+
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
438 |
+
tl.float32
|
439 |
+
)
|
440 |
+
if not HAS_SEQ_IDX:
|
441 |
+
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
|
442 |
+
else:
|
443 |
+
scale = tl.where(
|
444 |
+
seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
|
445 |
+
)
|
446 |
+
b *= scale[:, None]
|
447 |
+
b = b.to(x_ptr.dtype.element_ty)
|
448 |
+
acc += tl.dot(x, b)
|
449 |
+
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
450 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
451 |
+
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
452 |
+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
453 |
+
if HAS_SEQ_IDX:
|
454 |
+
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
455 |
+
states = acc.to(states_ptr.dtype.element_ty)
|
456 |
+
|
457 |
+
states_ptr += (
|
458 |
+
pid_b * stride_states_batch
|
459 |
+
+ pid_c * stride_states_chunk
|
460 |
+
+ pid_h * stride_states_head
|
461 |
+
)
|
462 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
463 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
464 |
+
states_ptrs = states_ptr + (
|
465 |
+
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
466 |
+
)
|
467 |
+
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
468 |
+
tl.store(states_ptrs, states, mask=c_mask)
|
469 |
+
|
470 |
+
|
471 |
+
@triton.autotune(
|
472 |
+
configs=[
|
473 |
+
triton.Config(
|
474 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
475 |
+
num_stages=3,
|
476 |
+
num_warps=8,
|
477 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
478 |
+
),
|
479 |
+
triton.Config(
|
480 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
481 |
+
num_stages=4,
|
482 |
+
num_warps=4,
|
483 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
484 |
+
),
|
485 |
+
triton.Config(
|
486 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
487 |
+
num_stages=4,
|
488 |
+
num_warps=4,
|
489 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
490 |
+
),
|
491 |
+
triton.Config(
|
492 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
493 |
+
num_stages=4,
|
494 |
+
num_warps=4,
|
495 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
496 |
+
),
|
497 |
+
triton.Config(
|
498 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
499 |
+
num_stages=4,
|
500 |
+
num_warps=4,
|
501 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
502 |
+
),
|
503 |
+
triton.Config(
|
504 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
505 |
+
num_stages=4,
|
506 |
+
num_warps=4,
|
507 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
508 |
+
),
|
509 |
+
triton.Config(
|
510 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
511 |
+
num_stages=5,
|
512 |
+
num_warps=4,
|
513 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
514 |
+
),
|
515 |
+
triton.Config(
|
516 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
517 |
+
num_stages=5,
|
518 |
+
num_warps=4,
|
519 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
520 |
+
),
|
521 |
+
triton.Config(
|
522 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
523 |
+
num_stages=4,
|
524 |
+
num_warps=4,
|
525 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
526 |
+
),
|
527 |
+
],
|
528 |
+
key=["chunk_size", "hdim", "dstate"],
|
529 |
+
)
|
530 |
+
@triton.jit
|
531 |
+
def _chunk_state_bwd_dx_kernel(
|
532 |
+
# Pointers to matrices
|
533 |
+
x_ptr,
|
534 |
+
b_ptr,
|
535 |
+
dstates_ptr,
|
536 |
+
dt_ptr,
|
537 |
+
dA_cumsum_ptr,
|
538 |
+
dx_ptr,
|
539 |
+
ddt_ptr,
|
540 |
+
ddA_cumsum_ptr,
|
541 |
+
# Matrix dimensions
|
542 |
+
chunk_size,
|
543 |
+
hdim,
|
544 |
+
dstate,
|
545 |
+
batch,
|
546 |
+
seqlen,
|
547 |
+
nheads_ngroups_ratio,
|
548 |
+
# Strides
|
549 |
+
stride_x_batch,
|
550 |
+
stride_x_seqlen,
|
551 |
+
stride_x_head,
|
552 |
+
stride_x_hdim,
|
553 |
+
stride_b_batch,
|
554 |
+
stride_b_seqlen,
|
555 |
+
stride_b_head,
|
556 |
+
stride_b_dstate,
|
557 |
+
stride_dstates_batch,
|
558 |
+
stride_dstates_chunk,
|
559 |
+
stride_states_head,
|
560 |
+
stride_states_hdim,
|
561 |
+
stride_states_dstate,
|
562 |
+
stride_dt_batch,
|
563 |
+
stride_dt_chunk,
|
564 |
+
stride_dt_head,
|
565 |
+
stride_dt_csize,
|
566 |
+
stride_dA_cs_batch,
|
567 |
+
stride_dA_cs_chunk,
|
568 |
+
stride_dA_cs_head,
|
569 |
+
stride_dA_cs_csize,
|
570 |
+
stride_dx_batch,
|
571 |
+
stride_dx_seqlen,
|
572 |
+
stride_dx_head,
|
573 |
+
stride_dx_hdim,
|
574 |
+
stride_ddt_batch,
|
575 |
+
stride_ddt_chunk,
|
576 |
+
stride_ddt_head,
|
577 |
+
stride_ddt_csize,
|
578 |
+
stride_ddA_cs_batch,
|
579 |
+
stride_ddA_cs_chunk,
|
580 |
+
stride_ddA_cs_head,
|
581 |
+
stride_ddA_cs_csize,
|
582 |
+
# Meta-parameters
|
583 |
+
BLOCK_SIZE_M: tl.constexpr,
|
584 |
+
BLOCK_SIZE_N: tl.constexpr,
|
585 |
+
BLOCK_SIZE_K: tl.constexpr,
|
586 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
587 |
+
):
|
588 |
+
pid_bc = tl.program_id(axis=1)
|
589 |
+
pid_c = pid_bc // batch
|
590 |
+
pid_b = pid_bc - pid_c * batch
|
591 |
+
pid_h = tl.program_id(axis=2)
|
592 |
+
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
593 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
594 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
595 |
+
x_ptr += (
|
596 |
+
pid_b * stride_x_batch
|
597 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
598 |
+
+ pid_h * stride_x_head
|
599 |
+
)
|
600 |
+
b_ptr += (
|
601 |
+
pid_b * stride_b_batch
|
602 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
603 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
604 |
+
)
|
605 |
+
dstates_ptr += (
|
606 |
+
pid_b * stride_dstates_batch
|
607 |
+
+ pid_c * stride_dstates_chunk
|
608 |
+
+ pid_h * stride_states_head
|
609 |
+
)
|
610 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
611 |
+
ddt_ptr += (
|
612 |
+
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
613 |
+
)
|
614 |
+
ddA_cumsum_ptr += (
|
615 |
+
pid_b * stride_ddA_cs_batch
|
616 |
+
+ pid_c * stride_ddA_cs_chunk
|
617 |
+
+ pid_h * stride_ddA_cs_head
|
618 |
+
)
|
619 |
+
dA_cumsum_ptr += (
|
620 |
+
pid_b * stride_dA_cs_batch
|
621 |
+
+ pid_c * stride_dA_cs_chunk
|
622 |
+
+ pid_h * stride_dA_cs_head
|
623 |
+
)
|
624 |
+
|
625 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
626 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
627 |
+
|
628 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
629 |
+
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
630 |
+
offs_k = tl.arange(
|
631 |
+
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
632 |
+
)
|
633 |
+
b_ptrs = b_ptr + (
|
634 |
+
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
635 |
+
)
|
636 |
+
dstates_ptrs = dstates_ptr + (
|
637 |
+
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
638 |
+
)
|
639 |
+
if BLOCK_SIZE_DSTATE <= 128:
|
640 |
+
b = tl.load(
|
641 |
+
b_ptrs,
|
642 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
643 |
+
other=0.0,
|
644 |
+
)
|
645 |
+
dstates = tl.load(
|
646 |
+
dstates_ptrs,
|
647 |
+
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
648 |
+
other=0.0,
|
649 |
+
)
|
650 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
651 |
+
acc = tl.dot(b, dstates)
|
652 |
+
else:
|
653 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
654 |
+
for k in range(0, dstate, BLOCK_SIZE_K):
|
655 |
+
b = tl.load(
|
656 |
+
b_ptrs,
|
657 |
+
mask=(offs_m[:, None] < chunk_size_limit)
|
658 |
+
& (offs_k[None, :] < dstate - k),
|
659 |
+
other=0.0,
|
660 |
+
)
|
661 |
+
dstates = tl.load(
|
662 |
+
dstates_ptrs,
|
663 |
+
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
664 |
+
other=0.0,
|
665 |
+
)
|
666 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
667 |
+
acc += tl.dot(b, dstates)
|
668 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
669 |
+
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
670 |
+
|
671 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
672 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
673 |
+
|
674 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
675 |
+
tl.float32
|
676 |
+
)
|
677 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
678 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
679 |
+
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
680 |
+
tl.float32
|
681 |
+
)
|
682 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
683 |
+
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
|
684 |
+
|
685 |
+
x_ptrs = x_ptr + (
|
686 |
+
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
687 |
+
)
|
688 |
+
x = tl.load(
|
689 |
+
x_ptrs,
|
690 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
691 |
+
other=0.0,
|
692 |
+
).to(tl.float32)
|
693 |
+
ddt = tl.sum(acc * x, axis=1)
|
694 |
+
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
695 |
+
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
696 |
+
ddA_cs = -(ddt * dt_m)
|
697 |
+
ddA_cs_last = -tl.sum(ddA_cs)
|
698 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
699 |
+
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
700 |
+
tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
|
701 |
+
|
702 |
+
dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
|
703 |
+
dx_ptr += (
|
704 |
+
pid_b * stride_dx_batch
|
705 |
+
+ pid_c * chunk_size * stride_dx_seqlen
|
706 |
+
+ pid_h * stride_dx_head
|
707 |
+
)
|
708 |
+
dx_ptrs = dx_ptr + (
|
709 |
+
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
710 |
+
)
|
711 |
+
tl.store(
|
712 |
+
dx_ptrs,
|
713 |
+
dx,
|
714 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
715 |
+
)
|
716 |
+
|
717 |
+
|
718 |
+
@triton.autotune(
|
719 |
+
configs=[
|
720 |
+
triton.Config(
|
721 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
|
722 |
+
num_stages=3,
|
723 |
+
num_warps=4,
|
724 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
725 |
+
),
|
726 |
+
triton.Config(
|
727 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
|
728 |
+
num_stages=3,
|
729 |
+
num_warps=4,
|
730 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
731 |
+
),
|
732 |
+
triton.Config(
|
733 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
|
734 |
+
num_stages=3,
|
735 |
+
num_warps=4,
|
736 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
737 |
+
),
|
738 |
+
triton.Config(
|
739 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
|
740 |
+
num_stages=3,
|
741 |
+
num_warps=4,
|
742 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
743 |
+
),
|
744 |
+
triton.Config(
|
745 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
|
746 |
+
num_stages=3,
|
747 |
+
num_warps=4,
|
748 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
749 |
+
),
|
750 |
+
triton.Config(
|
751 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
|
752 |
+
num_stages=3,
|
753 |
+
num_warps=4,
|
754 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
755 |
+
),
|
756 |
+
triton.Config(
|
757 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
|
758 |
+
num_stages=3,
|
759 |
+
num_warps=4,
|
760 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
761 |
+
),
|
762 |
+
triton.Config(
|
763 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
|
764 |
+
num_stages=3,
|
765 |
+
num_warps=4,
|
766 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
767 |
+
),
|
768 |
+
],
|
769 |
+
key=["chunk_size", "dstate", "hdim"],
|
770 |
+
)
|
771 |
+
@triton.jit
|
772 |
+
def _chunk_state_bwd_db_kernel(
|
773 |
+
# Pointers to matrices
|
774 |
+
x_ptr,
|
775 |
+
dstates_ptr,
|
776 |
+
b_ptr,
|
777 |
+
dt_ptr,
|
778 |
+
dA_cumsum_ptr,
|
779 |
+
seq_idx_ptr,
|
780 |
+
db_ptr,
|
781 |
+
ddA_cumsum_ptr,
|
782 |
+
# Matrix dimensions
|
783 |
+
chunk_size,
|
784 |
+
dstate,
|
785 |
+
hdim,
|
786 |
+
batch,
|
787 |
+
seqlen,
|
788 |
+
nheads,
|
789 |
+
nheads_per_program,
|
790 |
+
ngroups,
|
791 |
+
# Strides
|
792 |
+
stride_x_batch,
|
793 |
+
stride_x_seqlen,
|
794 |
+
stride_x_head,
|
795 |
+
stride_x_hdim,
|
796 |
+
stride_dstates_batch,
|
797 |
+
stride_dstates_chunk,
|
798 |
+
stride_states_head,
|
799 |
+
stride_states_hdim,
|
800 |
+
stride_states_dstate,
|
801 |
+
stride_b_batch,
|
802 |
+
stride_b_seqlen,
|
803 |
+
stride_b_head,
|
804 |
+
stride_b_dstate,
|
805 |
+
stride_dt_batch,
|
806 |
+
stride_dt_chunk,
|
807 |
+
stride_dt_head,
|
808 |
+
stride_dt_csize,
|
809 |
+
stride_dA_cs_batch,
|
810 |
+
stride_dA_cs_chunk,
|
811 |
+
stride_dA_cs_head,
|
812 |
+
stride_dA_cs_csize,
|
813 |
+
stride_seq_idx_batch,
|
814 |
+
stride_seq_idx_seqlen,
|
815 |
+
stride_db_batch,
|
816 |
+
stride_db_seqlen,
|
817 |
+
stride_db_split,
|
818 |
+
stride_db_group,
|
819 |
+
stride_db_dstate,
|
820 |
+
stride_ddA_cs_batch,
|
821 |
+
stride_ddA_cs_chunk,
|
822 |
+
stride_ddA_cs_head,
|
823 |
+
stride_ddA_cs_csize,
|
824 |
+
# Meta-parameters
|
825 |
+
HAS_DDA_CS: tl.constexpr,
|
826 |
+
HAS_SEQ_IDX: tl.constexpr,
|
827 |
+
BLOCK_SIZE_M: tl.constexpr,
|
828 |
+
BLOCK_SIZE_N: tl.constexpr,
|
829 |
+
BLOCK_SIZE_K: tl.constexpr,
|
830 |
+
):
|
831 |
+
pid_bc = tl.program_id(axis=1)
|
832 |
+
pid_c = pid_bc // batch
|
833 |
+
pid_b = pid_bc - pid_c * batch
|
834 |
+
pid_sg = tl.program_id(axis=2)
|
835 |
+
pid_s = pid_sg // ngroups
|
836 |
+
pid_g = pid_sg - pid_s * ngroups
|
837 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
838 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
839 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
840 |
+
x_ptr += (
|
841 |
+
pid_b * stride_x_batch
|
842 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
843 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
|
844 |
+
)
|
845 |
+
db_ptr += (
|
846 |
+
pid_b * stride_db_batch
|
847 |
+
+ pid_c * chunk_size * stride_db_seqlen
|
848 |
+
+ pid_g * stride_db_group
|
849 |
+
+ pid_s * stride_db_split
|
850 |
+
)
|
851 |
+
dstates_ptr += (
|
852 |
+
pid_b * stride_dstates_batch
|
853 |
+
+ pid_c * stride_dstates_chunk
|
854 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
855 |
+
* stride_states_head
|
856 |
+
)
|
857 |
+
dt_ptr += (
|
858 |
+
pid_b * stride_dt_batch
|
859 |
+
+ pid_c * stride_dt_chunk
|
860 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
|
861 |
+
)
|
862 |
+
dA_cumsum_ptr += (
|
863 |
+
pid_b * stride_dA_cs_batch
|
864 |
+
+ pid_c * stride_dA_cs_chunk
|
865 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
|
866 |
+
)
|
867 |
+
if HAS_DDA_CS:
|
868 |
+
b_ptr += (
|
869 |
+
pid_b * stride_b_batch
|
870 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
871 |
+
+ pid_g * stride_b_head
|
872 |
+
)
|
873 |
+
ddA_cumsum_ptr += (
|
874 |
+
pid_b * stride_ddA_cs_batch
|
875 |
+
+ pid_c * stride_ddA_cs_chunk
|
876 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
877 |
+
* stride_ddA_cs_head
|
878 |
+
)
|
879 |
+
if HAS_SEQ_IDX:
|
880 |
+
seq_idx_ptr += (
|
881 |
+
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
882 |
+
)
|
883 |
+
|
884 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
885 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
886 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
887 |
+
x_ptrs = x_ptr + (
|
888 |
+
offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
|
889 |
+
)
|
890 |
+
dstates_ptrs = dstates_ptr + (
|
891 |
+
offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
|
892 |
+
)
|
893 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
894 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
895 |
+
if HAS_DDA_CS:
|
896 |
+
b_ptrs = b_ptr + (
|
897 |
+
offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
|
898 |
+
)
|
899 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
900 |
+
|
901 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
902 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
903 |
+
if HAS_DDA_CS:
|
904 |
+
b = tl.load(
|
905 |
+
b_ptrs,
|
906 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
907 |
+
other=0.0,
|
908 |
+
).to(tl.float32)
|
909 |
+
if HAS_SEQ_IDX:
|
910 |
+
seq_idx_m = tl.load(
|
911 |
+
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
912 |
+
mask=offs_m < chunk_size_limit,
|
913 |
+
other=-1,
|
914 |
+
)
|
915 |
+
seq_idx_last = tl.load(
|
916 |
+
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
917 |
+
)
|
918 |
+
nheads_iter = min(
|
919 |
+
nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
|
920 |
+
)
|
921 |
+
for h in range(nheads_iter):
|
922 |
+
x = tl.load(
|
923 |
+
x_ptrs,
|
924 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
|
925 |
+
other=0.0,
|
926 |
+
)
|
927 |
+
dstates = tl.load(
|
928 |
+
dstates_ptrs,
|
929 |
+
mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
|
930 |
+
other=0.0,
|
931 |
+
)
|
932 |
+
dstates = dstates.to(x_ptrs.dtype.element_ty)
|
933 |
+
db = tl.dot(x, dstates)
|
934 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
935 |
+
tl.float32
|
936 |
+
)
|
937 |
+
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
938 |
+
tl.float32
|
939 |
+
)
|
940 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
941 |
+
if not HAS_SEQ_IDX:
|
942 |
+
scale = tl.exp(dA_cs_last - dA_cs_m)
|
943 |
+
else:
|
944 |
+
scale = tl.where(
|
945 |
+
seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
|
946 |
+
)
|
947 |
+
db *= (scale * dt_m)[:, None]
|
948 |
+
if HAS_DDA_CS:
|
949 |
+
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
|
950 |
+
ddA_cs = tl.sum(db * b, axis=1)
|
951 |
+
tl.atomic_add(
|
952 |
+
ddA_cumsum_ptrs + stride_ddA_cs_csize,
|
953 |
+
ddA_cs,
|
954 |
+
mask=offs_m < chunk_size - 1,
|
955 |
+
)
|
956 |
+
acc += db
|
957 |
+
x_ptrs += stride_x_head
|
958 |
+
dstates_ptrs += stride_states_head
|
959 |
+
dt_ptrs += stride_dt_head
|
960 |
+
dA_cumsum_ptr += stride_dA_cs_head
|
961 |
+
dA_cumsum_ptrs += stride_dA_cs_head
|
962 |
+
if HAS_DDA_CS:
|
963 |
+
ddA_cumsum_ptrs += stride_ddA_cs_head
|
964 |
+
|
965 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
966 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
967 |
+
# if HAS_SEQ_IDX:
|
968 |
+
# seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
969 |
+
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
970 |
+
# acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
|
971 |
+
db_ptrs = db_ptr + (
|
972 |
+
offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
|
973 |
+
)
|
974 |
+
tl.store(
|
975 |
+
db_ptrs,
|
976 |
+
acc,
|
977 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
978 |
+
)
|
979 |
+
|
980 |
+
|
981 |
+
@triton.autotune(
|
982 |
+
configs=[
|
983 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
984 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
985 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
986 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
987 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
988 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
989 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
990 |
+
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
991 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
992 |
+
triton.Config(
|
993 |
+
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
994 |
+
num_stages=3,
|
995 |
+
num_warps=4,
|
996 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
997 |
+
),
|
998 |
+
triton.Config(
|
999 |
+
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
1000 |
+
num_stages=3,
|
1001 |
+
num_warps=4,
|
1002 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1003 |
+
),
|
1004 |
+
triton.Config(
|
1005 |
+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1006 |
+
num_stages=3,
|
1007 |
+
num_warps=4,
|
1008 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1009 |
+
),
|
1010 |
+
triton.Config(
|
1011 |
+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
1012 |
+
num_stages=3,
|
1013 |
+
num_warps=4,
|
1014 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1015 |
+
),
|
1016 |
+
triton.Config(
|
1017 |
+
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
1018 |
+
num_stages=4,
|
1019 |
+
num_warps=8,
|
1020 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1021 |
+
),
|
1022 |
+
triton.Config(
|
1023 |
+
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
1024 |
+
num_stages=4,
|
1025 |
+
num_warps=8,
|
1026 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1027 |
+
),
|
1028 |
+
triton.Config(
|
1029 |
+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1030 |
+
num_stages=4,
|
1031 |
+
num_warps=8,
|
1032 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1033 |
+
),
|
1034 |
+
triton.Config(
|
1035 |
+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
1036 |
+
num_stages=4,
|
1037 |
+
num_warps=8,
|
1038 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1039 |
+
),
|
1040 |
+
],
|
1041 |
+
key=["chunk_size", "hdim", "dstate"],
|
1042 |
+
)
|
1043 |
+
@triton.jit
|
1044 |
+
def _chunk_state_bwd_ddAcs_stable_kernel(
|
1045 |
+
# Pointers to matrices
|
1046 |
+
x_ptr,
|
1047 |
+
b_ptr,
|
1048 |
+
dstates_ptr,
|
1049 |
+
dt_ptr,
|
1050 |
+
dA_cumsum_ptr,
|
1051 |
+
seq_idx_ptr,
|
1052 |
+
ddA_cumsum_ptr,
|
1053 |
+
# Matrix dimensions
|
1054 |
+
chunk_size,
|
1055 |
+
hdim,
|
1056 |
+
dstate,
|
1057 |
+
batch,
|
1058 |
+
seqlen,
|
1059 |
+
nheads_ngroups_ratio,
|
1060 |
+
# Strides
|
1061 |
+
stride_x_batch,
|
1062 |
+
stride_x_seqlen,
|
1063 |
+
stride_x_head,
|
1064 |
+
stride_x_hdim,
|
1065 |
+
stride_b_batch,
|
1066 |
+
stride_b_seqlen,
|
1067 |
+
stride_b_head,
|
1068 |
+
stride_b_dstate,
|
1069 |
+
stride_dstates_batch,
|
1070 |
+
stride_dstates_chunk,
|
1071 |
+
stride_states_head,
|
1072 |
+
stride_states_hdim,
|
1073 |
+
stride_states_dstate,
|
1074 |
+
stride_dt_batch,
|
1075 |
+
stride_dt_chunk,
|
1076 |
+
stride_dt_head,
|
1077 |
+
stride_dt_csize,
|
1078 |
+
stride_dA_cs_batch,
|
1079 |
+
stride_dA_cs_chunk,
|
1080 |
+
stride_dA_cs_head,
|
1081 |
+
stride_dA_cs_csize,
|
1082 |
+
stride_seq_idx_batch,
|
1083 |
+
stride_seq_idx_seqlen,
|
1084 |
+
stride_ddA_cs_batch,
|
1085 |
+
stride_ddA_cs_chunk,
|
1086 |
+
stride_ddA_cs_head,
|
1087 |
+
stride_ddA_cs_csize,
|
1088 |
+
# Meta-parameters
|
1089 |
+
HAS_SEQ_IDX: tl.constexpr,
|
1090 |
+
BLOCK_SIZE_M: tl.constexpr,
|
1091 |
+
BLOCK_SIZE_N: tl.constexpr,
|
1092 |
+
BLOCK_SIZE_K: tl.constexpr,
|
1093 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
1094 |
+
):
|
1095 |
+
pid_bc = tl.program_id(axis=1)
|
1096 |
+
pid_c = pid_bc // batch
|
1097 |
+
pid_b = pid_bc - pid_c * batch
|
1098 |
+
pid_h = tl.program_id(axis=2)
|
1099 |
+
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
1100 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
1101 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
1102 |
+
x_ptr += (
|
1103 |
+
pid_b * stride_x_batch
|
1104 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
1105 |
+
+ pid_h * stride_x_head
|
1106 |
+
)
|
1107 |
+
b_ptr += (
|
1108 |
+
pid_b * stride_b_batch
|
1109 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
1110 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
1111 |
+
)
|
1112 |
+
dstates_ptr += (
|
1113 |
+
pid_b * stride_dstates_batch
|
1114 |
+
+ pid_c * stride_dstates_chunk
|
1115 |
+
+ pid_h * stride_states_head
|
1116 |
+
)
|
1117 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
1118 |
+
ddA_cumsum_ptr += (
|
1119 |
+
pid_b * stride_ddA_cs_batch
|
1120 |
+
+ pid_c * stride_ddA_cs_chunk
|
1121 |
+
+ pid_h * stride_ddA_cs_head
|
1122 |
+
)
|
1123 |
+
dA_cumsum_ptr += (
|
1124 |
+
pid_b * stride_dA_cs_batch
|
1125 |
+
+ pid_c * stride_dA_cs_chunk
|
1126 |
+
+ pid_h * stride_dA_cs_head
|
1127 |
+
)
|
1128 |
+
if HAS_SEQ_IDX:
|
1129 |
+
seq_idx_ptr += (
|
1130 |
+
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1134 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
1135 |
+
|
1136 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
1137 |
+
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
1138 |
+
offs_k = tl.arange(
|
1139 |
+
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
1140 |
+
)
|
1141 |
+
b_ptrs = b_ptr + (
|
1142 |
+
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
1143 |
+
)
|
1144 |
+
dstates_ptrs = dstates_ptr + (
|
1145 |
+
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
1146 |
+
)
|
1147 |
+
if BLOCK_SIZE_DSTATE <= 128:
|
1148 |
+
b = tl.load(
|
1149 |
+
b_ptrs,
|
1150 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
1151 |
+
other=0.0,
|
1152 |
+
)
|
1153 |
+
dstates = tl.load(
|
1154 |
+
dstates_ptrs,
|
1155 |
+
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
1156 |
+
other=0.0,
|
1157 |
+
)
|
1158 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
1159 |
+
acc = tl.dot(b, dstates)
|
1160 |
+
else:
|
1161 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
1162 |
+
for k in range(0, dstate, BLOCK_SIZE_K):
|
1163 |
+
b = tl.load(
|
1164 |
+
b_ptrs,
|
1165 |
+
mask=(offs_m[:, None] < chunk_size_limit)
|
1166 |
+
& (offs_k[None, :] < dstate - k),
|
1167 |
+
other=0.0,
|
1168 |
+
)
|
1169 |
+
dstates = tl.load(
|
1170 |
+
dstates_ptrs,
|
1171 |
+
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
1172 |
+
other=0.0,
|
1173 |
+
)
|
1174 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
1175 |
+
acc += tl.dot(b, dstates)
|
1176 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
1177 |
+
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
1178 |
+
|
1179 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1180 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
1181 |
+
|
1182 |
+
dA_cs_m = tl.load(
|
1183 |
+
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
1184 |
+
).to(tl.float32)
|
1185 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
1186 |
+
tl.float32
|
1187 |
+
)
|
1188 |
+
if not HAS_SEQ_IDX:
|
1189 |
+
scale = tl.exp(dA_cs_last - dA_cs_m)
|
1190 |
+
else:
|
1191 |
+
seq_idx_m = tl.load(
|
1192 |
+
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
1193 |
+
mask=offs_m < chunk_size_limit,
|
1194 |
+
other=-1,
|
1195 |
+
)
|
1196 |
+
seq_idx_last = tl.load(
|
1197 |
+
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
1198 |
+
)
|
1199 |
+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
1200 |
+
acc *= scale[:, None]
|
1201 |
+
|
1202 |
+
x_ptrs = x_ptr + (
|
1203 |
+
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
1204 |
+
)
|
1205 |
+
x = tl.load(
|
1206 |
+
x_ptrs,
|
1207 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
1208 |
+
other=0.0,
|
1209 |
+
).to(tl.float32)
|
1210 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
1211 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
1212 |
+
ddt = tl.sum(acc * x, axis=1)
|
1213 |
+
# ddA_cs = -(ddt * dt_m)
|
1214 |
+
# Triton 2.2.0 errors if we have the cumsum here, so we just write it out
|
1215 |
+
# then call torch.cumsum outside this kernel.
|
1216 |
+
# ddA_cs = tl.cumsum(ddt * dt_m)
|
1217 |
+
ddA_cs = ddt * dt_m
|
1218 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
1219 |
+
# tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
1220 |
+
tl.atomic_add(
|
1221 |
+
ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
|
1222 |
+
)
|
1223 |
+
|
1224 |
+
|
1225 |
+
@triton.autotune(
|
1226 |
+
configs=[
|
1227 |
+
triton.Config(
|
1228 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
1229 |
+
num_stages=3,
|
1230 |
+
num_warps=8,
|
1231 |
+
),
|
1232 |
+
triton.Config(
|
1233 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
1234 |
+
num_stages=4,
|
1235 |
+
num_warps=4,
|
1236 |
+
),
|
1237 |
+
triton.Config(
|
1238 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
1239 |
+
num_stages=4,
|
1240 |
+
num_warps=4,
|
1241 |
+
),
|
1242 |
+
triton.Config(
|
1243 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1244 |
+
num_stages=4,
|
1245 |
+
num_warps=4,
|
1246 |
+
),
|
1247 |
+
triton.Config(
|
1248 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
1249 |
+
num_stages=4,
|
1250 |
+
num_warps=4,
|
1251 |
+
),
|
1252 |
+
triton.Config(
|
1253 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
1254 |
+
num_stages=4,
|
1255 |
+
num_warps=4,
|
1256 |
+
),
|
1257 |
+
triton.Config(
|
1258 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
1259 |
+
num_stages=5,
|
1260 |
+
num_warps=2,
|
1261 |
+
),
|
1262 |
+
triton.Config(
|
1263 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1264 |
+
num_stages=5,
|
1265 |
+
num_warps=2,
|
1266 |
+
),
|
1267 |
+
triton.Config(
|
1268 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1269 |
+
num_stages=4,
|
1270 |
+
num_warps=2,
|
1271 |
+
),
|
1272 |
+
],
|
1273 |
+
key=["hdim", "dstate", "chunk_size"],
|
1274 |
+
)
|
1275 |
+
@triton.jit
|
1276 |
+
def _chunk_state_varlen_kernel(
|
1277 |
+
# Pointers to matrices
|
1278 |
+
x_ptr,
|
1279 |
+
b_ptr,
|
1280 |
+
dt_ptr,
|
1281 |
+
dA_cumsum_ptr,
|
1282 |
+
chunk_states_ptr,
|
1283 |
+
cu_seqlens_ptr,
|
1284 |
+
states_ptr,
|
1285 |
+
# Matrix dimensions
|
1286 |
+
hdim,
|
1287 |
+
dstate,
|
1288 |
+
chunk_size,
|
1289 |
+
seqlen,
|
1290 |
+
nheads_ngroups_ratio,
|
1291 |
+
# Strides
|
1292 |
+
stride_x_seqlen,
|
1293 |
+
stride_x_head,
|
1294 |
+
stride_x_hdim,
|
1295 |
+
stride_b_seqlen,
|
1296 |
+
stride_b_head,
|
1297 |
+
stride_b_dstate,
|
1298 |
+
stride_dt_chunk,
|
1299 |
+
stride_dt_head,
|
1300 |
+
stride_dt_csize,
|
1301 |
+
stride_dA_cs_chunk,
|
1302 |
+
stride_dA_cs_head,
|
1303 |
+
stride_dA_cs_csize,
|
1304 |
+
stride_chunk_states_chunk,
|
1305 |
+
stride_chunk_states_head,
|
1306 |
+
stride_chunk_states_hdim,
|
1307 |
+
stride_chunk_states_dstate,
|
1308 |
+
stride_states_batch,
|
1309 |
+
stride_states_head,
|
1310 |
+
stride_states_hdim,
|
1311 |
+
stride_states_dstate,
|
1312 |
+
# Meta-parameters
|
1313 |
+
BLOCK_SIZE_M: tl.constexpr,
|
1314 |
+
BLOCK_SIZE_N: tl.constexpr,
|
1315 |
+
BLOCK_SIZE_K: tl.constexpr,
|
1316 |
+
):
|
1317 |
+
pid_b = tl.program_id(axis=1)
|
1318 |
+
pid_h = tl.program_id(axis=2)
|
1319 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
1320 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
1321 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
1322 |
+
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
1323 |
+
pid_c = (end_idx - 1) // chunk_size
|
1324 |
+
b_ptr += (
|
1325 |
+
pid_c * chunk_size * stride_b_seqlen
|
1326 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
1327 |
+
)
|
1328 |
+
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
1329 |
+
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
1330 |
+
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
1331 |
+
chunk_states_ptr += (
|
1332 |
+
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
1333 |
+
)
|
1334 |
+
|
1335 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1336 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
1337 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
1338 |
+
x_ptrs = x_ptr + (
|
1339 |
+
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
1340 |
+
)
|
1341 |
+
b_ptrs = b_ptr + (
|
1342 |
+
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
1343 |
+
)
|
1344 |
+
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
1345 |
+
dA_cs_last = tl.load(
|
1346 |
+
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
1347 |
+
).to(tl.float32)
|
1348 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
1349 |
+
|
1350 |
+
chunk_size_limit = end_idx - pid_c * chunk_size
|
1351 |
+
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
1352 |
+
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
1353 |
+
|
1354 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
1355 |
+
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
1356 |
+
x = tl.load(
|
1357 |
+
x_ptrs,
|
1358 |
+
mask=(offs_m[:, None] < hdim)
|
1359 |
+
& (offs_k[None, :] < chunk_size_limit - k)
|
1360 |
+
& (offs_k[None, :] >= start_idx_cur - k),
|
1361 |
+
other=0.0,
|
1362 |
+
)
|
1363 |
+
b = tl.load(
|
1364 |
+
b_ptrs,
|
1365 |
+
mask=(offs_k[:, None] < chunk_size_limit - k)
|
1366 |
+
& (offs_n[None, :] < dstate)
|
1367 |
+
& (offs_k[:, None] >= start_idx_cur - k),
|
1368 |
+
other=0.0,
|
1369 |
+
).to(tl.float32)
|
1370 |
+
dA_cs_k = tl.load(
|
1371 |
+
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
1372 |
+
).to(tl.float32)
|
1373 |
+
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
1374 |
+
tl.float32
|
1375 |
+
)
|
1376 |
+
scale = tl.where(
|
1377 |
+
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
1378 |
+
tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
|
1379 |
+
0.0,
|
1380 |
+
)
|
1381 |
+
b *= scale[:, None]
|
1382 |
+
b = b.to(x_ptr.dtype.element_ty)
|
1383 |
+
acc += tl.dot(x, b)
|
1384 |
+
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
1385 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
1386 |
+
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
1387 |
+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
1388 |
+
|
1389 |
+
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
1390 |
+
if start_idx < pid_c * chunk_size:
|
1391 |
+
chunk_states_ptrs = chunk_states_ptr + (
|
1392 |
+
offs_m[:, None] * stride_chunk_states_hdim
|
1393 |
+
+ offs_n[None, :] * stride_chunk_states_dstate
|
1394 |
+
)
|
1395 |
+
chunk_states = tl.load(
|
1396 |
+
chunk_states_ptrs,
|
1397 |
+
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
1398 |
+
other=0.0,
|
1399 |
+
).to(tl.float32)
|
1400 |
+
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
|
1401 |
+
scale = tl.exp(dA_cs_last)
|
1402 |
+
acc += chunk_states * scale
|
1403 |
+
|
1404 |
+
states = acc.to(states_ptr.dtype.element_ty)
|
1405 |
+
|
1406 |
+
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
1407 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1408 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
1409 |
+
states_ptrs = states_ptr + (
|
1410 |
+
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
1411 |
+
)
|
1412 |
+
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
1413 |
+
tl.store(states_ptrs, states, mask=c_mask)
|
1414 |
+
|
1415 |
+
|
1416 |
+
def _chunk_cumsum_fwd(
|
1417 |
+
dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
|
1418 |
+
):
|
1419 |
+
batch, seqlen, nheads = dt.shape
|
1420 |
+
assert A.shape == (nheads,)
|
1421 |
+
if dt_bias is not None:
|
1422 |
+
assert dt_bias.shape == (nheads,)
|
1423 |
+
nchunks = math.ceil(seqlen / chunk_size)
|
1424 |
+
dt_out = torch.empty(
|
1425 |
+
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
1426 |
+
)
|
1427 |
+
dA_cumsum = torch.empty(
|
1428 |
+
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
1429 |
+
)
|
1430 |
+
grid_chunk_cs = lambda META: (
|
1431 |
+
batch,
|
1432 |
+
nchunks,
|
1433 |
+
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
1434 |
+
)
|
1435 |
+
with torch.cuda.device(dt.device.index):
|
1436 |
+
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
1437 |
+
dt,
|
1438 |
+
A,
|
1439 |
+
dt_bias,
|
1440 |
+
dt_out,
|
1441 |
+
dA_cumsum,
|
1442 |
+
batch,
|
1443 |
+
seqlen,
|
1444 |
+
nheads,
|
1445 |
+
chunk_size,
|
1446 |
+
dt_limit[0],
|
1447 |
+
dt_limit[1],
|
1448 |
+
dt.stride(0),
|
1449 |
+
dt.stride(1),
|
1450 |
+
dt.stride(2),
|
1451 |
+
A.stride(0),
|
1452 |
+
dt_bias.stride(0) if dt_bias is not None else 0,
|
1453 |
+
dt_out.stride(0),
|
1454 |
+
dt_out.stride(2),
|
1455 |
+
dt_out.stride(1),
|
1456 |
+
dt_out.stride(3),
|
1457 |
+
dA_cumsum.stride(0),
|
1458 |
+
dA_cumsum.stride(2),
|
1459 |
+
dA_cumsum.stride(1),
|
1460 |
+
dA_cumsum.stride(3),
|
1461 |
+
dt_softplus,
|
1462 |
+
HAS_DT_BIAS=dt_bias is not None,
|
1463 |
+
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
1464 |
+
)
|
1465 |
+
return dA_cumsum, dt_out
|
1466 |
+
|
1467 |
+
|
1468 |
+
def _chunk_cumsum_bwd(
|
1469 |
+
ddA,
|
1470 |
+
ddt_out,
|
1471 |
+
dt,
|
1472 |
+
A,
|
1473 |
+
dt_bias=None,
|
1474 |
+
dt_softplus=False,
|
1475 |
+
dt_limit=(0.0, float("inf")),
|
1476 |
+
ddt=None,
|
1477 |
+
):
|
1478 |
+
batch, seqlen, nheads = dt.shape
|
1479 |
+
_, _, nchunks, chunk_size = ddA.shape
|
1480 |
+
assert ddA.shape == (batch, nheads, nchunks, chunk_size)
|
1481 |
+
assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
|
1482 |
+
assert A.shape == (nheads,)
|
1483 |
+
if dt_bias is not None:
|
1484 |
+
assert dt_bias.shape == (nheads,)
|
1485 |
+
ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
|
1486 |
+
else:
|
1487 |
+
ddt_bias = None
|
1488 |
+
if ddt is not None:
|
1489 |
+
assert ddt.shape == dt.shape
|
1490 |
+
else:
|
1491 |
+
ddt = torch.empty_like(dt)
|
1492 |
+
dA = torch.empty_like(A, dtype=torch.float32)
|
1493 |
+
grid_chunk_cs = lambda META: (
|
1494 |
+
batch,
|
1495 |
+
nchunks,
|
1496 |
+
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
1497 |
+
)
|
1498 |
+
with torch.cuda.device(dt.device.index):
|
1499 |
+
_chunk_cumsum_bwd_kernel[grid_chunk_cs](
|
1500 |
+
ddA,
|
1501 |
+
ddt_out,
|
1502 |
+
dt,
|
1503 |
+
A,
|
1504 |
+
dt_bias,
|
1505 |
+
ddt,
|
1506 |
+
dA,
|
1507 |
+
ddt_bias,
|
1508 |
+
batch,
|
1509 |
+
seqlen,
|
1510 |
+
nheads,
|
1511 |
+
chunk_size,
|
1512 |
+
dt_limit[0],
|
1513 |
+
dt_limit[1],
|
1514 |
+
ddA.stride(0),
|
1515 |
+
ddA.stride(2),
|
1516 |
+
ddA.stride(1),
|
1517 |
+
ddA.stride(3),
|
1518 |
+
ddt_out.stride(0),
|
1519 |
+
ddt_out.stride(2),
|
1520 |
+
ddt_out.stride(1),
|
1521 |
+
ddt_out.stride(3),
|
1522 |
+
dt.stride(0),
|
1523 |
+
dt.stride(1),
|
1524 |
+
dt.stride(2),
|
1525 |
+
A.stride(0),
|
1526 |
+
dt_bias.stride(0) if dt_bias is not None else 0,
|
1527 |
+
ddt.stride(0),
|
1528 |
+
ddt.stride(1),
|
1529 |
+
ddt.stride(2),
|
1530 |
+
dA.stride(0),
|
1531 |
+
ddt_bias.stride(0) if ddt_bias is not None else 0,
|
1532 |
+
dt_softplus,
|
1533 |
+
HAS_DT_BIAS=dt_bias is not None,
|
1534 |
+
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
1535 |
+
)
|
1536 |
+
return ddt, dA, ddt_bias
|
1537 |
+
|
1538 |
+
|
1539 |
+
def _chunk_state_fwd(
|
1540 |
+
B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
|
1541 |
+
):
|
1542 |
+
batch, seqlen, nheads, headdim = x.shape
|
1543 |
+
_, _, nchunks, chunk_size = dt.shape
|
1544 |
+
_, _, ngroups, dstate = B.shape
|
1545 |
+
assert nheads % ngroups == 0
|
1546 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1547 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1548 |
+
assert dA_cumsum.shape == dt.shape
|
1549 |
+
if seq_idx is not None:
|
1550 |
+
assert seq_idx.shape == (batch, seqlen)
|
1551 |
+
if states is not None:
|
1552 |
+
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
1553 |
+
else:
|
1554 |
+
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
1555 |
+
states = torch.empty(
|
1556 |
+
(batch, nchunks, nheads, headdim, dstate),
|
1557 |
+
device=x.device,
|
1558 |
+
dtype=states_dtype,
|
1559 |
+
)
|
1560 |
+
grid = lambda META: (
|
1561 |
+
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
1562 |
+
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
1563 |
+
batch * nchunks,
|
1564 |
+
nheads,
|
1565 |
+
)
|
1566 |
+
with torch.cuda.device(x.device.index):
|
1567 |
+
_chunk_state_fwd_kernel[grid](
|
1568 |
+
x,
|
1569 |
+
B,
|
1570 |
+
states,
|
1571 |
+
dt,
|
1572 |
+
dA_cumsum,
|
1573 |
+
seq_idx,
|
1574 |
+
headdim,
|
1575 |
+
dstate,
|
1576 |
+
chunk_size,
|
1577 |
+
batch,
|
1578 |
+
seqlen,
|
1579 |
+
nheads // ngroups,
|
1580 |
+
x.stride(0),
|
1581 |
+
x.stride(1),
|
1582 |
+
x.stride(2),
|
1583 |
+
x.stride(3),
|
1584 |
+
B.stride(0),
|
1585 |
+
B.stride(1),
|
1586 |
+
B.stride(2),
|
1587 |
+
B.stride(-1),
|
1588 |
+
states.stride(0),
|
1589 |
+
states.stride(1),
|
1590 |
+
states.stride(2),
|
1591 |
+
states.stride(3),
|
1592 |
+
states.stride(4),
|
1593 |
+
dt.stride(0),
|
1594 |
+
dt.stride(2),
|
1595 |
+
dt.stride(1),
|
1596 |
+
dt.stride(3),
|
1597 |
+
dA_cumsum.stride(0),
|
1598 |
+
dA_cumsum.stride(2),
|
1599 |
+
dA_cumsum.stride(1),
|
1600 |
+
dA_cumsum.stride(3),
|
1601 |
+
*(
|
1602 |
+
(seq_idx.stride(0), seq_idx.stride(1))
|
1603 |
+
if seq_idx is not None
|
1604 |
+
else (0, 0)
|
1605 |
+
),
|
1606 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
1607 |
+
)
|
1608 |
+
return states
|
1609 |
+
|
1610 |
+
|
1611 |
+
def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
|
1612 |
+
batch, seqlen, nheads, headdim = x.shape
|
1613 |
+
_, _, nchunks, chunk_size = dt.shape
|
1614 |
+
_, _, ngroups, dstate = B.shape
|
1615 |
+
assert nheads % ngroups == 0
|
1616 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1617 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1618 |
+
assert dA_cumsum.shape == dt.shape
|
1619 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
1620 |
+
if dx is not None:
|
1621 |
+
assert dx.shape == x.shape
|
1622 |
+
else:
|
1623 |
+
dx = torch.empty_like(x)
|
1624 |
+
ddt = torch.empty(
|
1625 |
+
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
1626 |
+
)
|
1627 |
+
ddA_cumsum = torch.empty(
|
1628 |
+
batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
|
1629 |
+
)
|
1630 |
+
grid_dx = lambda META: (
|
1631 |
+
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
1632 |
+
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
1633 |
+
batch * nchunks,
|
1634 |
+
nheads,
|
1635 |
+
)
|
1636 |
+
with torch.cuda.device(x.device.index):
|
1637 |
+
_chunk_state_bwd_dx_kernel[grid_dx](
|
1638 |
+
x,
|
1639 |
+
B,
|
1640 |
+
dstates,
|
1641 |
+
dt,
|
1642 |
+
dA_cumsum,
|
1643 |
+
dx,
|
1644 |
+
ddt,
|
1645 |
+
ddA_cumsum,
|
1646 |
+
chunk_size,
|
1647 |
+
headdim,
|
1648 |
+
dstate,
|
1649 |
+
batch,
|
1650 |
+
seqlen,
|
1651 |
+
nheads // ngroups,
|
1652 |
+
x.stride(0),
|
1653 |
+
x.stride(1),
|
1654 |
+
x.stride(2),
|
1655 |
+
x.stride(3),
|
1656 |
+
B.stride(0),
|
1657 |
+
B.stride(1),
|
1658 |
+
B.stride(2),
|
1659 |
+
B.stride(-1),
|
1660 |
+
dstates.stride(0),
|
1661 |
+
dstates.stride(1),
|
1662 |
+
dstates.stride(2),
|
1663 |
+
dstates.stride(3),
|
1664 |
+
dstates.stride(4),
|
1665 |
+
dt.stride(0),
|
1666 |
+
dt.stride(2),
|
1667 |
+
dt.stride(1),
|
1668 |
+
dt.stride(3),
|
1669 |
+
dA_cumsum.stride(0),
|
1670 |
+
dA_cumsum.stride(2),
|
1671 |
+
dA_cumsum.stride(1),
|
1672 |
+
dA_cumsum.stride(3),
|
1673 |
+
dx.stride(0),
|
1674 |
+
dx.stride(1),
|
1675 |
+
dx.stride(2),
|
1676 |
+
dx.stride(3),
|
1677 |
+
ddt.stride(0),
|
1678 |
+
ddt.stride(2),
|
1679 |
+
ddt.stride(1),
|
1680 |
+
ddt.stride(3),
|
1681 |
+
ddA_cumsum.stride(0),
|
1682 |
+
ddA_cumsum.stride(2),
|
1683 |
+
ddA_cumsum.stride(1),
|
1684 |
+
ddA_cumsum.stride(3),
|
1685 |
+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
1686 |
+
)
|
1687 |
+
return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
|
1688 |
+
|
1689 |
+
|
1690 |
+
def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
|
1691 |
+
batch, seqlen, nheads, headdim = x.shape
|
1692 |
+
_, _, nchunks, chunk_size = dt.shape
|
1693 |
+
dstate = dstates.shape[-1]
|
1694 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1695 |
+
assert dA_cumsum.shape == dt.shape
|
1696 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
1697 |
+
if seq_idx is not None:
|
1698 |
+
assert seq_idx.shape == (batch, seqlen)
|
1699 |
+
if B is not None:
|
1700 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1701 |
+
B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
|
1702 |
+
# Use torch.empty since the Triton kernel will call init_to_zero
|
1703 |
+
ddA_cumsum = torch.empty(
|
1704 |
+
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
1705 |
+
)
|
1706 |
+
ddA_cumsum_strides = (
|
1707 |
+
ddA_cumsum.stride(0),
|
1708 |
+
ddA_cumsum.stride(2),
|
1709 |
+
ddA_cumsum.stride(1),
|
1710 |
+
ddA_cumsum.stride(3),
|
1711 |
+
)
|
1712 |
+
else:
|
1713 |
+
B_strides = (0, 0, 0, 0)
|
1714 |
+
ddA_cumsum = None
|
1715 |
+
ddA_cumsum_strides = (0, 0, 0, 0)
|
1716 |
+
nheads_ngroups_ratio = nheads // ngroups
|
1717 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
1718 |
+
nheads_per_program = max(
|
1719 |
+
min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
|
1720 |
+
)
|
1721 |
+
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
|
1722 |
+
dB = torch.empty(
|
1723 |
+
batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
|
1724 |
+
)
|
1725 |
+
grid_db = lambda META: (
|
1726 |
+
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
1727 |
+
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
1728 |
+
batch * nchunks,
|
1729 |
+
nsplits * ngroups,
|
1730 |
+
)
|
1731 |
+
with torch.cuda.device(x.device.index):
|
1732 |
+
_chunk_state_bwd_db_kernel[grid_db](
|
1733 |
+
x,
|
1734 |
+
dstates,
|
1735 |
+
B,
|
1736 |
+
dt,
|
1737 |
+
dA_cumsum,
|
1738 |
+
seq_idx,
|
1739 |
+
dB,
|
1740 |
+
ddA_cumsum,
|
1741 |
+
chunk_size,
|
1742 |
+
dstate,
|
1743 |
+
headdim,
|
1744 |
+
batch,
|
1745 |
+
seqlen,
|
1746 |
+
nheads,
|
1747 |
+
nheads_per_program,
|
1748 |
+
ngroups,
|
1749 |
+
x.stride(0),
|
1750 |
+
x.stride(1),
|
1751 |
+
x.stride(2),
|
1752 |
+
x.stride(3),
|
1753 |
+
dstates.stride(0),
|
1754 |
+
dstates.stride(1),
|
1755 |
+
dstates.stride(2),
|
1756 |
+
dstates.stride(3),
|
1757 |
+
dstates.stride(4),
|
1758 |
+
*B_strides,
|
1759 |
+
dt.stride(0),
|
1760 |
+
dt.stride(2),
|
1761 |
+
dt.stride(1),
|
1762 |
+
dt.stride(3),
|
1763 |
+
dA_cumsum.stride(0),
|
1764 |
+
dA_cumsum.stride(2),
|
1765 |
+
dA_cumsum.stride(1),
|
1766 |
+
dA_cumsum.stride(3),
|
1767 |
+
*(
|
1768 |
+
(seq_idx.stride(0), seq_idx.stride(1))
|
1769 |
+
if seq_idx is not None
|
1770 |
+
else (0, 0)
|
1771 |
+
),
|
1772 |
+
dB.stride(0),
|
1773 |
+
dB.stride(1),
|
1774 |
+
dB.stride(2),
|
1775 |
+
dB.stride(3),
|
1776 |
+
dB.stride(4),
|
1777 |
+
*ddA_cumsum_strides,
|
1778 |
+
HAS_DDA_CS=ddA_cumsum is not None,
|
1779 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
1780 |
+
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
|
1781 |
+
)
|
1782 |
+
dB = dB.sum(2)
|
1783 |
+
if ddA_cumsum is not None:
|
1784 |
+
# The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
|
1785 |
+
# to the state of the chunk.
|
1786 |
+
# torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
1787 |
+
# But it's easier to just do the cumsum for all elements, the result will be the same.
|
1788 |
+
torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
|
1789 |
+
return dB if B is None else (dB, ddA_cumsum)
|
1790 |
+
|
1791 |
+
|
1792 |
+
def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
|
1793 |
+
batch, seqlen, nheads, headdim = x.shape
|
1794 |
+
_, _, nchunks, chunk_size = dt.shape
|
1795 |
+
_, _, ngroups, dstate = B.shape
|
1796 |
+
assert nheads % ngroups == 0
|
1797 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1798 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1799 |
+
assert dA_cumsum.shape == dt.shape
|
1800 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
1801 |
+
if seq_idx is not None:
|
1802 |
+
assert seq_idx.shape == (batch, seqlen)
|
1803 |
+
# Use torch.empty since the Triton kernel will call init_to_zero
|
1804 |
+
ddA_cumsum = torch.empty(
|
1805 |
+
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
1806 |
+
)
|
1807 |
+
grid_ddtcs = lambda META: (
|
1808 |
+
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
1809 |
+
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
1810 |
+
batch * nchunks,
|
1811 |
+
nheads,
|
1812 |
+
)
|
1813 |
+
with torch.cuda.device(x.device.index):
|
1814 |
+
_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
|
1815 |
+
x,
|
1816 |
+
B,
|
1817 |
+
dstates,
|
1818 |
+
dt,
|
1819 |
+
dA_cumsum,
|
1820 |
+
seq_idx,
|
1821 |
+
ddA_cumsum,
|
1822 |
+
chunk_size,
|
1823 |
+
headdim,
|
1824 |
+
dstate,
|
1825 |
+
batch,
|
1826 |
+
seqlen,
|
1827 |
+
nheads // ngroups,
|
1828 |
+
x.stride(0),
|
1829 |
+
x.stride(1),
|
1830 |
+
x.stride(2),
|
1831 |
+
x.stride(3),
|
1832 |
+
B.stride(0),
|
1833 |
+
B.stride(1),
|
1834 |
+
B.stride(2),
|
1835 |
+
B.stride(-1),
|
1836 |
+
dstates.stride(0),
|
1837 |
+
dstates.stride(1),
|
1838 |
+
dstates.stride(2),
|
1839 |
+
dstates.stride(3),
|
1840 |
+
dstates.stride(4),
|
1841 |
+
dt.stride(0),
|
1842 |
+
dt.stride(2),
|
1843 |
+
dt.stride(1),
|
1844 |
+
dt.stride(3),
|
1845 |
+
dA_cumsum.stride(0),
|
1846 |
+
dA_cumsum.stride(2),
|
1847 |
+
dA_cumsum.stride(1),
|
1848 |
+
dA_cumsum.stride(3),
|
1849 |
+
*(
|
1850 |
+
(seq_idx.stride(0), seq_idx.stride(1))
|
1851 |
+
if seq_idx is not None
|
1852 |
+
else (0, 0)
|
1853 |
+
),
|
1854 |
+
ddA_cumsum.stride(0),
|
1855 |
+
ddA_cumsum.stride(2),
|
1856 |
+
ddA_cumsum.stride(1),
|
1857 |
+
ddA_cumsum.stride(3),
|
1858 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
1859 |
+
BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
|
1860 |
+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
1861 |
+
)
|
1862 |
+
torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
1863 |
+
return ddA_cumsum
|
1864 |
+
|
1865 |
+
|
1866 |
+
def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
|
1867 |
+
total_seqlen, nheads, headdim = x.shape
|
1868 |
+
_, nchunks, chunk_size = dt.shape
|
1869 |
+
_, ngroups, dstate = B.shape
|
1870 |
+
batch = cu_seqlens.shape[0] - 1
|
1871 |
+
cu_seqlens = cu_seqlens.contiguous()
|
1872 |
+
assert nheads % ngroups == 0
|
1873 |
+
assert B.shape == (total_seqlen, ngroups, dstate)
|
1874 |
+
assert dt.shape == (nheads, nchunks, chunk_size)
|
1875 |
+
assert dA_cumsum.shape == dt.shape
|
1876 |
+
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
1877 |
+
states = torch.empty(
|
1878 |
+
batch,
|
1879 |
+
nheads,
|
1880 |
+
headdim,
|
1881 |
+
dstate,
|
1882 |
+
dtype=chunk_states.dtype,
|
1883 |
+
device=chunk_states.device,
|
1884 |
+
)
|
1885 |
+
grid = lambda META: (
|
1886 |
+
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
1887 |
+
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
1888 |
+
batch,
|
1889 |
+
nheads,
|
1890 |
+
)
|
1891 |
+
with torch.cuda.device(x.device.index):
|
1892 |
+
_chunk_state_varlen_kernel[grid](
|
1893 |
+
x,
|
1894 |
+
B,
|
1895 |
+
dt,
|
1896 |
+
dA_cumsum,
|
1897 |
+
chunk_states,
|
1898 |
+
cu_seqlens,
|
1899 |
+
states,
|
1900 |
+
headdim,
|
1901 |
+
dstate,
|
1902 |
+
chunk_size,
|
1903 |
+
total_seqlen,
|
1904 |
+
nheads // ngroups,
|
1905 |
+
x.stride(0),
|
1906 |
+
x.stride(1),
|
1907 |
+
x.stride(2),
|
1908 |
+
B.stride(0),
|
1909 |
+
B.stride(1),
|
1910 |
+
B.stride(2),
|
1911 |
+
dt.stride(1),
|
1912 |
+
dt.stride(0),
|
1913 |
+
dt.stride(2),
|
1914 |
+
dA_cumsum.stride(1),
|
1915 |
+
dA_cumsum.stride(0),
|
1916 |
+
dA_cumsum.stride(2),
|
1917 |
+
chunk_states.stride(0),
|
1918 |
+
chunk_states.stride(1),
|
1919 |
+
chunk_states.stride(2),
|
1920 |
+
chunk_states.stride(3),
|
1921 |
+
states.stride(0),
|
1922 |
+
states.stride(1),
|
1923 |
+
states.stride(2),
|
1924 |
+
states.stride(3),
|
1925 |
+
)
|
1926 |
+
return states
|
1927 |
+
|
1928 |
+
|
1929 |
+
class ChunkStateFn(torch.autograd.Function):
|
1930 |
+
|
1931 |
+
@staticmethod
|
1932 |
+
def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
|
1933 |
+
batch, seqlen, nheads, headdim = x.shape
|
1934 |
+
_, _, nchunks, chunk_size = dt.shape
|
1935 |
+
assert seqlen <= nchunks * chunk_size
|
1936 |
+
_, _, ngroups, dstate = B.shape
|
1937 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1938 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1939 |
+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
1940 |
+
if B.stride(-1) != 1:
|
1941 |
+
B = B.contiguous()
|
1942 |
+
if (
|
1943 |
+
x.stride(-1) != 1 and x.stride(1) != 1
|
1944 |
+
): # Either M or K dimension should be contiguous
|
1945 |
+
x = x.contiguous()
|
1946 |
+
states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
|
1947 |
+
ctx.save_for_backward(B, x, dt, dA_cumsum)
|
1948 |
+
return states
|
1949 |
+
|
1950 |
+
@staticmethod
|
1951 |
+
def backward(ctx, dstates):
|
1952 |
+
B, x, dt, dA_cumsum = ctx.saved_tensors
|
1953 |
+
batch, seqlen, nheads, headdim = x.shape
|
1954 |
+
_, _, nchunks, chunk_size = dt.shape
|
1955 |
+
_, _, ngroups, dstate = B.shape
|
1956 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
1957 |
+
if dstates.stride(-1) != 1:
|
1958 |
+
dstates = dstates.contiguous()
|
1959 |
+
dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
|
1960 |
+
dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
|
1961 |
+
dB = dB.to(B.dtype)
|
1962 |
+
return dB, dx, ddt, ddA_cumsum, None
|
1963 |
+
|
1964 |
+
|
1965 |
+
def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
|
1966 |
+
"""
|
1967 |
+
Argument:
|
1968 |
+
B: (batch, seqlen, ngroups, headdim)
|
1969 |
+
x: (batch, seqlen, nheads, headdim)
|
1970 |
+
dt: (batch, nheads, nchunks, chunk_size)
|
1971 |
+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
1972 |
+
Return:
|
1973 |
+
states: (batch, nchunks, nheads, headdim, dstate)
|
1974 |
+
"""
|
1975 |
+
return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
|
1976 |
+
|
1977 |
+
|
1978 |
+
def chunk_state_ref(B, x, dt, dA_cumsum):
|
1979 |
+
"""
|
1980 |
+
Argument:
|
1981 |
+
B: (batch, seqlen, ngroups, headdim)
|
1982 |
+
x: (batch, seqlen, nheads, headdim)
|
1983 |
+
dt: (batch, nheads, nchunks, chunk_size)
|
1984 |
+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
1985 |
+
Return:
|
1986 |
+
states: (batch, nchunks, nheads, headdim, dstate)
|
1987 |
+
"""
|
1988 |
+
# Check constraints.
|
1989 |
+
batch, seqlen, nheads, headdim = x.shape
|
1990 |
+
dstate = B.shape[-1]
|
1991 |
+
_, _, nchunks, chunk_size = dt.shape
|
1992 |
+
assert seqlen <= nchunks * chunk_size
|
1993 |
+
assert x.shape == (batch, seqlen, nheads, headdim)
|
1994 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1995 |
+
ngroups = B.shape[2]
|
1996 |
+
assert nheads % ngroups == 0
|
1997 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1998 |
+
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
|
1999 |
+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
2000 |
+
if seqlen < nchunks * chunk_size:
|
2001 |
+
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
2002 |
+
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
2003 |
+
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
|
2004 |
+
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
|
2005 |
+
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
|
2006 |
+
return torch.einsum(
|
2007 |
+
"bclhn,bhcl,bhcl,bclhp->bchpn",
|
2008 |
+
B.to(x.dtype),
|
2009 |
+
decay_states.to(x.dtype),
|
2010 |
+
dt.to(x.dtype),
|
2011 |
+
x,
|
2012 |
+
)
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_combined.py
ADDED
@@ -0,0 +1,1884 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import math
|
9 |
+
from packaging import version
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import Tensor
|
14 |
+
from ...utils.torch import custom_bwd, custom_fwd
|
15 |
+
|
16 |
+
import triton
|
17 |
+
import triton.language as tl
|
18 |
+
|
19 |
+
from einops import rearrange, repeat
|
20 |
+
|
21 |
+
try:
|
22 |
+
from causal_conv1d import causal_conv1d_fn
|
23 |
+
import causal_conv1d_cuda
|
24 |
+
except ImportError:
|
25 |
+
causal_conv1d_fn, causal_conv1d_cuda = None, None
|
26 |
+
|
27 |
+
from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
|
28 |
+
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
|
29 |
+
from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
|
30 |
+
from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
|
31 |
+
from .ssd_chunk_state import chunk_state, chunk_state_ref
|
32 |
+
from .ssd_chunk_state import chunk_state_varlen
|
33 |
+
from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
|
34 |
+
from .ssd_state_passing import state_passing, state_passing_ref
|
35 |
+
from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
|
36 |
+
from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
|
37 |
+
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
|
38 |
+
from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
|
39 |
+
from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
|
40 |
+
from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
|
41 |
+
from .k_activations import _swiglu_fwd, _swiglu_bwd
|
42 |
+
|
43 |
+
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
44 |
+
|
45 |
+
|
46 |
+
def init_to_zero(names):
|
47 |
+
return lambda nargs: [
|
48 |
+
nargs[name].zero_() for name in names if nargs[name] is not None
|
49 |
+
]
|
50 |
+
|
51 |
+
|
52 |
+
@triton.autotune(
|
53 |
+
configs=[
|
54 |
+
triton.Config(
|
55 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
56 |
+
num_stages=3,
|
57 |
+
num_warps=8,
|
58 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
59 |
+
),
|
60 |
+
triton.Config(
|
61 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
62 |
+
num_stages=4,
|
63 |
+
num_warps=4,
|
64 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
65 |
+
),
|
66 |
+
triton.Config(
|
67 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
68 |
+
num_stages=4,
|
69 |
+
num_warps=4,
|
70 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
71 |
+
),
|
72 |
+
triton.Config(
|
73 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
74 |
+
num_stages=4,
|
75 |
+
num_warps=4,
|
76 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
77 |
+
),
|
78 |
+
triton.Config(
|
79 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
80 |
+
num_stages=4,
|
81 |
+
num_warps=4,
|
82 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
83 |
+
),
|
84 |
+
triton.Config(
|
85 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
86 |
+
num_stages=4,
|
87 |
+
num_warps=4,
|
88 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
89 |
+
),
|
90 |
+
triton.Config(
|
91 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
92 |
+
num_stages=5,
|
93 |
+
num_warps=4,
|
94 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
95 |
+
),
|
96 |
+
triton.Config(
|
97 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
98 |
+
num_stages=5,
|
99 |
+
num_warps=4,
|
100 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
101 |
+
),
|
102 |
+
triton.Config(
|
103 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
104 |
+
num_stages=4,
|
105 |
+
num_warps=4,
|
106 |
+
pre_hook=init_to_zero(["ddt_ptr"]),
|
107 |
+
),
|
108 |
+
],
|
109 |
+
key=["chunk_size", "hdim", "dstate"],
|
110 |
+
)
|
111 |
+
@triton.jit
|
112 |
+
def _chunk_scan_chunk_state_bwd_dx_kernel(
|
113 |
+
# Pointers to matrices
|
114 |
+
x_ptr,
|
115 |
+
cb_ptr,
|
116 |
+
dout_ptr,
|
117 |
+
dt_ptr,
|
118 |
+
dA_cumsum_ptr,
|
119 |
+
seq_idx_ptr,
|
120 |
+
D_ptr,
|
121 |
+
b_ptr,
|
122 |
+
dstates_ptr,
|
123 |
+
dx_ptr,
|
124 |
+
ddt_ptr,
|
125 |
+
dD_ptr,
|
126 |
+
# Matrix dimensions
|
127 |
+
chunk_size,
|
128 |
+
hdim,
|
129 |
+
dstate,
|
130 |
+
batch,
|
131 |
+
seqlen,
|
132 |
+
nheads_ngroups_ratio,
|
133 |
+
# Strides
|
134 |
+
stride_x_batch,
|
135 |
+
stride_x_seqlen,
|
136 |
+
stride_x_head,
|
137 |
+
stride_x_hdim,
|
138 |
+
stride_cb_batch,
|
139 |
+
stride_cb_chunk,
|
140 |
+
stride_cb_head,
|
141 |
+
stride_cb_csize_m,
|
142 |
+
stride_cb_csize_k,
|
143 |
+
stride_dout_batch,
|
144 |
+
stride_dout_seqlen,
|
145 |
+
stride_dout_head,
|
146 |
+
stride_dout_hdim,
|
147 |
+
stride_dt_batch,
|
148 |
+
stride_dt_chunk,
|
149 |
+
stride_dt_head,
|
150 |
+
stride_dt_csize,
|
151 |
+
stride_dA_cs_batch,
|
152 |
+
stride_dA_cs_chunk,
|
153 |
+
stride_dA_cs_head,
|
154 |
+
stride_dA_cs_csize,
|
155 |
+
stride_seq_idx_batch,
|
156 |
+
stride_seq_idx_seqlen,
|
157 |
+
stride_D_head,
|
158 |
+
stride_b_batch,
|
159 |
+
stride_b_seqlen,
|
160 |
+
stride_b_head,
|
161 |
+
stride_b_dstate,
|
162 |
+
stride_dstates_batch,
|
163 |
+
stride_dstates_chunk,
|
164 |
+
stride_dstates_head,
|
165 |
+
stride_dstates_hdim,
|
166 |
+
stride_dstates_dstate,
|
167 |
+
stride_dx_batch,
|
168 |
+
stride_dx_seqlen,
|
169 |
+
stride_dx_head,
|
170 |
+
stride_dx_hdim,
|
171 |
+
stride_ddt_batch,
|
172 |
+
stride_ddt_chunk,
|
173 |
+
stride_ddt_head,
|
174 |
+
stride_ddt_csize,
|
175 |
+
stride_dD_batch,
|
176 |
+
stride_dD_chunk,
|
177 |
+
stride_dD_head,
|
178 |
+
stride_dD_csize,
|
179 |
+
stride_dD_hdim,
|
180 |
+
# Meta-parameters
|
181 |
+
HAS_D: tl.constexpr,
|
182 |
+
D_HAS_HDIM: tl.constexpr,
|
183 |
+
HAS_SEQ_IDX: tl.constexpr,
|
184 |
+
BLOCK_SIZE_M: tl.constexpr,
|
185 |
+
BLOCK_SIZE_N: tl.constexpr,
|
186 |
+
BLOCK_SIZE_K: tl.constexpr,
|
187 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
188 |
+
IS_TRITON_22: tl.constexpr,
|
189 |
+
):
|
190 |
+
pid_bc = tl.program_id(axis=1)
|
191 |
+
pid_c = pid_bc // batch
|
192 |
+
pid_b = pid_bc - pid_c * batch
|
193 |
+
pid_h = tl.program_id(axis=2)
|
194 |
+
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
195 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
196 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
197 |
+
x_ptr += (
|
198 |
+
pid_b * stride_x_batch
|
199 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
200 |
+
+ pid_h * stride_x_head
|
201 |
+
)
|
202 |
+
cb_ptr += (
|
203 |
+
pid_b * stride_cb_batch
|
204 |
+
+ pid_c * stride_cb_chunk
|
205 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
206 |
+
)
|
207 |
+
dout_ptr += (
|
208 |
+
pid_b * stride_dout_batch
|
209 |
+
+ pid_c * chunk_size * stride_dout_seqlen
|
210 |
+
+ pid_h * stride_dout_head
|
211 |
+
)
|
212 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
213 |
+
ddt_ptr += (
|
214 |
+
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
215 |
+
)
|
216 |
+
dA_cumsum_ptr += (
|
217 |
+
pid_b * stride_dA_cs_batch
|
218 |
+
+ pid_c * stride_dA_cs_chunk
|
219 |
+
+ pid_h * stride_dA_cs_head
|
220 |
+
)
|
221 |
+
b_ptr += (
|
222 |
+
pid_b * stride_b_batch
|
223 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
224 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
225 |
+
)
|
226 |
+
dstates_ptr += (
|
227 |
+
pid_b * stride_dstates_batch
|
228 |
+
+ pid_c * stride_dstates_chunk
|
229 |
+
+ pid_h * stride_dstates_head
|
230 |
+
)
|
231 |
+
if HAS_SEQ_IDX:
|
232 |
+
seq_idx_ptr += (
|
233 |
+
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
234 |
+
)
|
235 |
+
|
236 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
237 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
238 |
+
|
239 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
240 |
+
|
241 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
242 |
+
|
243 |
+
dA_cs_m = tl.load(
|
244 |
+
dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
245 |
+
mask=offs_m < chunk_size_limit,
|
246 |
+
other=0.0,
|
247 |
+
).to(tl.float32)
|
248 |
+
|
249 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
250 |
+
tl.float32
|
251 |
+
)
|
252 |
+
if not HAS_SEQ_IDX:
|
253 |
+
scale = tl.exp(dA_cs_last - dA_cs_m)
|
254 |
+
else:
|
255 |
+
seq_idx_m = tl.load(
|
256 |
+
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
257 |
+
mask=offs_m < chunk_size_limit,
|
258 |
+
other=-1,
|
259 |
+
)
|
260 |
+
seq_idx_last = tl.load(
|
261 |
+
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
262 |
+
)
|
263 |
+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
264 |
+
# Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
265 |
+
# However, we're getting error with the Triton compiler 2.1.0 for that code path:
|
266 |
+
# Unexpected mma -> mma layout conversion
|
267 |
+
# Triton 2.2.0 fixes this
|
268 |
+
offs_dstate = tl.arange(
|
269 |
+
0,
|
270 |
+
(
|
271 |
+
BLOCK_SIZE_DSTATE
|
272 |
+
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128
|
273 |
+
else BLOCK_SIZE_K
|
274 |
+
),
|
275 |
+
)
|
276 |
+
b_ptrs = b_ptr + (
|
277 |
+
offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate
|
278 |
+
)
|
279 |
+
dstates_ptrs = dstates_ptr + (
|
280 |
+
offs_n[None, :] * stride_dstates_hdim
|
281 |
+
+ offs_dstate[:, None] * stride_dstates_dstate
|
282 |
+
)
|
283 |
+
if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
|
284 |
+
b = tl.load(
|
285 |
+
b_ptrs,
|
286 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate),
|
287 |
+
other=0.0,
|
288 |
+
)
|
289 |
+
dstates = tl.load(
|
290 |
+
dstates_ptrs,
|
291 |
+
mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
292 |
+
other=0.0,
|
293 |
+
)
|
294 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
295 |
+
acc = tl.dot(b, dstates) * scale[:, None]
|
296 |
+
else:
|
297 |
+
for k in range(0, dstate, BLOCK_SIZE_K):
|
298 |
+
b = tl.load(
|
299 |
+
b_ptrs,
|
300 |
+
mask=(offs_m[:, None] < chunk_size_limit)
|
301 |
+
& (offs_dstate[None, :] < dstate - k),
|
302 |
+
other=0.0,
|
303 |
+
)
|
304 |
+
dstates = tl.load(
|
305 |
+
dstates_ptrs,
|
306 |
+
mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
307 |
+
other=0.0,
|
308 |
+
)
|
309 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
310 |
+
acc += tl.dot(b, dstates)
|
311 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
312 |
+
dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
|
313 |
+
acc *= scale[:, None]
|
314 |
+
|
315 |
+
# x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
|
316 |
+
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
|
317 |
+
# dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
318 |
+
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
319 |
+
# ddt = tl.sum(acc * x, axis=1) * dt_m
|
320 |
+
# ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
321 |
+
# tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
322 |
+
|
323 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
324 |
+
cb_ptrs = cb_ptr + (
|
325 |
+
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
326 |
+
)
|
327 |
+
dout_ptrs = dout_ptr + (
|
328 |
+
offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
329 |
+
)
|
330 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
331 |
+
K_MAX = chunk_size_limit
|
332 |
+
K_MIN = pid_m * BLOCK_SIZE_M
|
333 |
+
cb_ptrs += K_MIN * stride_cb_csize_k
|
334 |
+
dout_ptrs += K_MIN * stride_dout_seqlen
|
335 |
+
dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
|
336 |
+
for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
|
337 |
+
k = tl.multiple_of(k, BLOCK_SIZE_K)
|
338 |
+
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
|
339 |
+
cb = tl.load(
|
340 |
+
cb_ptrs,
|
341 |
+
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k),
|
342 |
+
other=0.0,
|
343 |
+
)
|
344 |
+
dout = tl.load(
|
345 |
+
dout_ptrs,
|
346 |
+
mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim),
|
347 |
+
other=0.0,
|
348 |
+
)
|
349 |
+
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(
|
350 |
+
tl.float32
|
351 |
+
)
|
352 |
+
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
|
353 |
+
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
|
354 |
+
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
|
355 |
+
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
|
356 |
+
# This will cause NaN in acc, and hence NaN in dx and ddt.
|
357 |
+
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
|
358 |
+
cb = tl.where(mask, cb, 0.0)
|
359 |
+
cb = cb.to(dout_ptr.dtype.element_ty)
|
360 |
+
acc += tl.dot(cb, dout)
|
361 |
+
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
362 |
+
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
|
363 |
+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
364 |
+
|
365 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
366 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
367 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
368 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
|
369 |
+
dx = acc * dt_m[:, None]
|
370 |
+
dx_ptr += (
|
371 |
+
pid_b * stride_dx_batch
|
372 |
+
+ pid_c * chunk_size * stride_dx_seqlen
|
373 |
+
+ pid_h * stride_dx_head
|
374 |
+
)
|
375 |
+
dx_ptrs = dx_ptr + (
|
376 |
+
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
377 |
+
)
|
378 |
+
if HAS_D:
|
379 |
+
dout_res_ptrs = dout_ptr + (
|
380 |
+
offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim
|
381 |
+
)
|
382 |
+
dout_res = tl.load(
|
383 |
+
dout_res_ptrs,
|
384 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
385 |
+
other=0.0,
|
386 |
+
).to(tl.float32)
|
387 |
+
if D_HAS_HDIM:
|
388 |
+
D = tl.load(
|
389 |
+
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
390 |
+
).to(tl.float32)
|
391 |
+
else:
|
392 |
+
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
393 |
+
dx += dout_res * D
|
394 |
+
tl.store(
|
395 |
+
dx_ptrs,
|
396 |
+
dx,
|
397 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
398 |
+
)
|
399 |
+
|
400 |
+
x_ptrs = x_ptr + (
|
401 |
+
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
402 |
+
)
|
403 |
+
x = tl.load(
|
404 |
+
x_ptrs,
|
405 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
406 |
+
other=0.0,
|
407 |
+
).to(tl.float32)
|
408 |
+
if HAS_D:
|
409 |
+
dD_ptr += (
|
410 |
+
pid_b * stride_dD_batch
|
411 |
+
+ pid_c * stride_dD_chunk
|
412 |
+
+ pid_h * stride_dD_head
|
413 |
+
+ pid_m * stride_dD_csize
|
414 |
+
)
|
415 |
+
if D_HAS_HDIM:
|
416 |
+
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
|
417 |
+
dD = tl.sum(dout_res * x, axis=0)
|
418 |
+
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
|
419 |
+
else:
|
420 |
+
dD = tl.sum(dout_res * x)
|
421 |
+
tl.store(dD_ptr, dD)
|
422 |
+
ddt = tl.sum(acc * x, axis=1)
|
423 |
+
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
424 |
+
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
425 |
+
|
426 |
+
|
427 |
+
def _chunk_scan_chunk_state_bwd_dx(
|
428 |
+
x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None
|
429 |
+
):
|
430 |
+
batch, seqlen, nheads, headdim = x.shape
|
431 |
+
_, _, nchunks, chunk_size = dt.shape
|
432 |
+
_, _, ngroups, dstate = B.shape
|
433 |
+
assert nheads % ngroups == 0
|
434 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
435 |
+
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
436 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
437 |
+
assert dA_cumsum.shape == dt.shape
|
438 |
+
assert dout.shape == x.shape
|
439 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
440 |
+
if seq_idx is not None:
|
441 |
+
assert seq_idx.shape == (batch, seqlen)
|
442 |
+
if D is not None:
|
443 |
+
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
444 |
+
assert D.stride(-1) == 1
|
445 |
+
BLOCK_SIZE_min = 32
|
446 |
+
dD = torch.empty(
|
447 |
+
triton.cdiv(chunk_size, BLOCK_SIZE_min),
|
448 |
+
batch,
|
449 |
+
nchunks,
|
450 |
+
nheads,
|
451 |
+
headdim if D.dim() == 2 else 1,
|
452 |
+
device=D.device,
|
453 |
+
dtype=torch.float32,
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
dD = None
|
457 |
+
dD_strides = (
|
458 |
+
(dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
|
459 |
+
if D is not None
|
460 |
+
else (0, 0, 0, 0, 0)
|
461 |
+
)
|
462 |
+
if dx is None:
|
463 |
+
dx = torch.empty_like(x)
|
464 |
+
else:
|
465 |
+
assert dx.shape == x.shape
|
466 |
+
ddt = torch.empty(
|
467 |
+
batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32
|
468 |
+
)
|
469 |
+
grid_dx = lambda META: (
|
470 |
+
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
471 |
+
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
472 |
+
batch * nchunks,
|
473 |
+
nheads,
|
474 |
+
)
|
475 |
+
with torch.cuda.device(x.device.index):
|
476 |
+
_chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
|
477 |
+
x,
|
478 |
+
CB,
|
479 |
+
dout,
|
480 |
+
dt,
|
481 |
+
dA_cumsum,
|
482 |
+
seq_idx,
|
483 |
+
D,
|
484 |
+
B,
|
485 |
+
dstates,
|
486 |
+
dx,
|
487 |
+
ddt,
|
488 |
+
dD,
|
489 |
+
chunk_size,
|
490 |
+
headdim,
|
491 |
+
dstate,
|
492 |
+
batch,
|
493 |
+
seqlen,
|
494 |
+
nheads // ngroups,
|
495 |
+
x.stride(0),
|
496 |
+
x.stride(1),
|
497 |
+
x.stride(2),
|
498 |
+
x.stride(3),
|
499 |
+
CB.stride(0),
|
500 |
+
CB.stride(1),
|
501 |
+
CB.stride(2),
|
502 |
+
CB.stride(-1),
|
503 |
+
CB.stride(-2),
|
504 |
+
dout.stride(0),
|
505 |
+
dout.stride(1),
|
506 |
+
dout.stride(2),
|
507 |
+
dout.stride(3),
|
508 |
+
dt.stride(0),
|
509 |
+
dt.stride(2),
|
510 |
+
dt.stride(1),
|
511 |
+
dt.stride(3),
|
512 |
+
dA_cumsum.stride(0),
|
513 |
+
dA_cumsum.stride(2),
|
514 |
+
dA_cumsum.stride(1),
|
515 |
+
dA_cumsum.stride(3),
|
516 |
+
*(
|
517 |
+
(seq_idx.stride(0), seq_idx.stride(1))
|
518 |
+
if seq_idx is not None
|
519 |
+
else (0, 0)
|
520 |
+
),
|
521 |
+
D.stride(0) if D is not None else 0,
|
522 |
+
B.stride(0),
|
523 |
+
B.stride(1),
|
524 |
+
B.stride(2),
|
525 |
+
B.stride(3),
|
526 |
+
dstates.stride(0),
|
527 |
+
dstates.stride(1),
|
528 |
+
dstates.stride(2),
|
529 |
+
dstates.stride(3),
|
530 |
+
dstates.stride(4),
|
531 |
+
dx.stride(0),
|
532 |
+
dx.stride(1),
|
533 |
+
dx.stride(2),
|
534 |
+
dx.stride(3),
|
535 |
+
ddt.stride(0),
|
536 |
+
ddt.stride(2),
|
537 |
+
ddt.stride(1),
|
538 |
+
ddt.stride(3),
|
539 |
+
dD_strides[1],
|
540 |
+
dD_strides[2],
|
541 |
+
dD_strides[3],
|
542 |
+
dD_strides[0],
|
543 |
+
dD_strides[4],
|
544 |
+
D is not None,
|
545 |
+
D.dim() == 2 if D is not None else True,
|
546 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
547 |
+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
548 |
+
IS_TRITON_22=TRITON_22
|
549 |
+
)
|
550 |
+
if D is not None:
|
551 |
+
BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs[
|
552 |
+
"BLOCK_SIZE_M"
|
553 |
+
]
|
554 |
+
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
555 |
+
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
|
556 |
+
if D.dim() == 1:
|
557 |
+
dD = rearrange(dD, "h 1 -> h")
|
558 |
+
return dx, ddt.to(dtype=dt.dtype), dD
|
559 |
+
|
560 |
+
|
561 |
+
def _mamba_chunk_scan_combined_fwd(
|
562 |
+
x,
|
563 |
+
dt,
|
564 |
+
A,
|
565 |
+
B,
|
566 |
+
C,
|
567 |
+
chunk_size,
|
568 |
+
D=None,
|
569 |
+
z=None,
|
570 |
+
dt_bias=None,
|
571 |
+
initial_states=None,
|
572 |
+
seq_idx=None,
|
573 |
+
cu_seqlens=None,
|
574 |
+
dt_softplus=False,
|
575 |
+
dt_limit=(0.0, float("inf")),
|
576 |
+
):
|
577 |
+
batch, seqlen, nheads, headdim = x.shape
|
578 |
+
_, _, ngroups, dstate = B.shape
|
579 |
+
assert nheads % ngroups == 0
|
580 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
581 |
+
assert x.shape == (batch, seqlen, nheads, headdim)
|
582 |
+
assert dt.shape == (batch, seqlen, nheads)
|
583 |
+
assert A.shape == (nheads,)
|
584 |
+
assert C.shape == B.shape
|
585 |
+
if z is not None:
|
586 |
+
assert z.shape == x.shape
|
587 |
+
if D is not None:
|
588 |
+
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
589 |
+
if seq_idx is not None:
|
590 |
+
assert seq_idx.shape == (batch, seqlen)
|
591 |
+
if B.stride(-1) != 1:
|
592 |
+
B = B.contiguous()
|
593 |
+
if C.stride(-1) != 1:
|
594 |
+
C = C.contiguous()
|
595 |
+
if (
|
596 |
+
x.stride(-1) != 1 and x.stride(1) != 1
|
597 |
+
): # Either M or K dimension should be contiguous
|
598 |
+
x = x.contiguous()
|
599 |
+
if (
|
600 |
+
z is not None and z.stride(-1) != 1 and z.stride(1) != 1
|
601 |
+
): # Either M or K dimension should be contiguous
|
602 |
+
z = z.contiguous()
|
603 |
+
if D is not None and D.stride(-1) != 1:
|
604 |
+
D = D.contiguous()
|
605 |
+
if initial_states is not None:
|
606 |
+
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
607 |
+
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
|
608 |
+
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
609 |
+
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
610 |
+
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
|
611 |
+
dA_cumsum, dt = _chunk_cumsum_fwd(
|
612 |
+
dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
|
613 |
+
)
|
614 |
+
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
615 |
+
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
|
616 |
+
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
|
617 |
+
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
|
618 |
+
states, final_states = _state_passing_fwd(
|
619 |
+
rearrange(states, "... p n -> ... (p n)"),
|
620 |
+
dA_cumsum[:, :, :, -1],
|
621 |
+
initial_states=(
|
622 |
+
rearrange(initial_states, "... p n -> ... (p n)")
|
623 |
+
if initial_states is not None
|
624 |
+
else None
|
625 |
+
),
|
626 |
+
seq_idx=seq_idx,
|
627 |
+
chunk_size=chunk_size,
|
628 |
+
out_dtype=C.dtype,
|
629 |
+
)
|
630 |
+
states, final_states = [
|
631 |
+
rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
|
632 |
+
]
|
633 |
+
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
634 |
+
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
|
635 |
+
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
636 |
+
out, out_x = _chunk_scan_fwd(
|
637 |
+
CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx
|
638 |
+
)
|
639 |
+
if cu_seqlens is None:
|
640 |
+
return out, out_x, dt, dA_cumsum, states, final_states
|
641 |
+
else:
|
642 |
+
assert (
|
643 |
+
batch == 1
|
644 |
+
), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
645 |
+
varlen_states = chunk_state_varlen(
|
646 |
+
B.squeeze(0),
|
647 |
+
x.squeeze(0),
|
648 |
+
dt.squeeze(0),
|
649 |
+
dA_cumsum.squeeze(0),
|
650 |
+
cu_seqlens,
|
651 |
+
states.squeeze(0),
|
652 |
+
)
|
653 |
+
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
|
654 |
+
|
655 |
+
|
656 |
+
def _mamba_chunk_scan_combined_bwd(
|
657 |
+
dout,
|
658 |
+
x,
|
659 |
+
dt,
|
660 |
+
A,
|
661 |
+
B,
|
662 |
+
C,
|
663 |
+
out,
|
664 |
+
chunk_size,
|
665 |
+
D=None,
|
666 |
+
z=None,
|
667 |
+
dt_bias=None,
|
668 |
+
initial_states=None,
|
669 |
+
dfinal_states=None,
|
670 |
+
seq_idx=None,
|
671 |
+
dt_softplus=False,
|
672 |
+
dt_limit=(0.0, float("inf")),
|
673 |
+
dx=None,
|
674 |
+
ddt=None,
|
675 |
+
dB=None,
|
676 |
+
dC=None,
|
677 |
+
dz=None,
|
678 |
+
recompute_output=False,
|
679 |
+
):
|
680 |
+
if dout.stride(-1) != 1:
|
681 |
+
dout = dout.contiguous()
|
682 |
+
batch, seqlen, nheads, headdim = x.shape
|
683 |
+
nchunks = math.ceil(seqlen / chunk_size)
|
684 |
+
_, _, ngroups, dstate = B.shape
|
685 |
+
assert dout.shape == (batch, seqlen, nheads, headdim)
|
686 |
+
assert dt.shape == (batch, seqlen, nheads)
|
687 |
+
assert A.shape == (nheads,)
|
688 |
+
assert nheads % ngroups == 0
|
689 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
690 |
+
assert C.shape == B.shape
|
691 |
+
assert out.shape == x.shape
|
692 |
+
if initial_states is not None:
|
693 |
+
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
694 |
+
if seq_idx is not None:
|
695 |
+
assert seq_idx.shape == (batch, seqlen)
|
696 |
+
if dx is not None:
|
697 |
+
assert dx.shape == x.shape
|
698 |
+
if dB is not None:
|
699 |
+
assert dB.shape == B.shape
|
700 |
+
dB_given = dB
|
701 |
+
else:
|
702 |
+
dB_given = torch.empty_like(B)
|
703 |
+
if dC is not None:
|
704 |
+
assert dC.shape == C.shape
|
705 |
+
dC_given = dC
|
706 |
+
else:
|
707 |
+
dC_given = torch.empty_like(C)
|
708 |
+
if dz is not None:
|
709 |
+
assert z is not None
|
710 |
+
assert dz.shape == z.shape
|
711 |
+
if ddt is not None:
|
712 |
+
assert ddt.shape == dt.shape
|
713 |
+
ddt_given = ddt
|
714 |
+
else:
|
715 |
+
ddt_given = torch.empty_like(dt)
|
716 |
+
# TD: For some reason Triton (2.1.0 and 2.2.0) errors with
|
717 |
+
# "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
|
718 |
+
dt_in = dt.clone()
|
719 |
+
dA_cumsum, dt = _chunk_cumsum_fwd(
|
720 |
+
dt_in,
|
721 |
+
A,
|
722 |
+
chunk_size,
|
723 |
+
dt_bias=dt_bias,
|
724 |
+
dt_softplus=dt_softplus,
|
725 |
+
dt_limit=dt_limit,
|
726 |
+
)
|
727 |
+
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
728 |
+
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
729 |
+
states, _ = _state_passing_fwd(
|
730 |
+
rearrange(states, "... p n -> ... (p n)"),
|
731 |
+
dA_cumsum[:, :, :, -1],
|
732 |
+
initial_states=(
|
733 |
+
rearrange(initial_states, "... p n -> ... (p n)")
|
734 |
+
if initial_states is not None
|
735 |
+
else None
|
736 |
+
),
|
737 |
+
seq_idx=seq_idx,
|
738 |
+
chunk_size=chunk_size,
|
739 |
+
)
|
740 |
+
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
741 |
+
if z is not None:
|
742 |
+
dz, dout, dD, *rest = _chunk_scan_bwd_dz(
|
743 |
+
x,
|
744 |
+
z,
|
745 |
+
out,
|
746 |
+
dout,
|
747 |
+
chunk_size=chunk_size,
|
748 |
+
has_ddAcs=False,
|
749 |
+
D=D,
|
750 |
+
dz=dz,
|
751 |
+
recompute_output=recompute_output,
|
752 |
+
)
|
753 |
+
outz = rest[0] if recompute_output else out
|
754 |
+
else:
|
755 |
+
dz = None
|
756 |
+
outz = out
|
757 |
+
dstates = _chunk_scan_bwd_dstates(
|
758 |
+
C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype
|
759 |
+
)
|
760 |
+
# dstates has length nchunks, containing the gradient to initial states at index 0 and
|
761 |
+
# gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
|
762 |
+
# Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
|
763 |
+
# will be used in matmul in the next kernels.
|
764 |
+
dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
|
765 |
+
rearrange(states, "... p n -> ... (p n)"),
|
766 |
+
dA_cumsum[:, :, :, -1],
|
767 |
+
rearrange(dstates, "... p n -> ... (p n)"),
|
768 |
+
dfinal_states=(
|
769 |
+
rearrange(dfinal_states, "... p n -> ... (p n)")
|
770 |
+
if dfinal_states is not None
|
771 |
+
else None
|
772 |
+
),
|
773 |
+
seq_idx=seq_idx,
|
774 |
+
has_initial_states=initial_states is not None,
|
775 |
+
dstates_dtype=x.dtype,
|
776 |
+
states_dtype=x.dtype,
|
777 |
+
chunk_size=chunk_size,
|
778 |
+
)
|
779 |
+
# dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
|
780 |
+
# gradient to the final states at index (nchunks - 1)
|
781 |
+
# states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
|
782 |
+
# The final states is not stored.
|
783 |
+
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
784 |
+
dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
|
785 |
+
dinitial_states = (
|
786 |
+
rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate)
|
787 |
+
if dinitial_states is not None
|
788 |
+
else None
|
789 |
+
)
|
790 |
+
dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(
|
791 |
+
x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx
|
792 |
+
)
|
793 |
+
# dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
|
794 |
+
dB, ddA_next = _chunk_state_bwd_db(
|
795 |
+
x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups
|
796 |
+
)
|
797 |
+
# dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
798 |
+
dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(
|
799 |
+
states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups
|
800 |
+
)
|
801 |
+
# Computing ddA with the dcb kernel is much slower, so we're not using it for now
|
802 |
+
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
|
803 |
+
# dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
|
804 |
+
dCB = dCB.to(CB.dtype)
|
805 |
+
_bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
|
806 |
+
_bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
|
807 |
+
# If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
|
808 |
+
# than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
|
809 |
+
if z is None:
|
810 |
+
dD = dD_from_x
|
811 |
+
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
|
812 |
+
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
|
813 |
+
# However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
|
814 |
+
# be a lot of underflow.
|
815 |
+
|
816 |
+
# This is already done as part of bwd_dC kernel
|
817 |
+
# ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
|
818 |
+
ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
|
819 |
+
ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
|
820 |
+
# This is already done as part of bwd_dB kernel
|
821 |
+
# ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
|
822 |
+
# We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
|
823 |
+
ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
|
824 |
+
ddA += ddA_next + ddA_prev
|
825 |
+
|
826 |
+
ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(
|
827 |
+
ddA,
|
828 |
+
ddt,
|
829 |
+
dt_in,
|
830 |
+
A,
|
831 |
+
dt_bias=dt_bias,
|
832 |
+
dt_softplus=dt_softplus,
|
833 |
+
dt_limit=dt_limit,
|
834 |
+
ddt=ddt_given,
|
835 |
+
)
|
836 |
+
|
837 |
+
# These 2 lines are just to test ddt and dA being computed by old code
|
838 |
+
# _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
|
839 |
+
# ddt_given.copy_(ddt)
|
840 |
+
|
841 |
+
return_vals = (
|
842 |
+
dx,
|
843 |
+
ddt_given,
|
844 |
+
dA,
|
845 |
+
dB_given,
|
846 |
+
dC_given,
|
847 |
+
dD,
|
848 |
+
dz,
|
849 |
+
ddt_bias,
|
850 |
+
dinitial_states,
|
851 |
+
)
|
852 |
+
return return_vals if not recompute_output else (*return_vals, outz)
|
853 |
+
|
854 |
+
|
855 |
+
def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
|
856 |
+
"""
|
857 |
+
Argument:
|
858 |
+
dout: (batch, seqlen, nheads, headdim)
|
859 |
+
x: (batch, seqlen, nheads, headdim)
|
860 |
+
dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
|
861 |
+
A: (nheads) or (dim, dstate)
|
862 |
+
B: (batch, seqlen, ngroups, dstate)
|
863 |
+
C: (batch, seqlen, ngroups, dstate)
|
864 |
+
D: (nheads, headdim) or (nheads,)
|
865 |
+
z: (batch, seqlen, nheads, headdim)
|
866 |
+
Return:
|
867 |
+
out: (batch, seqlen, nheads, headdim)
|
868 |
+
"""
|
869 |
+
import selective_scan
|
870 |
+
|
871 |
+
batch, seqlen, nheads, headdim = x.shape
|
872 |
+
chunk_size = dt.shape[-1]
|
873 |
+
_, _, ngroups, dstate = B.shape
|
874 |
+
assert nheads % ngroups == 0
|
875 |
+
x = rearrange(x, "b l h p -> b (h p) l")
|
876 |
+
squeeze_dt = dt.dim() == 4
|
877 |
+
if dt.dim() == 4:
|
878 |
+
dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
|
879 |
+
dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
|
880 |
+
squeeze_A = A.dim() == 1
|
881 |
+
if A.dim() == 1:
|
882 |
+
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
883 |
+
else:
|
884 |
+
A = A.to(dtype=torch.float32)
|
885 |
+
B = rearrange(B, "b l g n -> b g n l")
|
886 |
+
C = rearrange(C, "b l g n -> b g n l")
|
887 |
+
if D is not None:
|
888 |
+
if D.dim() == 2:
|
889 |
+
D = rearrange(D, "h p -> (h p)")
|
890 |
+
else:
|
891 |
+
D = repeat(D, "h -> (h p)", p=headdim)
|
892 |
+
if z is not None:
|
893 |
+
z = rearrange(z, "b l h p -> b (h p) l")
|
894 |
+
|
895 |
+
if x.stride(-1) != 1:
|
896 |
+
x = x.contiguous()
|
897 |
+
if dt.stride(-1) != 1:
|
898 |
+
dt = dt.contiguous()
|
899 |
+
if D is not None:
|
900 |
+
D = D.contiguous()
|
901 |
+
if B.stride(-1) != 1:
|
902 |
+
B = B.contiguous()
|
903 |
+
if C.stride(-1) != 1:
|
904 |
+
C = C.contiguous()
|
905 |
+
if z is not None and z.stride(-1) != 1:
|
906 |
+
z = z.contiguous()
|
907 |
+
_, intermediate, *rest = selective_scan.fwd(
|
908 |
+
x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False
|
909 |
+
)
|
910 |
+
if z is not None:
|
911 |
+
out = rest[0]
|
912 |
+
else:
|
913 |
+
out = None
|
914 |
+
|
915 |
+
dout = rearrange(dout, "b l h p -> b (h p) l")
|
916 |
+
|
917 |
+
if dout.stride(-1) != 1:
|
918 |
+
dout = dout.contiguous()
|
919 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
920 |
+
# backward of selective_scan with the backward of chunk).
|
921 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
922 |
+
_, ddt, dA, *rest = selective_scan.bwd(
|
923 |
+
x,
|
924 |
+
dt.to(dtype=x.dtype),
|
925 |
+
A,
|
926 |
+
B,
|
927 |
+
C,
|
928 |
+
D,
|
929 |
+
z,
|
930 |
+
None,
|
931 |
+
dout,
|
932 |
+
intermediate,
|
933 |
+
out,
|
934 |
+
None,
|
935 |
+
False,
|
936 |
+
False, # option to recompute out_z, not used here
|
937 |
+
)
|
938 |
+
ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
|
939 |
+
if squeeze_dt:
|
940 |
+
ddt = ddt.float().sum(dim=2)
|
941 |
+
if squeeze_A:
|
942 |
+
dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
|
943 |
+
return ddt, dA
|
944 |
+
|
945 |
+
|
946 |
+
class MambaChunkScanCombinedFn(torch.autograd.Function):
|
947 |
+
|
948 |
+
@staticmethod
|
949 |
+
def forward(
|
950 |
+
ctx,
|
951 |
+
x,
|
952 |
+
dt,
|
953 |
+
A,
|
954 |
+
B,
|
955 |
+
C,
|
956 |
+
chunk_size,
|
957 |
+
D=None,
|
958 |
+
z=None,
|
959 |
+
dt_bias=None,
|
960 |
+
initial_states=None,
|
961 |
+
seq_idx=None,
|
962 |
+
cu_seqlens=None,
|
963 |
+
dt_softplus=False,
|
964 |
+
dt_limit=(0.0, float("inf")),
|
965 |
+
return_final_states=False,
|
966 |
+
return_varlen_states=False,
|
967 |
+
):
|
968 |
+
ctx.dt_dtype = dt.dtype
|
969 |
+
if not return_varlen_states:
|
970 |
+
cu_seqlens = None
|
971 |
+
else:
|
972 |
+
assert (
|
973 |
+
cu_seqlens is not None
|
974 |
+
), "cu_seqlens must be provided if return_varlen_states is True"
|
975 |
+
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = (
|
976 |
+
_mamba_chunk_scan_combined_fwd(
|
977 |
+
x,
|
978 |
+
dt,
|
979 |
+
A,
|
980 |
+
B,
|
981 |
+
C,
|
982 |
+
chunk_size,
|
983 |
+
D=D,
|
984 |
+
z=z,
|
985 |
+
dt_bias=dt_bias,
|
986 |
+
initial_states=initial_states,
|
987 |
+
seq_idx=seq_idx,
|
988 |
+
cu_seqlens=cu_seqlens,
|
989 |
+
dt_softplus=dt_softplus,
|
990 |
+
dt_limit=dt_limit,
|
991 |
+
)
|
992 |
+
)
|
993 |
+
ctx.save_for_backward(
|
994 |
+
out if z is None else out_x,
|
995 |
+
x,
|
996 |
+
dt,
|
997 |
+
dA_cumsum,
|
998 |
+
A,
|
999 |
+
B,
|
1000 |
+
C,
|
1001 |
+
D,
|
1002 |
+
z,
|
1003 |
+
dt_bias,
|
1004 |
+
initial_states,
|
1005 |
+
seq_idx,
|
1006 |
+
)
|
1007 |
+
ctx.dt_softplus = dt_softplus
|
1008 |
+
ctx.chunk_size = chunk_size
|
1009 |
+
ctx.dt_limit = dt_limit
|
1010 |
+
ctx.return_final_states = return_final_states
|
1011 |
+
ctx.return_varlen_states = return_varlen_states
|
1012 |
+
if not return_varlen_states:
|
1013 |
+
return out if not return_final_states else (out, final_states)
|
1014 |
+
else:
|
1015 |
+
varlen_states = rest[0]
|
1016 |
+
return (
|
1017 |
+
(out, varlen_states)
|
1018 |
+
if not return_final_states
|
1019 |
+
else (out, final_states, varlen_states)
|
1020 |
+
)
|
1021 |
+
|
1022 |
+
@staticmethod
|
1023 |
+
def backward(ctx, dout, *args):
|
1024 |
+
out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = (
|
1025 |
+
ctx.saved_tensors
|
1026 |
+
)
|
1027 |
+
assert (
|
1028 |
+
not ctx.return_varlen_states
|
1029 |
+
), "return_varlen_states is not supported in backward"
|
1030 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
1031 |
+
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = (
|
1032 |
+
_mamba_chunk_scan_combined_bwd(
|
1033 |
+
dout,
|
1034 |
+
x,
|
1035 |
+
dt,
|
1036 |
+
A,
|
1037 |
+
B,
|
1038 |
+
C,
|
1039 |
+
out,
|
1040 |
+
ctx.chunk_size,
|
1041 |
+
D=D,
|
1042 |
+
z=z,
|
1043 |
+
dt_bias=dt_bias,
|
1044 |
+
initial_states=initial_states,
|
1045 |
+
dfinal_states=dfinal_states,
|
1046 |
+
seq_idx=seq_idx,
|
1047 |
+
dt_softplus=ctx.dt_softplus,
|
1048 |
+
dt_limit=ctx.dt_limit,
|
1049 |
+
)
|
1050 |
+
)
|
1051 |
+
return (
|
1052 |
+
dx,
|
1053 |
+
ddt,
|
1054 |
+
dA,
|
1055 |
+
dB,
|
1056 |
+
dC,
|
1057 |
+
None,
|
1058 |
+
dD,
|
1059 |
+
dz,
|
1060 |
+
ddt_bias,
|
1061 |
+
dinitial_states,
|
1062 |
+
None,
|
1063 |
+
None,
|
1064 |
+
None,
|
1065 |
+
None,
|
1066 |
+
None,
|
1067 |
+
None,
|
1068 |
+
)
|
1069 |
+
|
1070 |
+
|
1071 |
+
def mamba_chunk_scan_combined(
|
1072 |
+
x,
|
1073 |
+
dt,
|
1074 |
+
A,
|
1075 |
+
B,
|
1076 |
+
C,
|
1077 |
+
chunk_size,
|
1078 |
+
D=None,
|
1079 |
+
z=None,
|
1080 |
+
dt_bias=None,
|
1081 |
+
initial_states=None,
|
1082 |
+
seq_idx=None,
|
1083 |
+
cu_seqlens=None,
|
1084 |
+
dt_softplus=False,
|
1085 |
+
dt_limit=(0.0, float("inf")),
|
1086 |
+
return_final_states=False,
|
1087 |
+
return_varlen_states=False,
|
1088 |
+
):
|
1089 |
+
"""
|
1090 |
+
Argument:
|
1091 |
+
x: (batch, seqlen, nheads, headdim)
|
1092 |
+
dt: (batch, seqlen, nheads)
|
1093 |
+
A: (nheads)
|
1094 |
+
B: (batch, seqlen, ngroups, dstate)
|
1095 |
+
C: (batch, seqlen, ngroups, dstate)
|
1096 |
+
chunk_size: int
|
1097 |
+
D: (nheads, headdim) or (nheads,)
|
1098 |
+
z: (batch, seqlen, nheads, headdim)
|
1099 |
+
dt_bias: (nheads,)
|
1100 |
+
initial_states: (batch, nheads, headdim, dstate)
|
1101 |
+
seq_idx: (batch, seqlen)
|
1102 |
+
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
1103 |
+
dt_softplus: Whether to apply softplus to dt
|
1104 |
+
Return:
|
1105 |
+
out: (batch, seqlen, nheads, headdim)
|
1106 |
+
"""
|
1107 |
+
return MambaChunkScanCombinedFn.apply(
|
1108 |
+
x,
|
1109 |
+
dt,
|
1110 |
+
A,
|
1111 |
+
B,
|
1112 |
+
C,
|
1113 |
+
chunk_size,
|
1114 |
+
D,
|
1115 |
+
z,
|
1116 |
+
dt_bias,
|
1117 |
+
initial_states,
|
1118 |
+
seq_idx,
|
1119 |
+
cu_seqlens,
|
1120 |
+
dt_softplus,
|
1121 |
+
dt_limit,
|
1122 |
+
return_final_states,
|
1123 |
+
return_varlen_states,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
|
1127 |
+
def mamba_chunk_scan(
|
1128 |
+
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
1129 |
+
):
|
1130 |
+
"""
|
1131 |
+
Argument:
|
1132 |
+
x: (batch, seqlen, nheads, headdim)
|
1133 |
+
dt: (batch, seqlen, nheads)
|
1134 |
+
A: (nheads)
|
1135 |
+
B: (batch, seqlen, ngroups, dstate)
|
1136 |
+
C: (batch, seqlen, ngroups, dstate)
|
1137 |
+
D: (nheads, headdim) or (nheads,)
|
1138 |
+
z: (batch, seqlen, nheads, headdim)
|
1139 |
+
dt_bias: (nheads,)
|
1140 |
+
Return:
|
1141 |
+
out: (batch, seqlen, nheads, headdim)
|
1142 |
+
"""
|
1143 |
+
batch, seqlen, nheads, headdim = x.shape
|
1144 |
+
dstate = B.shape[-1]
|
1145 |
+
if seqlen % chunk_size != 0:
|
1146 |
+
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
1147 |
+
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
1148 |
+
dt = dt.float() # We want high precision for this before cumsum
|
1149 |
+
if dt_bias is not None:
|
1150 |
+
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
1151 |
+
if dt_softplus:
|
1152 |
+
dt = F.softplus(dt)
|
1153 |
+
dA = dt * rearrange(A, "h -> h 1 1")
|
1154 |
+
dA = dt * rearrange(A, "h -> h 1 1")
|
1155 |
+
dA_cumsum = torch.cumsum(dA, dim=-1)
|
1156 |
+
# 1. Compute the state for each chunk
|
1157 |
+
states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
|
1158 |
+
# 2. Pass the state to all the chunks by weighted cumsum.
|
1159 |
+
states = rearrange(
|
1160 |
+
state_passing(
|
1161 |
+
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
1162 |
+
)[0],
|
1163 |
+
"... (p n) -> ... p n",
|
1164 |
+
n=dstate,
|
1165 |
+
)
|
1166 |
+
# 3. Compute the output for each chunk
|
1167 |
+
out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
1168 |
+
return out
|
1169 |
+
|
1170 |
+
|
1171 |
+
def ssd_chunk_scan_combined_ref(
|
1172 |
+
x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False
|
1173 |
+
):
|
1174 |
+
"""
|
1175 |
+
Argument:
|
1176 |
+
x: (batch, seqlen, nheads, headdim)
|
1177 |
+
dt: (batch, seqlen, nheads)
|
1178 |
+
A: (nheads)
|
1179 |
+
B: (batch, seqlen, ngroups, dstate)
|
1180 |
+
C: (batch, seqlen, ngroups, dstate)
|
1181 |
+
D: (nheads, headdim) or (nheads,)
|
1182 |
+
z: (batch, seqlen, nheads, headdim)
|
1183 |
+
dt_bias: (nheads,)
|
1184 |
+
Return:
|
1185 |
+
out: (batch, seqlen, nheads, headdim)
|
1186 |
+
"""
|
1187 |
+
batch, seqlen, nheads, headdim = x.shape
|
1188 |
+
dstate = B.shape[-1]
|
1189 |
+
if seqlen % chunk_size != 0:
|
1190 |
+
dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
|
1191 |
+
dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
|
1192 |
+
dt = dt.float() # We want high precision for this before cumsum
|
1193 |
+
if dt_bias is not None:
|
1194 |
+
dt = dt + rearrange(dt_bias, "h -> h 1 1")
|
1195 |
+
if dt_softplus:
|
1196 |
+
dt = F.softplus(dt)
|
1197 |
+
dA = dt * rearrange(A, "h -> h 1 1")
|
1198 |
+
dA_cumsum = torch.cumsum(dA, dim=-1)
|
1199 |
+
# 1. Compute the state for each chunk
|
1200 |
+
states = chunk_state_ref(B, x, dt, dA_cumsum)
|
1201 |
+
states_dtype = states.dtype
|
1202 |
+
if states.dtype not in [torch.float32, torch.float64]:
|
1203 |
+
states = states.to(torch.float32)
|
1204 |
+
# 2. Pass the state to all the chunks by weighted cumsum.
|
1205 |
+
# state_passing_ref is much less numerically stable
|
1206 |
+
states = rearrange(
|
1207 |
+
state_passing_ref(
|
1208 |
+
rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]
|
1209 |
+
)[0],
|
1210 |
+
"... (p n) -> ... p n",
|
1211 |
+
n=dstate,
|
1212 |
+
)
|
1213 |
+
states = states.to(states_dtype)
|
1214 |
+
# 3. Compute the output for each chunk
|
1215 |
+
out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
1216 |
+
return out
|
1217 |
+
|
1218 |
+
|
1219 |
+
def ssd_selective_scan(
|
1220 |
+
x,
|
1221 |
+
dt,
|
1222 |
+
A,
|
1223 |
+
B,
|
1224 |
+
C,
|
1225 |
+
D=None,
|
1226 |
+
z=None,
|
1227 |
+
dt_bias=None,
|
1228 |
+
dt_softplus=False,
|
1229 |
+
dt_limit=(0.0, float("inf")),
|
1230 |
+
):
|
1231 |
+
"""
|
1232 |
+
Argument:
|
1233 |
+
x: (batch, seqlen, nheads, headdim)
|
1234 |
+
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
1235 |
+
A: (nheads) or (dim, dstate)
|
1236 |
+
B: (batch, seqlen, ngroups, dstate)
|
1237 |
+
C: (batch, seqlen, ngroups, dstate)
|
1238 |
+
D: (nheads, headdim) or (nheads,)
|
1239 |
+
z: (batch, seqlen, nheads, headdim)
|
1240 |
+
dt_bias: (nheads,) or (nheads, headdim)
|
1241 |
+
Return:
|
1242 |
+
out: (batch, seqlen, nheads, headdim)
|
1243 |
+
"""
|
1244 |
+
from ..selective_scan_interface import selective_scan_fn
|
1245 |
+
|
1246 |
+
batch, seqlen, nheads, headdim = x.shape
|
1247 |
+
_, _, ngroups, dstate = B.shape
|
1248 |
+
x = rearrange(x, "b l h p -> b (h p) l")
|
1249 |
+
if dt.dim() == 3:
|
1250 |
+
dt = repeat(dt, "b l h -> b l h p", p=headdim)
|
1251 |
+
dt = rearrange(dt, "b l h p -> b (h p) l")
|
1252 |
+
if A.dim() == 1:
|
1253 |
+
A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
|
1254 |
+
else:
|
1255 |
+
A = A.to(dtype=torch.float32)
|
1256 |
+
B = rearrange(B, "b l g n -> b g n l")
|
1257 |
+
C = rearrange(C, "b l g n -> b g n l")
|
1258 |
+
if D is not None:
|
1259 |
+
if D.dim() == 2:
|
1260 |
+
D = rearrange(D, "h p -> (h p)")
|
1261 |
+
else:
|
1262 |
+
D = repeat(D, "h -> (h p)", p=headdim)
|
1263 |
+
if z is not None:
|
1264 |
+
z = rearrange(z, "b l h p -> b (h p) l")
|
1265 |
+
if dt_bias is not None:
|
1266 |
+
if dt_bias.dim() == 1:
|
1267 |
+
dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
|
1268 |
+
dt_bias = rearrange(dt_bias, "h p -> (h p)")
|
1269 |
+
if dt_limit != (0.0, float("inf")):
|
1270 |
+
if dt_bias is not None:
|
1271 |
+
dt = dt + rearrange(dt_bias, "d -> d 1")
|
1272 |
+
if dt_softplus:
|
1273 |
+
dt = F.softplus(dt)
|
1274 |
+
dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
|
1275 |
+
dt_bias = None
|
1276 |
+
dt_softplus = None
|
1277 |
+
out = selective_scan_fn(
|
1278 |
+
x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus
|
1279 |
+
)
|
1280 |
+
return rearrange(out, "b (h p) l -> b l h p", p=headdim)
|
1281 |
+
|
1282 |
+
|
1283 |
+
def mamba_conv1d_scan_ref(
|
1284 |
+
xBC,
|
1285 |
+
conv1d_weight,
|
1286 |
+
conv1d_bias,
|
1287 |
+
dt,
|
1288 |
+
A,
|
1289 |
+
chunk_size,
|
1290 |
+
D=None,
|
1291 |
+
z=None,
|
1292 |
+
dt_bias=None,
|
1293 |
+
dt_softplus=False,
|
1294 |
+
dt_limit=(0.0, float("inf")),
|
1295 |
+
activation="silu",
|
1296 |
+
headdim=None,
|
1297 |
+
ngroups=1,
|
1298 |
+
):
|
1299 |
+
"""
|
1300 |
+
Argument:
|
1301 |
+
xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
|
1302 |
+
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
1303 |
+
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
1304 |
+
dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
|
1305 |
+
A: (nheads)
|
1306 |
+
D: (nheads, headdim) or (nheads,)
|
1307 |
+
z: (batch, seqlen, dim)
|
1308 |
+
dt_bias: (nheads) or (nheads, headdim)
|
1309 |
+
headdim: if D is 1D and z is None, headdim must be passed in
|
1310 |
+
Return:
|
1311 |
+
out: (batch, seqlen, dim)
|
1312 |
+
"""
|
1313 |
+
batch, seqlen, nheads = dt.shape[:3]
|
1314 |
+
assert nheads % ngroups == 0
|
1315 |
+
if z is not None:
|
1316 |
+
dim = z.shape[-1]
|
1317 |
+
assert dim % nheads == 0
|
1318 |
+
headdim = dim // nheads
|
1319 |
+
else:
|
1320 |
+
if D.dim() == 1:
|
1321 |
+
assert headdim is not None
|
1322 |
+
else:
|
1323 |
+
headdim = D.shape[1]
|
1324 |
+
dim = nheads * headdim
|
1325 |
+
xBC = rearrange(
|
1326 |
+
causal_conv1d_fn(
|
1327 |
+
rearrange(xBC, "b s d -> b d s"),
|
1328 |
+
conv1d_weight,
|
1329 |
+
conv1d_bias,
|
1330 |
+
activation=activation,
|
1331 |
+
),
|
1332 |
+
"b d s -> b s d",
|
1333 |
+
)
|
1334 |
+
dstate = (xBC.shape[-1] - dim) // ngroups // 2
|
1335 |
+
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
1336 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
1337 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
1338 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
1339 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
1340 |
+
out = ssd_selective_scan(
|
1341 |
+
x,
|
1342 |
+
dt.to(x.dtype),
|
1343 |
+
A,
|
1344 |
+
B,
|
1345 |
+
C,
|
1346 |
+
D=D.float(),
|
1347 |
+
z=z,
|
1348 |
+
dt_bias=dt_bias,
|
1349 |
+
dt_softplus=dt_softplus,
|
1350 |
+
dt_limit=dt_limit,
|
1351 |
+
)
|
1352 |
+
return rearrange(out, "b s h p -> b s (h p)")
|
1353 |
+
|
1354 |
+
|
1355 |
+
class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
|
1356 |
+
|
1357 |
+
@staticmethod
|
1358 |
+
@custom_fwd
|
1359 |
+
def forward(
|
1360 |
+
ctx,
|
1361 |
+
zxbcdt,
|
1362 |
+
conv1d_weight,
|
1363 |
+
conv1d_bias,
|
1364 |
+
dt_bias,
|
1365 |
+
A,
|
1366 |
+
D,
|
1367 |
+
chunk_size,
|
1368 |
+
initial_states=None,
|
1369 |
+
seq_idx=None,
|
1370 |
+
dt_limit=(0.0, float("inf")),
|
1371 |
+
return_final_states=False,
|
1372 |
+
activation="silu",
|
1373 |
+
rmsnorm_weight=None,
|
1374 |
+
rmsnorm_eps=1e-6,
|
1375 |
+
outproj_weight=None,
|
1376 |
+
outproj_bias=None,
|
1377 |
+
headdim=None,
|
1378 |
+
ngroups=1,
|
1379 |
+
norm_before_gate=True,
|
1380 |
+
):
|
1381 |
+
assert activation in [None, "silu", "swish"]
|
1382 |
+
if D.dim() == 1:
|
1383 |
+
assert headdim is not None
|
1384 |
+
(nheads,) = D.shape
|
1385 |
+
else:
|
1386 |
+
nheads, headdim = D.shape
|
1387 |
+
batch, seqlen, _ = zxbcdt.shape
|
1388 |
+
dim = nheads * headdim
|
1389 |
+
assert nheads % ngroups == 0
|
1390 |
+
dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
|
1391 |
+
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
|
1392 |
+
assert d_nonssm >= 0
|
1393 |
+
assert zxbcdt.shape == (
|
1394 |
+
batch,
|
1395 |
+
seqlen,
|
1396 |
+
2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads,
|
1397 |
+
)
|
1398 |
+
assert dt_bias.shape == (nheads,)
|
1399 |
+
assert A.shape == (nheads,)
|
1400 |
+
zx0, z, xBC, dt = torch.split(
|
1401 |
+
zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1
|
1402 |
+
)
|
1403 |
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
1404 |
+
xBC_conv = rearrange(
|
1405 |
+
causal_conv1d_cuda.causal_conv1d_fwd(
|
1406 |
+
rearrange(xBC, "b s d -> b d s"),
|
1407 |
+
conv1d_weight,
|
1408 |
+
conv1d_bias,
|
1409 |
+
seq_idx,
|
1410 |
+
None,
|
1411 |
+
None,
|
1412 |
+
activation in ["silu", "swish"],
|
1413 |
+
),
|
1414 |
+
"b d s -> b s d",
|
1415 |
+
)
|
1416 |
+
x, B, C = torch.split(
|
1417 |
+
xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1
|
1418 |
+
)
|
1419 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
1420 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
1421 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
1422 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
|
1423 |
+
if rmsnorm_weight is None:
|
1424 |
+
out, out_x, dt_out, dA_cumsum, states, final_states = (
|
1425 |
+
_mamba_chunk_scan_combined_fwd(
|
1426 |
+
x,
|
1427 |
+
dt,
|
1428 |
+
A,
|
1429 |
+
B,
|
1430 |
+
C,
|
1431 |
+
chunk_size=chunk_size,
|
1432 |
+
D=D,
|
1433 |
+
z=z,
|
1434 |
+
dt_bias=dt_bias,
|
1435 |
+
initial_states=initial_states,
|
1436 |
+
seq_idx=seq_idx,
|
1437 |
+
dt_softplus=True,
|
1438 |
+
dt_limit=dt_limit,
|
1439 |
+
)
|
1440 |
+
)
|
1441 |
+
out = rearrange(out, "b s h p -> b s (h p)")
|
1442 |
+
rstd = None
|
1443 |
+
if d_nonssm > 0:
|
1444 |
+
out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
|
1445 |
+
else:
|
1446 |
+
out_x, _, dt_out, dA_cumsum, states, final_states = (
|
1447 |
+
_mamba_chunk_scan_combined_fwd(
|
1448 |
+
x,
|
1449 |
+
dt,
|
1450 |
+
A,
|
1451 |
+
B,
|
1452 |
+
C,
|
1453 |
+
chunk_size=chunk_size,
|
1454 |
+
D=D,
|
1455 |
+
z=None,
|
1456 |
+
dt_bias=dt_bias,
|
1457 |
+
initial_states=initial_states,
|
1458 |
+
seq_idx=seq_idx,
|
1459 |
+
dt_softplus=True,
|
1460 |
+
dt_limit=dt_limit,
|
1461 |
+
)
|
1462 |
+
)
|
1463 |
+
# reshape input data into 2D tensor
|
1464 |
+
x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
|
1465 |
+
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
1466 |
+
rmsnorm_weight = rmsnorm_weight.contiguous()
|
1467 |
+
if d_nonssm == 0:
|
1468 |
+
out = None
|
1469 |
+
else:
|
1470 |
+
out01 = torch.empty(
|
1471 |
+
(batch, seqlen, d_nonssm + dim),
|
1472 |
+
dtype=x_rms.dtype,
|
1473 |
+
device=x_rms.device,
|
1474 |
+
)
|
1475 |
+
out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
|
1476 |
+
_swiglu_fwd(zx0, out=out01[..., :d_nonssm])
|
1477 |
+
out, _, rstd = _layer_norm_fwd(
|
1478 |
+
x_rms,
|
1479 |
+
rmsnorm_weight,
|
1480 |
+
None,
|
1481 |
+
rmsnorm_eps,
|
1482 |
+
z_rms,
|
1483 |
+
out=out,
|
1484 |
+
group_size=dim // ngroups,
|
1485 |
+
norm_before_gate=norm_before_gate,
|
1486 |
+
is_rms_norm=True,
|
1487 |
+
)
|
1488 |
+
if d_nonssm == 0:
|
1489 |
+
out = rearrange(out, "(b s) d -> b s d", b=batch)
|
1490 |
+
else:
|
1491 |
+
out = out01
|
1492 |
+
ctx.outproj_weight_dtype = (
|
1493 |
+
outproj_weight.dtype if outproj_weight is not None else None
|
1494 |
+
)
|
1495 |
+
if outproj_weight is not None:
|
1496 |
+
if torch.is_autocast_enabled():
|
1497 |
+
dtype = torch.get_autocast_gpu_dtype()
|
1498 |
+
out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
|
1499 |
+
outproj_bias = (
|
1500 |
+
outproj_bias.to(dtype) if outproj_bias is not None else None
|
1501 |
+
)
|
1502 |
+
out = F.linear(out, outproj_weight, outproj_bias)
|
1503 |
+
else:
|
1504 |
+
assert outproj_bias is None
|
1505 |
+
ctx.save_for_backward(
|
1506 |
+
zxbcdt,
|
1507 |
+
conv1d_weight,
|
1508 |
+
conv1d_bias,
|
1509 |
+
out_x,
|
1510 |
+
A,
|
1511 |
+
D,
|
1512 |
+
dt_bias,
|
1513 |
+
initial_states,
|
1514 |
+
seq_idx,
|
1515 |
+
rmsnorm_weight,
|
1516 |
+
rstd,
|
1517 |
+
outproj_weight,
|
1518 |
+
outproj_bias,
|
1519 |
+
)
|
1520 |
+
ctx.dt_limit = dt_limit
|
1521 |
+
ctx.return_final_states = return_final_states
|
1522 |
+
ctx.activation = activation
|
1523 |
+
ctx.rmsnorm_eps = rmsnorm_eps
|
1524 |
+
ctx.norm_before_gate = norm_before_gate
|
1525 |
+
ctx.chunk_size = chunk_size
|
1526 |
+
ctx.headdim = headdim
|
1527 |
+
ctx.ngroups = ngroups
|
1528 |
+
return out if not return_final_states else (out, final_states)
|
1529 |
+
|
1530 |
+
@staticmethod
|
1531 |
+
@custom_bwd
|
1532 |
+
def backward(ctx, dout, *args):
|
1533 |
+
(
|
1534 |
+
zxbcdt,
|
1535 |
+
conv1d_weight,
|
1536 |
+
conv1d_bias,
|
1537 |
+
out,
|
1538 |
+
A,
|
1539 |
+
D,
|
1540 |
+
dt_bias,
|
1541 |
+
initial_states,
|
1542 |
+
seq_idx,
|
1543 |
+
rmsnorm_weight,
|
1544 |
+
rstd,
|
1545 |
+
outproj_weight,
|
1546 |
+
outproj_bias,
|
1547 |
+
) = ctx.saved_tensors
|
1548 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
1549 |
+
headdim = ctx.headdim
|
1550 |
+
nheads = D.shape[0]
|
1551 |
+
dim = nheads * headdim
|
1552 |
+
assert nheads % ctx.ngroups == 0
|
1553 |
+
dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
|
1554 |
+
d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
|
1555 |
+
assert d_nonssm >= 0
|
1556 |
+
recompute_output = outproj_weight is not None
|
1557 |
+
if recompute_output:
|
1558 |
+
out_recompute = torch.empty(
|
1559 |
+
*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype
|
1560 |
+
)
|
1561 |
+
out0_recompute, out1_recompute = out_recompute.split(
|
1562 |
+
[d_nonssm, dim], dim=-1
|
1563 |
+
)
|
1564 |
+
zx0, z, xBC, dt = torch.split(
|
1565 |
+
zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
1566 |
+
)
|
1567 |
+
# Recompute x, B, C
|
1568 |
+
xBC_conv = rearrange(
|
1569 |
+
causal_conv1d_cuda.causal_conv1d_fwd(
|
1570 |
+
rearrange(xBC, "b s d -> b d s"),
|
1571 |
+
conv1d_weight,
|
1572 |
+
conv1d_bias,
|
1573 |
+
seq_idx,
|
1574 |
+
None,
|
1575 |
+
None,
|
1576 |
+
ctx.activation in ["silu", "swish"],
|
1577 |
+
),
|
1578 |
+
"b d s -> b s d",
|
1579 |
+
)
|
1580 |
+
x, B, C = torch.split(
|
1581 |
+
xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
1582 |
+
)
|
1583 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
1584 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
|
1585 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
|
1586 |
+
dzxbcdt = torch.empty_like(zxbcdt)
|
1587 |
+
dzx0, dz, dxBC_given, ddt_given = torch.split(
|
1588 |
+
dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1
|
1589 |
+
)
|
1590 |
+
dxBC = torch.empty_like(xBC)
|
1591 |
+
dx, dB, dC = torch.split(
|
1592 |
+
dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1
|
1593 |
+
)
|
1594 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
1595 |
+
dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
|
1596 |
+
dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
|
1597 |
+
dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
|
1598 |
+
if outproj_weight is not None:
|
1599 |
+
dout_og = dout
|
1600 |
+
dout = F.linear(dout, outproj_weight.t())
|
1601 |
+
if d_nonssm > 0:
|
1602 |
+
dout0, dout = dout.split([d_nonssm, dim], dim=-1)
|
1603 |
+
_swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
|
1604 |
+
dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
|
1605 |
+
if rmsnorm_weight is None:
|
1606 |
+
dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
|
1607 |
+
dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = (
|
1608 |
+
_mamba_chunk_scan_combined_bwd(
|
1609 |
+
dout,
|
1610 |
+
x,
|
1611 |
+
dt,
|
1612 |
+
A,
|
1613 |
+
B,
|
1614 |
+
C,
|
1615 |
+
out,
|
1616 |
+
ctx.chunk_size,
|
1617 |
+
D=D,
|
1618 |
+
z=z,
|
1619 |
+
dt_bias=dt_bias,
|
1620 |
+
initial_states=initial_states,
|
1621 |
+
dfinal_states=dfinal_states,
|
1622 |
+
seq_idx=seq_idx,
|
1623 |
+
dt_softplus=True,
|
1624 |
+
dt_limit=ctx.dt_limit,
|
1625 |
+
dx=dx,
|
1626 |
+
ddt=ddt_given,
|
1627 |
+
dB=dB,
|
1628 |
+
dC=dC,
|
1629 |
+
dz=dz,
|
1630 |
+
recompute_output=recompute_output,
|
1631 |
+
)
|
1632 |
+
)
|
1633 |
+
out_for_linear = (
|
1634 |
+
rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
|
1635 |
+
)
|
1636 |
+
drmsnorm_weight = None
|
1637 |
+
else:
|
1638 |
+
batch = dout.shape[0]
|
1639 |
+
dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
|
1640 |
+
dz = rearrange(dz, "b l d -> (b l) d")
|
1641 |
+
x_rms = rearrange(out, "b s h p -> (b s) (h p)")
|
1642 |
+
z_rms = rearrange(z, "b s h p -> (b s) (h p)")
|
1643 |
+
out1_recompute = (
|
1644 |
+
rearrange(out1_recompute, "b s d -> (b s) d")
|
1645 |
+
if recompute_output
|
1646 |
+
else None
|
1647 |
+
)
|
1648 |
+
dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(
|
1649 |
+
dy_rms,
|
1650 |
+
x_rms,
|
1651 |
+
rmsnorm_weight,
|
1652 |
+
None,
|
1653 |
+
ctx.rmsnorm_eps,
|
1654 |
+
None,
|
1655 |
+
rstd,
|
1656 |
+
z_rms,
|
1657 |
+
group_size=dim // ctx.ngroups,
|
1658 |
+
norm_before_gate=ctx.norm_before_gate,
|
1659 |
+
is_rms_norm=True,
|
1660 |
+
recompute_output=recompute_output,
|
1661 |
+
dz=dz,
|
1662 |
+
out=out1_recompute if recompute_output else None,
|
1663 |
+
)
|
1664 |
+
out_for_linear = out_recompute if recompute_output else None
|
1665 |
+
dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
|
1666 |
+
dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = (
|
1667 |
+
_mamba_chunk_scan_combined_bwd(
|
1668 |
+
dout,
|
1669 |
+
x,
|
1670 |
+
dt,
|
1671 |
+
A,
|
1672 |
+
B,
|
1673 |
+
C,
|
1674 |
+
out,
|
1675 |
+
ctx.chunk_size,
|
1676 |
+
D=D,
|
1677 |
+
z=None,
|
1678 |
+
dt_bias=dt_bias,
|
1679 |
+
initial_states=initial_states,
|
1680 |
+
dfinal_states=dfinal_states,
|
1681 |
+
seq_idx=seq_idx,
|
1682 |
+
dt_softplus=True,
|
1683 |
+
dt_limit=ctx.dt_limit,
|
1684 |
+
dx=dx,
|
1685 |
+
ddt=ddt_given,
|
1686 |
+
dB=dB,
|
1687 |
+
dC=dC,
|
1688 |
+
)
|
1689 |
+
)
|
1690 |
+
|
1691 |
+
if outproj_weight is not None:
|
1692 |
+
doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
|
1693 |
+
doutproj_bias = (
|
1694 |
+
dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
|
1695 |
+
)
|
1696 |
+
else:
|
1697 |
+
doutproj_weight, doutproj_bias = None, None
|
1698 |
+
dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
|
1699 |
+
dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
1700 |
+
rearrange(xBC, "b s d -> b d s"),
|
1701 |
+
conv1d_weight,
|
1702 |
+
conv1d_bias,
|
1703 |
+
rearrange(dxBC, "b s d -> b d s"),
|
1704 |
+
seq_idx,
|
1705 |
+
None,
|
1706 |
+
None,
|
1707 |
+
dxBC_given,
|
1708 |
+
False,
|
1709 |
+
ctx.activation in ["silu", "swish"],
|
1710 |
+
)
|
1711 |
+
dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
|
1712 |
+
return (
|
1713 |
+
dzxbcdt,
|
1714 |
+
dweight,
|
1715 |
+
dbias,
|
1716 |
+
ddt_bias,
|
1717 |
+
dA,
|
1718 |
+
dD,
|
1719 |
+
None,
|
1720 |
+
dinitial_states,
|
1721 |
+
None,
|
1722 |
+
None,
|
1723 |
+
None,
|
1724 |
+
None,
|
1725 |
+
drmsnorm_weight,
|
1726 |
+
None,
|
1727 |
+
doutproj_weight,
|
1728 |
+
doutproj_bias,
|
1729 |
+
None,
|
1730 |
+
None,
|
1731 |
+
None,
|
1732 |
+
)
|
1733 |
+
|
1734 |
+
|
1735 |
+
def mamba_split_conv1d_scan_combined(
|
1736 |
+
zxbcdt,
|
1737 |
+
conv1d_weight,
|
1738 |
+
conv1d_bias,
|
1739 |
+
dt_bias,
|
1740 |
+
A,
|
1741 |
+
D,
|
1742 |
+
chunk_size,
|
1743 |
+
initial_states=None,
|
1744 |
+
seq_idx=None,
|
1745 |
+
dt_limit=(0.0, float("inf")),
|
1746 |
+
return_final_states=False,
|
1747 |
+
activation="silu",
|
1748 |
+
rmsnorm_weight=None,
|
1749 |
+
rmsnorm_eps=1e-6,
|
1750 |
+
outproj_weight=None,
|
1751 |
+
outproj_bias=None,
|
1752 |
+
headdim=None,
|
1753 |
+
ngroups=1,
|
1754 |
+
norm_before_gate=True,
|
1755 |
+
):
|
1756 |
+
"""
|
1757 |
+
Argument:
|
1758 |
+
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
1759 |
+
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
1760 |
+
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
1761 |
+
dt_bias: (nheads,)
|
1762 |
+
A: (nheads)
|
1763 |
+
D: (nheads, headdim) or (nheads,)
|
1764 |
+
initial_states: (batch, nheads, headdim, dstate)
|
1765 |
+
seq_idx: (batch, seqlen), int32
|
1766 |
+
rmsnorm_weight: (dim,)
|
1767 |
+
outproj_weight: (out_dim, dim)
|
1768 |
+
outproj_bias: (out_dim,)
|
1769 |
+
headdim: if D is 1D, headdim must be passed in
|
1770 |
+
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
1771 |
+
Return:
|
1772 |
+
out: (batch, seqlen, dim)
|
1773 |
+
"""
|
1774 |
+
return MambaSplitConv1dScanCombinedFn.apply(
|
1775 |
+
zxbcdt,
|
1776 |
+
conv1d_weight,
|
1777 |
+
conv1d_bias,
|
1778 |
+
dt_bias,
|
1779 |
+
A,
|
1780 |
+
D,
|
1781 |
+
chunk_size,
|
1782 |
+
initial_states,
|
1783 |
+
seq_idx,
|
1784 |
+
dt_limit,
|
1785 |
+
return_final_states,
|
1786 |
+
activation,
|
1787 |
+
rmsnorm_weight,
|
1788 |
+
rmsnorm_eps,
|
1789 |
+
outproj_weight,
|
1790 |
+
outproj_bias,
|
1791 |
+
headdim,
|
1792 |
+
ngroups,
|
1793 |
+
norm_before_gate,
|
1794 |
+
)
|
1795 |
+
|
1796 |
+
|
1797 |
+
def mamba_split_conv1d_scan_ref(
|
1798 |
+
zxbcdt,
|
1799 |
+
conv1d_weight,
|
1800 |
+
conv1d_bias,
|
1801 |
+
dt_bias,
|
1802 |
+
A,
|
1803 |
+
D,
|
1804 |
+
chunk_size,
|
1805 |
+
dt_limit=(0.0, float("inf")),
|
1806 |
+
activation="silu",
|
1807 |
+
rmsnorm_weight=None,
|
1808 |
+
rmsnorm_eps=1e-6,
|
1809 |
+
outproj_weight=None,
|
1810 |
+
outproj_bias=None,
|
1811 |
+
headdim=None,
|
1812 |
+
ngroups=1,
|
1813 |
+
norm_before_gate=True,
|
1814 |
+
):
|
1815 |
+
"""
|
1816 |
+
Argument:
|
1817 |
+
zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
|
1818 |
+
conv1d_weight: (dim + 2 * ngroups * dstate, width)
|
1819 |
+
conv1d_bias: (dim + 2 * ngroups * dstate,)
|
1820 |
+
dt_bias: (nheads,)
|
1821 |
+
A: (nheads)
|
1822 |
+
D: (nheads, headdim) or (nheads,)
|
1823 |
+
rmsnorm_weight: (dim,)
|
1824 |
+
outproj_weight: (out_dim, dim)
|
1825 |
+
outproj_bias: (out_dim,)
|
1826 |
+
headdim: if D is 1D, headdim must be passed in
|
1827 |
+
norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
|
1828 |
+
Return:
|
1829 |
+
out: (batch, seqlen, dim)
|
1830 |
+
"""
|
1831 |
+
if D.dim() == 1:
|
1832 |
+
assert headdim is not None
|
1833 |
+
(nheads,) = D.shape
|
1834 |
+
else:
|
1835 |
+
nheads, headdim = D.shape
|
1836 |
+
assert nheads % ngroups == 0
|
1837 |
+
batch, seqlen, _ = zxbcdt.shape
|
1838 |
+
dim = nheads * headdim
|
1839 |
+
dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
|
1840 |
+
assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
|
1841 |
+
assert dt_bias.shape == (nheads,)
|
1842 |
+
assert A.shape == (nheads,)
|
1843 |
+
if rmsnorm_weight is not None:
|
1844 |
+
assert rmsnorm_weight.shape == (dim,)
|
1845 |
+
z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
|
1846 |
+
xBC = rearrange(
|
1847 |
+
causal_conv1d_fn(
|
1848 |
+
rearrange(xBC, "b s d -> b d s"),
|
1849 |
+
conv1d_weight,
|
1850 |
+
conv1d_bias,
|
1851 |
+
activation=activation,
|
1852 |
+
),
|
1853 |
+
"b d s -> b s d",
|
1854 |
+
)
|
1855 |
+
x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
|
1856 |
+
x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
|
1857 |
+
B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
|
1858 |
+
C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
|
1859 |
+
z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
|
1860 |
+
out = ssd_selective_scan(
|
1861 |
+
x,
|
1862 |
+
dt.to(x.dtype),
|
1863 |
+
A,
|
1864 |
+
B,
|
1865 |
+
C,
|
1866 |
+
D=D.float(),
|
1867 |
+
z=z if rmsnorm_weight is None else None,
|
1868 |
+
dt_bias=dt_bias,
|
1869 |
+
dt_softplus=True,
|
1870 |
+
dt_limit=dt_limit,
|
1871 |
+
)
|
1872 |
+
out = rearrange(out, "b s h p -> b s (h p)")
|
1873 |
+
if rmsnorm_weight is not None:
|
1874 |
+
out = rmsnorm_fn(
|
1875 |
+
out,
|
1876 |
+
rmsnorm_weight,
|
1877 |
+
None,
|
1878 |
+
z=rearrange(z, "b l h p -> b l (h p)"),
|
1879 |
+
eps=rmsnorm_eps,
|
1880 |
+
norm_before_gate=norm_before_gate,
|
1881 |
+
)
|
1882 |
+
if outproj_weight is not None:
|
1883 |
+
out = F.linear(out, outproj_weight, outproj_bias)
|
1884 |
+
return out
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/ops/triton/ssd_state_passing.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
|
16 |
+
@triton.autotune(
|
17 |
+
configs=[
|
18 |
+
triton.Config({'BLOCK_SIZE': 64}),
|
19 |
+
triton.Config({'BLOCK_SIZE': 128}),
|
20 |
+
triton.Config({'BLOCK_SIZE': 256}),
|
21 |
+
triton.Config({'BLOCK_SIZE': 512}),
|
22 |
+
triton.Config({'BLOCK_SIZE': 1024}),
|
23 |
+
triton.Config({'BLOCK_SIZE': 2048}),
|
24 |
+
],
|
25 |
+
key=['dim'],
|
26 |
+
)
|
27 |
+
@triton.jit
|
28 |
+
def _state_passing_fwd_kernel(
|
29 |
+
# Pointers to matrices
|
30 |
+
states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
|
31 |
+
# Matrix dimensions
|
32 |
+
dim, nchunks, seqlen, chunk_size,
|
33 |
+
# Strides
|
34 |
+
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
|
35 |
+
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
|
36 |
+
stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
|
37 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
|
38 |
+
stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
|
39 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
40 |
+
# Meta-parameters
|
41 |
+
HAS_INITSTATES: tl.constexpr,
|
42 |
+
HAS_SEQ_IDX: tl.constexpr,
|
43 |
+
BLOCK_SIZE: tl.constexpr,
|
44 |
+
):
|
45 |
+
pid_b = tl.program_id(axis=1)
|
46 |
+
pid_h = tl.program_id(axis=2)
|
47 |
+
pid_m = tl.program_id(axis=0)
|
48 |
+
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
49 |
+
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
|
50 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
51 |
+
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
|
52 |
+
if HAS_INITSTATES:
|
53 |
+
initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
|
54 |
+
if HAS_SEQ_IDX:
|
55 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
56 |
+
|
57 |
+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
58 |
+
states_ptrs = states_ptr + offs_m * stride_states_dim
|
59 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
60 |
+
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
|
61 |
+
|
62 |
+
if not HAS_INITSTATES:
|
63 |
+
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
64 |
+
else:
|
65 |
+
initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
|
66 |
+
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
67 |
+
tl.store(out_ptrs, states, mask=offs_m < dim)
|
68 |
+
out_ptrs += stride_out_chunk
|
69 |
+
seq_idx = 0
|
70 |
+
for c in range(nchunks):
|
71 |
+
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
72 |
+
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
73 |
+
scale = tl.exp(dA_cs)
|
74 |
+
if HAS_SEQ_IDX:
|
75 |
+
seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
|
76 |
+
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
|
77 |
+
seq_idx = seq_idx_new
|
78 |
+
states = scale * states + new_states
|
79 |
+
if c < nchunks - 1:
|
80 |
+
tl.store(out_ptrs, states, mask=offs_m < dim)
|
81 |
+
else:
|
82 |
+
tl.store(final_states_ptrs, states, mask=offs_m < dim)
|
83 |
+
states_ptrs += stride_states_chunk
|
84 |
+
dA_cs_ptr += stride_dA_cs_chunk
|
85 |
+
out_ptrs += stride_out_chunk
|
86 |
+
|
87 |
+
|
88 |
+
@triton.autotune(
|
89 |
+
configs=[
|
90 |
+
triton.Config({'BLOCK_SIZE': 64}),
|
91 |
+
triton.Config({'BLOCK_SIZE': 128}),
|
92 |
+
triton.Config({'BLOCK_SIZE': 256}),
|
93 |
+
triton.Config({'BLOCK_SIZE': 512}),
|
94 |
+
triton.Config({'BLOCK_SIZE': 1024}),
|
95 |
+
triton.Config({'BLOCK_SIZE': 2048}),
|
96 |
+
],
|
97 |
+
key=['dim'],
|
98 |
+
)
|
99 |
+
@triton.jit
|
100 |
+
def _state_passing_bwd_kernel(
|
101 |
+
# Pointers to matrices
|
102 |
+
dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
|
103 |
+
dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
|
104 |
+
# Matrix dimensions
|
105 |
+
dim, nchunks, seqlen, chunk_size,
|
106 |
+
# Strides
|
107 |
+
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
|
108 |
+
stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
|
109 |
+
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
|
110 |
+
stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
|
111 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
112 |
+
stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
|
113 |
+
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
|
114 |
+
stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
|
115 |
+
# Meta-parameters
|
116 |
+
CONVERT_STATES: tl.constexpr,
|
117 |
+
HAS_DFINAL_STATES: tl.constexpr,
|
118 |
+
HAS_DINITSTATES: tl.constexpr,
|
119 |
+
HAS_SEQ_IDX: tl.constexpr,
|
120 |
+
BLOCK_SIZE: tl.constexpr,
|
121 |
+
):
|
122 |
+
pid_b = tl.program_id(axis=1)
|
123 |
+
pid_h = tl.program_id(axis=2)
|
124 |
+
pid_m = tl.program_id(axis=0)
|
125 |
+
dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
|
126 |
+
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
|
127 |
+
ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
|
128 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
|
129 |
+
dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
|
130 |
+
if CONVERT_STATES:
|
131 |
+
states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
|
132 |
+
if HAS_DFINAL_STATES:
|
133 |
+
dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
|
134 |
+
if HAS_DINITSTATES:
|
135 |
+
dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
|
136 |
+
if HAS_SEQ_IDX:
|
137 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
138 |
+
|
139 |
+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
140 |
+
dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
|
141 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
142 |
+
dout_ptrs = dout_ptr + offs_m * stride_dout_dim
|
143 |
+
if CONVERT_STATES:
|
144 |
+
states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
|
145 |
+
|
146 |
+
if HAS_DFINAL_STATES:
|
147 |
+
dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
|
148 |
+
else:
|
149 |
+
dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
150 |
+
tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
|
151 |
+
if HAS_SEQ_IDX:
|
152 |
+
seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
|
153 |
+
dstates_ptrs -= stride_dstates_chunk
|
154 |
+
for c in range(nchunks - 1):
|
155 |
+
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
156 |
+
scale = tl.exp(dA_cs)
|
157 |
+
if HAS_SEQ_IDX:
|
158 |
+
seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
|
159 |
+
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
|
160 |
+
seq_idx = seq_idx_new
|
161 |
+
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
162 |
+
if CONVERT_STATES:
|
163 |
+
tl.store(states_converted_ptrs, out, mask=offs_m < dim)
|
164 |
+
ddA = tl.sum(out * dstates) * scale
|
165 |
+
tl.store(ddA_cs_ptr, ddA)
|
166 |
+
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
167 |
+
dstates = scale * dstates + dout
|
168 |
+
tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
|
169 |
+
dout_ptrs -= stride_dout_chunk
|
170 |
+
dstates_ptrs -= stride_dstates_chunk
|
171 |
+
dA_cs_ptr -= stride_dA_cs_chunk
|
172 |
+
ddA_cs_ptr -= stride_ddA_cs_chunk
|
173 |
+
out_ptrs -= stride_out_chunk
|
174 |
+
if CONVERT_STATES:
|
175 |
+
states_converted_ptrs -= stride_out_chunk
|
176 |
+
if CONVERT_STATES:
|
177 |
+
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
178 |
+
tl.store(states_converted_ptrs, out, mask=offs_m < dim)
|
179 |
+
if not HAS_DINITSTATES:
|
180 |
+
tl.store(ddA_cs_ptr, 0.0)
|
181 |
+
else:
|
182 |
+
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
183 |
+
scale = tl.exp(dA_cs)
|
184 |
+
if HAS_SEQ_IDX:
|
185 |
+
scale = tl.where(seq_idx == 0, scale, 0.0)
|
186 |
+
out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
187 |
+
ddA = tl.sum(out * dstates) * scale
|
188 |
+
tl.store(ddA_cs_ptr, ddA)
|
189 |
+
dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
190 |
+
dstates = scale * dstates + dout
|
191 |
+
tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
|
192 |
+
|
193 |
+
|
194 |
+
def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
|
195 |
+
out_dtype=None):
|
196 |
+
batch, nchunks, nheads, dim = states.shape
|
197 |
+
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
198 |
+
if initial_states is not None:
|
199 |
+
assert initial_states.shape == (batch, nheads, dim)
|
200 |
+
if seq_idx is not None:
|
201 |
+
assert chunk_size is not None
|
202 |
+
seqlen = seq_idx.shape[-1]
|
203 |
+
assert seq_idx.shape == (batch, seqlen)
|
204 |
+
out_dtype = states.dtype if out_dtype is None else out_dtype
|
205 |
+
out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
|
206 |
+
final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
|
207 |
+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
|
208 |
+
with torch.cuda.device(states.device.index):
|
209 |
+
_state_passing_fwd_kernel[grid](
|
210 |
+
states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
|
211 |
+
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
|
212 |
+
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
|
213 |
+
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
|
214 |
+
final_states.stride(0), final_states.stride(1), final_states.stride(2),
|
215 |
+
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
|
216 |
+
*((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
|
217 |
+
if initial_states is not None else (0, 0, 0)),
|
218 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
219 |
+
HAS_INITSTATES=initial_states is not None,
|
220 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
221 |
+
)
|
222 |
+
return out, final_states
|
223 |
+
|
224 |
+
|
225 |
+
def _state_passing_bwd(
|
226 |
+
states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
|
227 |
+
dstates_dtype=None, states_dtype=None, chunk_size=None
|
228 |
+
):
|
229 |
+
"""
|
230 |
+
states contains the initial_states at index 0. The final states are not included in states.
|
231 |
+
"""
|
232 |
+
batch, nchunks, nheads, dim = states.shape
|
233 |
+
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
234 |
+
assert dout.shape == (batch, nchunks, nheads, dim)
|
235 |
+
if seq_idx is not None:
|
236 |
+
assert chunk_size is not None
|
237 |
+
seqlen = seq_idx.shape[-1]
|
238 |
+
assert seq_idx.shape == (batch, seqlen)
|
239 |
+
dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
|
240 |
+
if states_dtype is not None and states_dtype != states.dtype:
|
241 |
+
states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
|
242 |
+
assert states_converted.stride() == states.stride()
|
243 |
+
else:
|
244 |
+
states_converted = None
|
245 |
+
if has_initial_states:
|
246 |
+
dinitstates = torch.empty_like(dstates[:, 0])
|
247 |
+
else:
|
248 |
+
dinitstates = None
|
249 |
+
if dfinal_states is not None:
|
250 |
+
assert dfinal_states.shape == (batch, nheads, dim)
|
251 |
+
BLOCK_SIZE_min = 64
|
252 |
+
n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
|
253 |
+
ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
|
254 |
+
dtype=torch.float32, device=dA_chunk_cumsum.device)
|
255 |
+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
|
256 |
+
with torch.cuda.device(dout.device.index):
|
257 |
+
_state_passing_bwd_kernel[grid](
|
258 |
+
dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
|
259 |
+
dstates, ddA_chunk_cumsum, dinitstates, states_converted,
|
260 |
+
dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
|
261 |
+
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
|
262 |
+
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
|
263 |
+
dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
|
264 |
+
*((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
|
265 |
+
if dfinal_states is not None else (0, 0, 0)),
|
266 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
267 |
+
dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
|
268 |
+
ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
|
269 |
+
*((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
|
270 |
+
if dinitstates is not None else (0, 0, 0)),
|
271 |
+
CONVERT_STATES=states_converted is not None,
|
272 |
+
HAS_DFINAL_STATES=dfinal_states is not None,
|
273 |
+
HAS_DINITSTATES=dinitstates is not None,
|
274 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
275 |
+
)
|
276 |
+
BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
|
277 |
+
n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
|
278 |
+
ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
|
279 |
+
if states_dtype is not None and states_dtype == states.dtype:
|
280 |
+
states_converted = states
|
281 |
+
return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
|
282 |
+
|
283 |
+
|
284 |
+
class StatePassingFn(torch.autograd.Function):
|
285 |
+
|
286 |
+
@staticmethod
|
287 |
+
def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
|
288 |
+
batch, nchunks, nheads, dim = states.shape
|
289 |
+
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
290 |
+
if states.stride(-1) != 1:
|
291 |
+
states = states.contiguous()
|
292 |
+
out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
|
293 |
+
ctx.save_for_backward(out, dA_chunk_cumsum)
|
294 |
+
ctx.has_initial_states = initial_states is not None
|
295 |
+
return out, final_states
|
296 |
+
|
297 |
+
@staticmethod
|
298 |
+
def backward(ctx, dout, dfinal_states):
|
299 |
+
out, dA_chunk_cumsum = ctx.saved_tensors
|
300 |
+
batch, nchunks, nheads, dim = out.shape
|
301 |
+
assert dout.shape == (batch, nchunks, nheads, dim)
|
302 |
+
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
303 |
+
assert dfinal_states.shape == (batch, nheads, dim)
|
304 |
+
if dout.stride(-1) != 1:
|
305 |
+
dout = dout.contiguous()
|
306 |
+
dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
|
307 |
+
out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
|
308 |
+
)
|
309 |
+
return dstates, ddA_chunk_cumsum, dinitstates
|
310 |
+
|
311 |
+
|
312 |
+
def state_passing(states, dA_chunk_cumsum, initial_states=None):
|
313 |
+
"""
|
314 |
+
Argument:
|
315 |
+
states: (batch, nchunks, nheads, dim)
|
316 |
+
dA_chunk_cumsum: (batch, nheads, nchunks)
|
317 |
+
initial_states: (batch, nheads, dim)
|
318 |
+
Return:
|
319 |
+
out: (batch, nchunks, nheads, dim)
|
320 |
+
final_states: (batch, nheads, dim)
|
321 |
+
"""
|
322 |
+
return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
|
323 |
+
|
324 |
+
|
325 |
+
def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
|
326 |
+
"""
|
327 |
+
Argument:
|
328 |
+
states: (batch, nchunks, nheads, dim)
|
329 |
+
dA_chunk_cumsum: (batch, nheads, nchunks)
|
330 |
+
initial_states: (batch, nheads, dim)
|
331 |
+
Return:
|
332 |
+
out: (batch, nchunks, nheads, dim)
|
333 |
+
final_states: (batch, nheads, dim)
|
334 |
+
"""
|
335 |
+
if initial_states is None:
|
336 |
+
initial_states = torch.zeros_like(states[:, 0])
|
337 |
+
states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
|
338 |
+
dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
|
339 |
+
dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
|
340 |
+
nchunks = dA_chunk_cumsum.shape[-1]
|
341 |
+
# (batch, nheads, nchunks, nchunks)
|
342 |
+
dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
|
343 |
+
# (batch, nheads, nchunks, nchunks)
|
344 |
+
decay_chunk = torch.exp(dt_chunk_segment_sum)
|
345 |
+
causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
|
346 |
+
decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
|
347 |
+
out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
|
348 |
+
return out[:, :-1], out[:, -1]
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/generation.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
import gc
|
3 |
+
import time
|
4 |
+
from collections import namedtuple
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from functools import partial
|
7 |
+
from typing import Callable, Optional, Sequence, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from torch import Tensor
|
13 |
+
from torch.profiler import ProfilerActivity, profile, record_function
|
14 |
+
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class InferenceParams:
|
19 |
+
"""Inference parameters that are passed to the main model in order
|
20 |
+
to efficienly calculate and store the context during inference."""
|
21 |
+
|
22 |
+
max_seqlen: int
|
23 |
+
max_batch_size: int
|
24 |
+
seqlen_offset: int = 0
|
25 |
+
batch_size_offset: int = 0
|
26 |
+
key_value_memory_dict: dict = field(default_factory=dict)
|
27 |
+
lengths_per_sample: Optional[Tensor] = None
|
28 |
+
|
29 |
+
def reset(self, max_seqlen, max_batch_size):
|
30 |
+
self.max_seqlen = max_seqlen
|
31 |
+
self.max_batch_size = max_batch_size
|
32 |
+
self.seqlen_offset = 0
|
33 |
+
if self.lengths_per_sample is not None:
|
34 |
+
self.lengths_per_sample.zero_()
|
35 |
+
|
36 |
+
|
37 |
+
def modify_logits_for_min_p_filtering(logits, min_p):
|
38 |
+
"""Set the logits for none min_p values to -inf. Done in-place."""
|
39 |
+
if min_p <= 0.0 or min_p >= 1.0:
|
40 |
+
return
|
41 |
+
indices_to_remove = logits < min_p
|
42 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
43 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
44 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
45 |
+
def modify_logits_for_top_k_filtering(logits, top_k):
|
46 |
+
"""Set the logits for none top-k values to -inf. Done in-place."""
|
47 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
48 |
+
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
49 |
+
|
50 |
+
|
51 |
+
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
52 |
+
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
53 |
+
def modify_logits_for_top_p_filtering(logits, top_p):
|
54 |
+
"""Set the logits for none top-p values to -inf. Done in-place."""
|
55 |
+
if top_p <= 0.0 or top_p >= 1.0:
|
56 |
+
return
|
57 |
+
# First sort and calculate cumulative sum of probabilities.
|
58 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
59 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
60 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
61 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
62 |
+
# scatter sorted tensors to original indexing
|
63 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
64 |
+
1, sorted_indices, sorted_indices_to_remove
|
65 |
+
)
|
66 |
+
logits.masked_fill_(indices_to_remove, float("-inf"))
|
67 |
+
|
68 |
+
|
69 |
+
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
|
70 |
+
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
71 |
+
logits: (batch_size, vocab_size)
|
72 |
+
prev_output_tokens: (batch_size, seq_len)
|
73 |
+
"""
|
74 |
+
if repetition_penalty == 1.0:
|
75 |
+
return logits
|
76 |
+
score = torch.gather(logits, 1, prev_output_tokens)
|
77 |
+
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
78 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
79 |
+
logits.scatter_(1, prev_output_tokens, score)
|
80 |
+
return logits
|
81 |
+
|
82 |
+
|
83 |
+
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
|
84 |
+
"""Sample from top-k logits.
|
85 |
+
Arguments:
|
86 |
+
logits: Tensor of shape (batch_size, vocab_size)
|
87 |
+
"""
|
88 |
+
if top_k == 1: # Short-circuit for greedy decoding
|
89 |
+
return logits.argmax(dim=-1)
|
90 |
+
else:
|
91 |
+
if top_p > 0.0:
|
92 |
+
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
93 |
+
if top_k > 0:
|
94 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
95 |
+
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
96 |
+
if temperature != 1.0:
|
97 |
+
logits_top /= temperature
|
98 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
99 |
+
return indices[
|
100 |
+
torch.arange(indices.shape[0], device=indices.device),
|
101 |
+
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
102 |
+
]
|
103 |
+
else:
|
104 |
+
if min_p > 0.0:
|
105 |
+
logits_top = logits.clone()
|
106 |
+
max_prob = logits_top[..., 0].item()
|
107 |
+
min_prob = max_prob * min_p
|
108 |
+
modify_logits_for_min_p_filtering(logits_top, min_prob)
|
109 |
+
if temperature != 1.0:
|
110 |
+
logits_top /= temperature
|
111 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
112 |
+
# Clone so that when we modify for top_p we don't change the original logits
|
113 |
+
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
114 |
+
modify_logits_for_top_p_filtering(logits_top, top_p)
|
115 |
+
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
116 |
+
dim=-1
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
@torch.inference_mode()
|
121 |
+
def decode(
|
122 |
+
input_ids,
|
123 |
+
model,
|
124 |
+
max_length,
|
125 |
+
top_k=1,
|
126 |
+
top_p=0.0,
|
127 |
+
min_p=0.0,
|
128 |
+
temperature=1.0,
|
129 |
+
repetition_penalty=1.0,
|
130 |
+
eos_token_id=None,
|
131 |
+
teacher_outputs=None,
|
132 |
+
vocab_size=None,
|
133 |
+
cg=False,
|
134 |
+
enable_timing=False,
|
135 |
+
output_scores=False,
|
136 |
+
streamer: Optional[TextStreamer] = None
|
137 |
+
):
|
138 |
+
"""Decoding, either greedy or with top-k or top-p sampling.
|
139 |
+
If top-k = 0, don't limit the number of candidates (pure sampling).
|
140 |
+
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
141 |
+
then top-p.
|
142 |
+
We assume that all sequences in the same batch have the same length.
|
143 |
+
|
144 |
+
Arguments:
|
145 |
+
input_ids: (batch, seq_len)
|
146 |
+
max_length: int
|
147 |
+
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
148 |
+
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
149 |
+
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
150 |
+
sequences: (batch, max_length)
|
151 |
+
scores: tuples of (batch, vocab_size)
|
152 |
+
"""
|
153 |
+
if streamer is not None:
|
154 |
+
streamer.put(input_ids.cpu())
|
155 |
+
|
156 |
+
batch_size, seqlen_og = input_ids.shape
|
157 |
+
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
158 |
+
if cg:
|
159 |
+
if not hasattr(model, "_decoding_cache"):
|
160 |
+
model._decoding_cache = None
|
161 |
+
model._decoding_cache = update_graph_cache(
|
162 |
+
model,
|
163 |
+
model._decoding_cache,
|
164 |
+
batch_size,
|
165 |
+
seqlen_og,
|
166 |
+
max_length,
|
167 |
+
)
|
168 |
+
inference_params = model._decoding_cache.inference_params
|
169 |
+
inference_params.reset(max_length, batch_size)
|
170 |
+
else:
|
171 |
+
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
172 |
+
|
173 |
+
def get_logits(input_ids, inference_params):
|
174 |
+
decoding = inference_params.seqlen_offset > 0
|
175 |
+
if decoding:
|
176 |
+
position_ids = torch.full(
|
177 |
+
(batch_size, 1),
|
178 |
+
inference_params.seqlen_offset,
|
179 |
+
dtype=torch.long,
|
180 |
+
device=input_ids.device,
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
position_ids = None
|
184 |
+
if not cg or not decoding:
|
185 |
+
logits = model(
|
186 |
+
input_ids,
|
187 |
+
position_ids=position_ids,
|
188 |
+
inference_params=inference_params,
|
189 |
+
num_last_tokens=1,
|
190 |
+
).logits.squeeze(dim=1)
|
191 |
+
else:
|
192 |
+
logits = model._decoding_cache.run(
|
193 |
+
input_ids, position_ids, inference_params.seqlen_offset
|
194 |
+
).squeeze(dim=1)
|
195 |
+
return logits[..., :vocab_size] if vocab_size is not None else logits
|
196 |
+
|
197 |
+
def sample_tokens(logits, inference_params):
|
198 |
+
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
|
199 |
+
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
|
200 |
+
else:
|
201 |
+
token = teacher_outputs[:, inference_params.seqlen_offset]
|
202 |
+
# return rearrange(token, "b -> b 1")
|
203 |
+
return token.unsqueeze(1)
|
204 |
+
|
205 |
+
def should_stop(current_token, inference_params):
|
206 |
+
if inference_params.seqlen_offset == 0:
|
207 |
+
return False
|
208 |
+
if eos_token_id is not None and (current_token == eos_token_id).all():
|
209 |
+
return True
|
210 |
+
if inference_params.seqlen_offset >= max_length - 1:
|
211 |
+
return True
|
212 |
+
return False
|
213 |
+
|
214 |
+
start = torch.cuda.Event(enable_timing=enable_timing)
|
215 |
+
end = torch.cuda.Event(enable_timing=enable_timing)
|
216 |
+
|
217 |
+
if enable_timing:
|
218 |
+
start.record()
|
219 |
+
scores, sequences = [], [input_ids]
|
220 |
+
sequences_cat = input_ids
|
221 |
+
while not should_stop(sequences[-1], inference_params):
|
222 |
+
logits = get_logits(sequences[-1], inference_params)
|
223 |
+
if output_scores:
|
224 |
+
scores.append(logits.clone())
|
225 |
+
inference_params.seqlen_offset += sequences[-1].shape[1]
|
226 |
+
if repetition_penalty == 1.0:
|
227 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
228 |
+
else:
|
229 |
+
logits = modify_logit_for_repetition_penalty(
|
230 |
+
logits, sequences_cat, repetition_penalty
|
231 |
+
)
|
232 |
+
sampled_tokens = sample_tokens(logits, inference_params)
|
233 |
+
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
234 |
+
sequences.append(sampled_tokens)
|
235 |
+
if streamer is not None:
|
236 |
+
streamer.put(sampled_tokens.cpu())
|
237 |
+
if streamer is not None:
|
238 |
+
streamer.end()
|
239 |
+
if enable_timing:
|
240 |
+
end.record()
|
241 |
+
torch.cuda.synchronize()
|
242 |
+
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
|
243 |
+
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
244 |
+
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
245 |
+
|
246 |
+
|
247 |
+
class GenerationMixin:
|
248 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
249 |
+
raise NotImplementedError
|
250 |
+
|
251 |
+
def generate(
|
252 |
+
self,
|
253 |
+
input_ids,
|
254 |
+
max_length,
|
255 |
+
top_k=1,
|
256 |
+
top_p=0.0,
|
257 |
+
min_p=0.0,
|
258 |
+
temperature=1.0,
|
259 |
+
return_dict_in_generate=False,
|
260 |
+
output_scores=False,
|
261 |
+
**kwargs,
|
262 |
+
):
|
263 |
+
output = decode(
|
264 |
+
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
|
265 |
+
)
|
266 |
+
if not output_scores:
|
267 |
+
output.scores = None
|
268 |
+
return output if return_dict_in_generate else output.sequences
|
269 |
+
|
270 |
+
|
271 |
+
@dataclass
|
272 |
+
class DecodingCGCache:
|
273 |
+
max_batch_size: int = 0
|
274 |
+
max_seqlen: int = 0
|
275 |
+
device = None
|
276 |
+
dtype = None
|
277 |
+
callables: dict = field(default_factory=dict)
|
278 |
+
mempool = None
|
279 |
+
inference_params: Optional[InferenceParams] = None
|
280 |
+
run: Optional[Callable] = None
|
281 |
+
|
282 |
+
|
283 |
+
@torch.inference_mode()
|
284 |
+
def update_graph_cache(
|
285 |
+
model,
|
286 |
+
cache,
|
287 |
+
batch_size,
|
288 |
+
seqlen_og,
|
289 |
+
max_seqlen,
|
290 |
+
decoding_seqlens=(1,),
|
291 |
+
dtype=None,
|
292 |
+
n_warmups=2,
|
293 |
+
):
|
294 |
+
if cache is None:
|
295 |
+
cache = DecodingCGCache()
|
296 |
+
param_example = next(iter(model.parameters()))
|
297 |
+
device = param_example.device
|
298 |
+
if dtype is None:
|
299 |
+
dtype = param_example.dtype
|
300 |
+
if (
|
301 |
+
(device, dtype) != (cache.device, cache.dtype)
|
302 |
+
or batch_size > cache.max_batch_size
|
303 |
+
or max_seqlen > cache.max_seqlen
|
304 |
+
): # Invalidate the cache
|
305 |
+
cache.callables = {}
|
306 |
+
cache.mempool = None
|
307 |
+
cache.inference_params = None
|
308 |
+
gc.collect()
|
309 |
+
cache.device, cache.dtype = device, dtype
|
310 |
+
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
311 |
+
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
|
312 |
+
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
|
313 |
+
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
314 |
+
cache.inference_params = InferenceParams(
|
315 |
+
max_seqlen=max_seqlen,
|
316 |
+
max_batch_size=batch_size,
|
317 |
+
seqlen_offset=seqlen_og,
|
318 |
+
key_value_memory_dict=inf_cache,
|
319 |
+
lengths_per_sample=lengths_per_sample,
|
320 |
+
)
|
321 |
+
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
322 |
+
for decoding_seqlen in decoding_seqlens:
|
323 |
+
if (batch_size, decoding_seqlen) not in cache.callables:
|
324 |
+
cache.callables[batch_size, decoding_seqlen] = capture_graph(
|
325 |
+
model,
|
326 |
+
cache.inference_params,
|
327 |
+
batch_size,
|
328 |
+
max_seqlen,
|
329 |
+
decoding_seqlen=decoding_seqlen,
|
330 |
+
mempool=cache.mempool,
|
331 |
+
n_warmups=n_warmups,
|
332 |
+
)
|
333 |
+
|
334 |
+
def dispatch(input_ids, position_ids, seqlen):
|
335 |
+
batch_size, decoding_seqlen = input_ids.shape[:2]
|
336 |
+
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
|
337 |
+
|
338 |
+
cache.run = dispatch
|
339 |
+
cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
|
340 |
+
return cache
|
341 |
+
|
342 |
+
|
343 |
+
def capture_graph(
|
344 |
+
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
|
345 |
+
):
|
346 |
+
device = next(iter(model.parameters())).device
|
347 |
+
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
348 |
+
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
|
349 |
+
seqlen_offset_og = inference_params.seqlen_offset
|
350 |
+
inference_params.seqlen_offset = max_seqlen - decoding_seqlen
|
351 |
+
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
|
352 |
+
|
353 |
+
# Warmup before capture
|
354 |
+
s = torch.cuda.Stream()
|
355 |
+
s.wait_stream(torch.cuda.current_stream())
|
356 |
+
with torch.cuda.stream(s):
|
357 |
+
for _ in range(n_warmups):
|
358 |
+
logits = model(
|
359 |
+
input_ids,
|
360 |
+
position_ids=position_ids,
|
361 |
+
inference_params=inference_params,
|
362 |
+
num_last_tokens=decoding_seqlen,
|
363 |
+
).logits
|
364 |
+
s.synchronize()
|
365 |
+
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
366 |
+
# which requires that graph launch and non-captured launch to not overlap (I think,
|
367 |
+
# that's how I interpret the documentation). I'm not sure if this is required.
|
368 |
+
if torch.distributed.is_initialized():
|
369 |
+
torch.distributed.barrier()
|
370 |
+
torch.cuda.current_stream().wait_stream(s)
|
371 |
+
# Captures the graph
|
372 |
+
# To allow capture, automatically sets a side stream as the current stream in the context
|
373 |
+
graph = torch.cuda.CUDAGraph()
|
374 |
+
with torch.cuda.graph(graph, pool=mempool):
|
375 |
+
logits = model(
|
376 |
+
input_ids,
|
377 |
+
position_ids=position_ids,
|
378 |
+
inference_params=inference_params,
|
379 |
+
num_last_tokens=decoding_seqlen,
|
380 |
+
).logits
|
381 |
+
|
382 |
+
def run(new_input_ids, new_position_ids, seqlen):
|
383 |
+
inference_params.lengths_per_sample[:] = seqlen
|
384 |
+
input_ids.copy_(new_input_ids)
|
385 |
+
position_ids.copy_(new_position_ids)
|
386 |
+
graph.replay()
|
387 |
+
return logits.clone()
|
388 |
+
|
389 |
+
inference_params.seqlen_offset = seqlen_offset_og
|
390 |
+
return run
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/hf.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
6 |
+
from transformers.utils.hub import cached_file
|
7 |
+
|
8 |
+
|
9 |
+
def load_config_hf(model_name):
|
10 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
11 |
+
return json.load(open(resolved_archive_file))
|
12 |
+
|
13 |
+
|
14 |
+
def load_state_dict_hf(model_name, device=None, dtype=None):
|
15 |
+
# If not fp32, then we don't want to load directly to the GPU
|
16 |
+
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
17 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
18 |
+
return torch.load(resolved_archive_file, map_location=mapped_device)
|
19 |
+
# Convert dtype before moving to GPU to save memory
|
20 |
+
if dtype is not None:
|
21 |
+
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
|
22 |
+
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
23 |
+
return state_dict
|
build/torch25-cxx11-cu118-x86_64-linux/mamba_ssm/utils/torch.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from functools import partial
|
3 |
+
from typing import Callable
|
4 |
+
|
5 |
+
def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
|
6 |
+
def decorator(*args, **kwargs):
|
7 |
+
if cuda_amp_deprecated:
|
8 |
+
kwargs["device_type"] = "cuda"
|
9 |
+
return dec(*args, **kwargs)
|
10 |
+
return decorator
|
11 |
+
|
12 |
+
|
13 |
+
if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
|
14 |
+
deprecated = True
|
15 |
+
from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
|
16 |
+
else:
|
17 |
+
deprecated = False
|
18 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
19 |
+
|
20 |
+
custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
|
21 |
+
custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "2.2.4"
|
2 |
+
|
3 |
+
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
4 |
+
from .modules.mamba_simple import Mamba
|
5 |
+
from .modules.mamba2 import Mamba2
|
6 |
+
from .models.mixer_seq_simple import MambaLMHeadModel
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"selective_scan_fn",
|
10 |
+
"mamba_inner_fn",
|
11 |
+
"Mamba",
|
12 |
+
"Mamba2",
|
13 |
+
"MambaLMHeadModel",
|
14 |
+
]
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_mamba_ssm_nmrmresto7zfi.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:854cfdd1c899869de1c88a6a56de1494a3d4a0edd1a04412167599485bc1093e
|
3 |
+
size 247806288
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _mamba_ssm_nmrmresto7zfi
|
3 |
+
ops = torch.ops._mamba_ssm_nmrmresto7zfi
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_mamba_ssm_nmrmresto7zfi::{op_name}"
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/distributed_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.distributed import ProcessGroup
|
6 |
+
|
7 |
+
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
8 |
+
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
9 |
+
# version of PyTorch. The following 4 lines are for backward compatibility with
|
10 |
+
# older PyTorch.
|
11 |
+
if "all_gather_into_tensor" not in dir(torch.distributed):
|
12 |
+
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
13 |
+
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
14 |
+
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
15 |
+
|
16 |
+
|
17 |
+
# Raw operation, does not support autograd, but does support async
|
18 |
+
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
19 |
+
world_size = torch.distributed.get_world_size(process_group)
|
20 |
+
output = torch.empty(
|
21 |
+
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
22 |
+
)
|
23 |
+
handle = torch.distributed.all_gather_into_tensor(
|
24 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
25 |
+
)
|
26 |
+
return output, handle
|
27 |
+
|
28 |
+
|
29 |
+
# Raw operation, does not support autograd, but does support async
|
30 |
+
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
31 |
+
world_size = torch.distributed.get_world_size(process_group)
|
32 |
+
assert input_.shape[0] % world_size == 0
|
33 |
+
output = torch.empty(
|
34 |
+
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
35 |
+
)
|
36 |
+
handle = torch.distributed.reduce_scatter_tensor(
|
37 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
38 |
+
)
|
39 |
+
return output, handle
|
40 |
+
|
41 |
+
|
42 |
+
# Raw operation, does not support autograd, but does support async
|
43 |
+
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
44 |
+
input_ = input_.contiguous()
|
45 |
+
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
46 |
+
return input_, handle
|
47 |
+
|
48 |
+
|
49 |
+
class AllGatherFunc(torch.autograd.Function):
|
50 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
54 |
+
ctx.process_group = process_group
|
55 |
+
output, _ = all_gather_raw(input_, process_group)
|
56 |
+
return output
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def backward(ctx, grad_output: Tensor):
|
60 |
+
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
61 |
+
return grad_input, None
|
62 |
+
|
63 |
+
|
64 |
+
# Supports autograd, but does not support async
|
65 |
+
all_gather = AllGatherFunc.apply
|
66 |
+
|
67 |
+
|
68 |
+
class ReduceScatterFunc(torch.autograd.Function):
|
69 |
+
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
73 |
+
ctx.process_group = process_group
|
74 |
+
output, _ = reduce_scatter_raw(input_, process_group)
|
75 |
+
return output
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def backward(ctx, grad_output: Tensor):
|
79 |
+
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
80 |
+
return grad_input, None
|
81 |
+
|
82 |
+
|
83 |
+
# Supports autograd, but does not support async
|
84 |
+
reduce_scatter = ReduceScatterFunc.apply
|
85 |
+
|
86 |
+
|
87 |
+
class AllReduceFunc(torch.autograd.Function):
|
88 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
92 |
+
ctx.process_group = process_group
|
93 |
+
output, _ = all_reduce_raw(input_, process_group)
|
94 |
+
return output
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def backward(ctx, grad_output: Tensor):
|
98 |
+
return grad_output, None
|
99 |
+
|
100 |
+
|
101 |
+
# Supports autograd, but does not support async
|
102 |
+
all_reduce = AllReduceFunc.apply
|
103 |
+
|
104 |
+
|
105 |
+
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
106 |
+
# We want to iterate over parameters with _shared_params=True in the same order,
|
107 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
108 |
+
pamams_shared = {
|
109 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
|
110 |
+
}
|
111 |
+
for _, p in sorted(pamams_shared.items()):
|
112 |
+
with torch.no_grad():
|
113 |
+
# Broadcast needs src to be global rank, not group rank
|
114 |
+
torch.distributed.broadcast(
|
115 |
+
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
120 |
+
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
121 |
+
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
122 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
123 |
+
params_seqparallel = {
|
124 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
|
125 |
+
}
|
126 |
+
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
127 |
+
if grads:
|
128 |
+
with torch.no_grad():
|
129 |
+
coalesced = torch._utils._flatten_dense_tensors(grads)
|
130 |
+
torch.distributed.all_reduce(coalesced, group=process_group)
|
131 |
+
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
132 |
+
buf.copy_(synced)
|
133 |
+
|
134 |
+
|
135 |
+
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
136 |
+
"""Get the dim for the local rank derived from splitting dim on world_size processes.
|
137 |
+
|
138 |
+
The split may not be even across the world_size processes.
|
139 |
+
"""
|
140 |
+
multiple = dim // multiple_of
|
141 |
+
div = multiple // world_size
|
142 |
+
mod = multiple % world_size
|
143 |
+
local_multiple = div + int(local_rank < mod)
|
144 |
+
return local_multiple * multiple_of
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/distributed/tensor_parallel.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.distributed import ProcessGroup
|
10 |
+
from ..utils.torch import custom_bwd, custom_fwd
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from ..distributed.distributed_utils import (
|
15 |
+
all_gather_raw,
|
16 |
+
all_reduce,
|
17 |
+
all_reduce_raw,
|
18 |
+
reduce_scatter,
|
19 |
+
reduce_scatter_raw,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class ParallelLinearFunc(torch.autograd.Function):
|
24 |
+
@staticmethod
|
25 |
+
@custom_fwd
|
26 |
+
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
27 |
+
"""
|
28 |
+
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
29 |
+
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
30 |
+
"""
|
31 |
+
ctx.compute_weight_gradient = weight.requires_grad
|
32 |
+
ctx.process_group = process_group
|
33 |
+
ctx.sequence_parallel = sequence_parallel
|
34 |
+
|
35 |
+
if torch.is_autocast_enabled():
|
36 |
+
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
37 |
+
x = x.contiguous()
|
38 |
+
if process_group is not None and sequence_parallel:
|
39 |
+
# We want to kick off the all_gather early, before weight dtype conversion
|
40 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
41 |
+
else:
|
42 |
+
total_x = x
|
43 |
+
|
44 |
+
if torch.is_autocast_enabled():
|
45 |
+
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
46 |
+
bias = (
|
47 |
+
bias.to(dtype=torch.get_autocast_gpu_dtype())
|
48 |
+
if bias is not None
|
49 |
+
else None
|
50 |
+
)
|
51 |
+
weight = weight.contiguous()
|
52 |
+
if process_group is not None and sequence_parallel:
|
53 |
+
handle_x.wait()
|
54 |
+
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
55 |
+
batch_dim = batch_shape.numel()
|
56 |
+
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
57 |
+
output = F.linear(total_x, weight, bias)
|
58 |
+
if ctx.compute_weight_gradient:
|
59 |
+
ctx.save_for_backward(x, weight)
|
60 |
+
else:
|
61 |
+
ctx.save_for_backward(weight)
|
62 |
+
return output
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
@custom_bwd
|
66 |
+
def backward(ctx, grad_output):
|
67 |
+
grad_output = grad_output.contiguous()
|
68 |
+
process_group = ctx.process_group
|
69 |
+
sequence_parallel = ctx.sequence_parallel
|
70 |
+
if ctx.compute_weight_gradient:
|
71 |
+
x, weight = ctx.saved_tensors
|
72 |
+
if process_group is not None and sequence_parallel:
|
73 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
74 |
+
else:
|
75 |
+
total_x = x
|
76 |
+
else:
|
77 |
+
(weight,) = ctx.saved_tensors
|
78 |
+
total_x = None
|
79 |
+
batch_shape = grad_output.shape[:-1]
|
80 |
+
batch_dim = batch_shape.numel()
|
81 |
+
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
82 |
+
if ctx.needs_input_grad[0]:
|
83 |
+
grad_input = F.linear(grad_output, weight.t())
|
84 |
+
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
85 |
+
if process_group is not None:
|
86 |
+
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
87 |
+
grad_input, handle_grad_input = reduce_fn(
|
88 |
+
grad_input, process_group, async_op=True
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
grad_input = None
|
92 |
+
if ctx.needs_input_grad[1]:
|
93 |
+
assert ctx.compute_weight_gradient
|
94 |
+
if process_group is not None and sequence_parallel:
|
95 |
+
handle_x.wait()
|
96 |
+
grad_weight = torch.einsum(
|
97 |
+
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
grad_weight = None
|
101 |
+
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
102 |
+
if process_group is not None and ctx.needs_input_grad[0]:
|
103 |
+
handle_grad_input.wait()
|
104 |
+
return grad_input, grad_weight, grad_bias, None, None
|
105 |
+
|
106 |
+
|
107 |
+
def parallel_linear_func(
|
108 |
+
x: Tensor,
|
109 |
+
weight: Tensor,
|
110 |
+
bias: Optional[Tensor] = None,
|
111 |
+
process_group: Optional[ProcessGroup] = None,
|
112 |
+
sequence_parallel: bool = True,
|
113 |
+
):
|
114 |
+
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
115 |
+
|
116 |
+
|
117 |
+
class ColumnParallelLinear(nn.Linear):
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
in_features: int,
|
121 |
+
out_features: int,
|
122 |
+
process_group: ProcessGroup,
|
123 |
+
bias: bool = True,
|
124 |
+
sequence_parallel=True,
|
125 |
+
multiple_of=1,
|
126 |
+
device=None,
|
127 |
+
dtype=None,
|
128 |
+
) -> None:
|
129 |
+
world_size = torch.distributed.get_world_size(process_group)
|
130 |
+
if out_features % multiple_of:
|
131 |
+
raise ValueError(
|
132 |
+
f"out_features ({out_features}) must be a multiple of {multiple_of}"
|
133 |
+
)
|
134 |
+
multiple = out_features // multiple_of
|
135 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
136 |
+
div = multiple // world_size
|
137 |
+
mod = multiple % world_size
|
138 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
139 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
140 |
+
super().__init__(
|
141 |
+
in_features,
|
142 |
+
local_multiple * multiple_of,
|
143 |
+
bias=bias,
|
144 |
+
device=device,
|
145 |
+
dtype=dtype,
|
146 |
+
)
|
147 |
+
self.process_group = process_group
|
148 |
+
self.sequence_parallel = sequence_parallel
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
152 |
+
# we do an all_gather of x before doing the matmul.
|
153 |
+
# If not, then the input is already gathered.
|
154 |
+
return parallel_linear_func(
|
155 |
+
x,
|
156 |
+
self.weight,
|
157 |
+
self.bias,
|
158 |
+
process_group=self.process_group,
|
159 |
+
sequence_parallel=self.sequence_parallel,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
class RowParallelLinear(nn.Linear):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
in_features: int,
|
167 |
+
out_features: int,
|
168 |
+
process_group: ProcessGroup,
|
169 |
+
bias: bool = True,
|
170 |
+
sequence_parallel=True,
|
171 |
+
multiple_of=1,
|
172 |
+
device=None,
|
173 |
+
dtype=None,
|
174 |
+
) -> None:
|
175 |
+
world_size = torch.distributed.get_world_size(process_group)
|
176 |
+
rank = torch.distributed.get_rank(process_group)
|
177 |
+
if in_features % multiple_of:
|
178 |
+
raise ValueError(
|
179 |
+
f"in_features ({in_features}) must be a multiple of {multiple_of}"
|
180 |
+
)
|
181 |
+
multiple = in_features // multiple_of
|
182 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
183 |
+
div = multiple // world_size
|
184 |
+
mod = multiple % world_size
|
185 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
186 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
187 |
+
# Only rank 0 will have bias
|
188 |
+
super().__init__(
|
189 |
+
local_multiple * multiple_of,
|
190 |
+
out_features,
|
191 |
+
bias=bias and rank == 0,
|
192 |
+
device=device,
|
193 |
+
dtype=dtype,
|
194 |
+
)
|
195 |
+
self.process_group = process_group
|
196 |
+
self.sequence_parallel = sequence_parallel
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
"""
|
200 |
+
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
201 |
+
a reduce_scatter of the result.
|
202 |
+
"""
|
203 |
+
out = parallel_linear_func(x, self.weight, self.bias)
|
204 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
205 |
+
return reduce_fn(out, self.process_group)
|
206 |
+
|
207 |
+
|
208 |
+
class VocabParallelEmbedding(nn.Embedding):
|
209 |
+
def __init__(
|
210 |
+
self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
|
211 |
+
):
|
212 |
+
self.process_group = process_group
|
213 |
+
if process_group is not None:
|
214 |
+
world_size = torch.distributed.get_world_size(process_group)
|
215 |
+
if num_embeddings % world_size != 0:
|
216 |
+
raise ValueError(
|
217 |
+
f"num_embeddings ({num_embeddings}) must be divisible by "
|
218 |
+
f"world_size ({world_size})"
|
219 |
+
)
|
220 |
+
if world_size > 1 and padding_idx is not None:
|
221 |
+
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
222 |
+
else:
|
223 |
+
world_size = 1
|
224 |
+
super().__init__(
|
225 |
+
num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
|
226 |
+
)
|
227 |
+
|
228 |
+
def forward(self, input: Tensor) -> Tensor:
|
229 |
+
if self.process_group is None:
|
230 |
+
return super().forward(input)
|
231 |
+
else:
|
232 |
+
rank = torch.distributed.get_rank(self.process_group)
|
233 |
+
vocab_size = self.num_embeddings
|
234 |
+
vocab_start_index, vocab_end_index = (
|
235 |
+
rank * vocab_size,
|
236 |
+
(rank + 1) * vocab_size,
|
237 |
+
)
|
238 |
+
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
239 |
+
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
240 |
+
input = input - vocab_start_index
|
241 |
+
input[input_ids_mask] = 0
|
242 |
+
embeddings = super().forward(input)
|
243 |
+
embeddings[input_ids_mask] = 0.0
|
244 |
+
return embeddings
|
245 |
+
|
246 |
+
|
247 |
+
class ColumnParallelEmbedding(nn.Embedding):
|
248 |
+
def __init__(
|
249 |
+
self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
|
250 |
+
):
|
251 |
+
self.process_group = process_group
|
252 |
+
if process_group is not None:
|
253 |
+
world_size = torch.distributed.get_world_size(process_group)
|
254 |
+
if embedding_dim % world_size != 0:
|
255 |
+
raise ValueError(
|
256 |
+
f"embedding_dim ({embedding_dim}) must be divisible by "
|
257 |
+
f"world_size ({world_size})"
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
world_size = 1
|
261 |
+
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
262 |
+
|
263 |
+
|
264 |
+
class ParallelEmbeddings(nn.Module):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
embed_dim,
|
268 |
+
vocab_size,
|
269 |
+
max_position_embeddings,
|
270 |
+
process_group,
|
271 |
+
padding_idx=None,
|
272 |
+
sequence_parallel=True,
|
273 |
+
device=None,
|
274 |
+
dtype=None,
|
275 |
+
):
|
276 |
+
"""
|
277 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
278 |
+
"""
|
279 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
280 |
+
super().__init__()
|
281 |
+
self.process_group = process_group
|
282 |
+
self.sequence_parallel = sequence_parallel
|
283 |
+
self.word_embeddings = VocabParallelEmbedding(
|
284 |
+
vocab_size,
|
285 |
+
embed_dim,
|
286 |
+
padding_idx=padding_idx,
|
287 |
+
process_group=process_group,
|
288 |
+
**factory_kwargs,
|
289 |
+
)
|
290 |
+
self.max_position_embeddings = max_position_embeddings
|
291 |
+
if self.max_position_embeddings > 0:
|
292 |
+
self.position_embeddings = ColumnParallelEmbedding(
|
293 |
+
max_position_embeddings,
|
294 |
+
embed_dim,
|
295 |
+
process_group=process_group,
|
296 |
+
**factory_kwargs,
|
297 |
+
)
|
298 |
+
|
299 |
+
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
300 |
+
"""
|
301 |
+
input_ids: (batch, seqlen)
|
302 |
+
position_ids: (batch, seqlen)
|
303 |
+
"""
|
304 |
+
batch_size, seqlen = input_ids.shape
|
305 |
+
world_size = torch.distributed.get_world_size(self.process_group)
|
306 |
+
embeddings = self.word_embeddings(input_ids)
|
307 |
+
if self.max_position_embeddings > 0:
|
308 |
+
if position_ids is None:
|
309 |
+
position_ids = torch.arange(
|
310 |
+
seqlen, dtype=torch.long, device=input_ids.device
|
311 |
+
)
|
312 |
+
position_embeddings = self.position_embeddings(position_ids)
|
313 |
+
if world_size <= 1:
|
314 |
+
embeddings = embeddings + position_embeddings
|
315 |
+
else:
|
316 |
+
partition_dim = self.position_embeddings.embedding_dim
|
317 |
+
rank = torch.distributed.get_rank(self.process_group)
|
318 |
+
embeddings[
|
319 |
+
..., rank * partition_dim : (rank + 1) * partition_dim
|
320 |
+
] += position_embeddings
|
321 |
+
if combine_batch_seqlen_dim:
|
322 |
+
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
323 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
324 |
+
return (
|
325 |
+
embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
326 |
+
)
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/config_mamba.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class MambaConfig:
|
6 |
+
|
7 |
+
d_model: int = 2560
|
8 |
+
d_intermediate: int = 0
|
9 |
+
n_layer: int = 64
|
10 |
+
vocab_size: int = 50277
|
11 |
+
ssm_cfg: dict = field(default_factory=dict)
|
12 |
+
attn_layer_idx: list = field(default_factory=list)
|
13 |
+
attn_cfg: dict = field(default_factory=dict)
|
14 |
+
rms_norm: bool = True
|
15 |
+
residual_in_fp32: bool = True
|
16 |
+
fused_add_norm: bool = True
|
17 |
+
pad_vocab_size_multiple: int = 8
|
18 |
+
tie_embeddings: bool = True
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/models/mixer_seq_simple.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from collections import namedtuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from .config_mamba import MambaConfig
|
15 |
+
from ..modules.mamba_simple import Mamba
|
16 |
+
from ..modules.mamba2 import Mamba2
|
17 |
+
from ..modules.mha import MHA
|
18 |
+
from ..modules.mlp import GatedMLP
|
19 |
+
from ..modules.block import Block
|
20 |
+
from ..utils.generation import GenerationMixin
|
21 |
+
from ..utils.hf import load_config_hf, load_state_dict_hf
|
22 |
+
|
23 |
+
try:
|
24 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
25 |
+
except ImportError:
|
26 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
27 |
+
|
28 |
+
|
29 |
+
def create_block(
|
30 |
+
d_model,
|
31 |
+
d_intermediate,
|
32 |
+
ssm_cfg=None,
|
33 |
+
attn_layer_idx=None,
|
34 |
+
attn_cfg=None,
|
35 |
+
norm_epsilon=1e-5,
|
36 |
+
rms_norm=False,
|
37 |
+
residual_in_fp32=False,
|
38 |
+
fused_add_norm=False,
|
39 |
+
layer_idx=None,
|
40 |
+
device=None,
|
41 |
+
dtype=None,
|
42 |
+
):
|
43 |
+
if ssm_cfg is None:
|
44 |
+
ssm_cfg = {}
|
45 |
+
if attn_layer_idx is None:
|
46 |
+
attn_layer_idx = []
|
47 |
+
if attn_cfg is None:
|
48 |
+
attn_cfg = {}
|
49 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
50 |
+
if layer_idx not in attn_layer_idx:
|
51 |
+
# Create a copy of the config to modify
|
52 |
+
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
53 |
+
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
54 |
+
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
55 |
+
raise ValueError(
|
56 |
+
f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
|
57 |
+
)
|
58 |
+
mixer_cls = partial(
|
59 |
+
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
60 |
+
layer_idx=layer_idx,
|
61 |
+
**ssm_cfg,
|
62 |
+
**factory_kwargs,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
66 |
+
norm_cls = partial(
|
67 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
68 |
+
)
|
69 |
+
if d_intermediate == 0:
|
70 |
+
mlp_cls = nn.Identity
|
71 |
+
else:
|
72 |
+
mlp_cls = partial(
|
73 |
+
GatedMLP,
|
74 |
+
hidden_features=d_intermediate,
|
75 |
+
out_features=d_model,
|
76 |
+
**factory_kwargs,
|
77 |
+
)
|
78 |
+
block = Block(
|
79 |
+
d_model,
|
80 |
+
mixer_cls,
|
81 |
+
mlp_cls,
|
82 |
+
norm_cls=norm_cls,
|
83 |
+
fused_add_norm=fused_add_norm,
|
84 |
+
residual_in_fp32=residual_in_fp32,
|
85 |
+
)
|
86 |
+
block.layer_idx = layer_idx
|
87 |
+
return block
|
88 |
+
|
89 |
+
|
90 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
91 |
+
def _init_weights(
|
92 |
+
module,
|
93 |
+
n_layer,
|
94 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
95 |
+
rescale_prenorm_residual=True,
|
96 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
97 |
+
):
|
98 |
+
if isinstance(module, nn.Linear):
|
99 |
+
if module.bias is not None:
|
100 |
+
if not getattr(module.bias, "_no_reinit", False):
|
101 |
+
nn.init.zeros_(module.bias)
|
102 |
+
elif isinstance(module, nn.Embedding):
|
103 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
104 |
+
|
105 |
+
if rescale_prenorm_residual:
|
106 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
107 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
108 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
109 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
110 |
+
#
|
111 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
112 |
+
for name, p in module.named_parameters():
|
113 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
114 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
115 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
116 |
+
# We need to reinit p since this code could be called multiple times
|
117 |
+
# Having just p *= scale would repeatedly scale it down
|
118 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
119 |
+
with torch.no_grad():
|
120 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
121 |
+
|
122 |
+
|
123 |
+
class MixerModel(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
d_model: int,
|
127 |
+
n_layer: int,
|
128 |
+
d_intermediate: int,
|
129 |
+
vocab_size: int,
|
130 |
+
ssm_cfg=None,
|
131 |
+
attn_layer_idx=None,
|
132 |
+
attn_cfg=None,
|
133 |
+
norm_epsilon: float = 1e-5,
|
134 |
+
rms_norm: bool = False,
|
135 |
+
initializer_cfg=None,
|
136 |
+
fused_add_norm=False,
|
137 |
+
residual_in_fp32=False,
|
138 |
+
device=None,
|
139 |
+
dtype=None,
|
140 |
+
) -> None:
|
141 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
142 |
+
super().__init__()
|
143 |
+
self.residual_in_fp32 = residual_in_fp32
|
144 |
+
|
145 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
146 |
+
|
147 |
+
# We change the order of residual and layer norm:
|
148 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
149 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
150 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
151 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
152 |
+
self.fused_add_norm = fused_add_norm
|
153 |
+
if self.fused_add_norm:
|
154 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
155 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
156 |
+
|
157 |
+
self.layers = nn.ModuleList(
|
158 |
+
[
|
159 |
+
create_block(
|
160 |
+
d_model,
|
161 |
+
d_intermediate=d_intermediate,
|
162 |
+
ssm_cfg=ssm_cfg,
|
163 |
+
attn_layer_idx=attn_layer_idx,
|
164 |
+
attn_cfg=attn_cfg,
|
165 |
+
norm_epsilon=norm_epsilon,
|
166 |
+
rms_norm=rms_norm,
|
167 |
+
residual_in_fp32=residual_in_fp32,
|
168 |
+
fused_add_norm=fused_add_norm,
|
169 |
+
layer_idx=i,
|
170 |
+
**factory_kwargs,
|
171 |
+
)
|
172 |
+
for i in range(n_layer)
|
173 |
+
]
|
174 |
+
)
|
175 |
+
|
176 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
177 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
178 |
+
)
|
179 |
+
|
180 |
+
self.apply(
|
181 |
+
partial(
|
182 |
+
_init_weights,
|
183 |
+
n_layer=n_layer,
|
184 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
185 |
+
n_residuals_per_layer=(
|
186 |
+
1 if d_intermediate == 0 else 2
|
187 |
+
), # 2 if we have MLP
|
188 |
+
)
|
189 |
+
)
|
190 |
+
|
191 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
192 |
+
return {
|
193 |
+
i: layer.allocate_inference_cache(
|
194 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
195 |
+
)
|
196 |
+
for i, layer in enumerate(self.layers)
|
197 |
+
}
|
198 |
+
|
199 |
+
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
200 |
+
hidden_states = self.embedding(input_ids)
|
201 |
+
residual = None
|
202 |
+
for layer in self.layers:
|
203 |
+
hidden_states, residual = layer(
|
204 |
+
hidden_states,
|
205 |
+
residual,
|
206 |
+
inference_params=inference_params,
|
207 |
+
**mixer_kwargs,
|
208 |
+
)
|
209 |
+
if not self.fused_add_norm:
|
210 |
+
residual = (
|
211 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
212 |
+
)
|
213 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
214 |
+
else:
|
215 |
+
# Set prenorm=False here since we don't need the residual
|
216 |
+
hidden_states = layer_norm_fn(
|
217 |
+
hidden_states,
|
218 |
+
self.norm_f.weight,
|
219 |
+
self.norm_f.bias,
|
220 |
+
eps=self.norm_f.eps,
|
221 |
+
residual=residual,
|
222 |
+
prenorm=False,
|
223 |
+
residual_in_fp32=self.residual_in_fp32,
|
224 |
+
is_rms_norm=isinstance(self.norm_f, RMSNorm),
|
225 |
+
)
|
226 |
+
return hidden_states
|
227 |
+
|
228 |
+
|
229 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
config: MambaConfig,
|
234 |
+
initializer_cfg=None,
|
235 |
+
device=None,
|
236 |
+
dtype=None,
|
237 |
+
) -> None:
|
238 |
+
self.config = config
|
239 |
+
d_model = config.d_model
|
240 |
+
n_layer = config.n_layer
|
241 |
+
d_intermediate = config.d_intermediate
|
242 |
+
vocab_size = config.vocab_size
|
243 |
+
ssm_cfg = config.ssm_cfg
|
244 |
+
attn_layer_idx = config.attn_layer_idx
|
245 |
+
attn_cfg = config.attn_cfg
|
246 |
+
rms_norm = config.rms_norm
|
247 |
+
residual_in_fp32 = config.residual_in_fp32
|
248 |
+
fused_add_norm = config.fused_add_norm
|
249 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
250 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
251 |
+
|
252 |
+
super().__init__()
|
253 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
254 |
+
vocab_size += pad_vocab_size_multiple - (
|
255 |
+
vocab_size % pad_vocab_size_multiple
|
256 |
+
)
|
257 |
+
self.backbone = MixerModel(
|
258 |
+
d_model=d_model,
|
259 |
+
n_layer=n_layer,
|
260 |
+
d_intermediate=d_intermediate,
|
261 |
+
vocab_size=vocab_size,
|
262 |
+
ssm_cfg=ssm_cfg,
|
263 |
+
attn_layer_idx=attn_layer_idx,
|
264 |
+
attn_cfg=attn_cfg,
|
265 |
+
rms_norm=rms_norm,
|
266 |
+
initializer_cfg=initializer_cfg,
|
267 |
+
fused_add_norm=fused_add_norm,
|
268 |
+
residual_in_fp32=residual_in_fp32,
|
269 |
+
**factory_kwargs,
|
270 |
+
)
|
271 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
272 |
+
|
273 |
+
# Initialize weights and apply final processing
|
274 |
+
self.apply(
|
275 |
+
partial(
|
276 |
+
_init_weights,
|
277 |
+
n_layer=n_layer,
|
278 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
279 |
+
)
|
280 |
+
)
|
281 |
+
self.tie_weights()
|
282 |
+
|
283 |
+
def tie_weights(self):
|
284 |
+
if self.config.tie_embeddings:
|
285 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
286 |
+
|
287 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
288 |
+
return self.backbone.allocate_inference_cache(
|
289 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(
|
293 |
+
self,
|
294 |
+
input_ids,
|
295 |
+
position_ids=None,
|
296 |
+
inference_params=None,
|
297 |
+
num_last_tokens=0,
|
298 |
+
**mixer_kwargs,
|
299 |
+
):
|
300 |
+
"""
|
301 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
302 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
303 |
+
"""
|
304 |
+
hidden_states = self.backbone(
|
305 |
+
input_ids, inference_params=inference_params, **mixer_kwargs
|
306 |
+
)
|
307 |
+
if num_last_tokens > 0:
|
308 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
309 |
+
lm_logits = self.lm_head(hidden_states)
|
310 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
311 |
+
return CausalLMOutput(logits=lm_logits)
|
312 |
+
|
313 |
+
@classmethod
|
314 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
315 |
+
config_data = load_config_hf(pretrained_model_name)
|
316 |
+
config = MambaConfig(**config_data)
|
317 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
318 |
+
model.load_state_dict(
|
319 |
+
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
320 |
+
)
|
321 |
+
return model
|
322 |
+
|
323 |
+
def save_pretrained(self, save_directory):
|
324 |
+
"""
|
325 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
326 |
+
Save the model and its configuration file to a directory.
|
327 |
+
"""
|
328 |
+
# Ensure save_directory exists
|
329 |
+
os.makedirs(save_directory, exist_ok=True)
|
330 |
+
|
331 |
+
# Save the model's state_dict
|
332 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
333 |
+
torch.save(self.state_dict(), model_path)
|
334 |
+
|
335 |
+
# Save the configuration of the model
|
336 |
+
config_path = os.path.join(save_directory, "config.json")
|
337 |
+
with open(config_path, "w") as f:
|
338 |
+
json.dump(self.config.__dict__, f, indent=4)
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/__init__.py
ADDED
File without changes
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/block.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
|
7 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
8 |
+
|
9 |
+
|
10 |
+
class Block(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
dim,
|
14 |
+
mixer_cls,
|
15 |
+
mlp_cls,
|
16 |
+
norm_cls=nn.LayerNorm,
|
17 |
+
fused_add_norm=False,
|
18 |
+
residual_in_fp32=False,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
22 |
+
|
23 |
+
This Block has a slightly different structure compared to a regular
|
24 |
+
prenorm Transformer block.
|
25 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
26 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
27 |
+
Here we have: Add -> LN -> Mixer, returning both
|
28 |
+
the hidden_states (output of the mixer) and the residual.
|
29 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
30 |
+
The residual needs to be provided (except for the very first block).
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
self.residual_in_fp32 = residual_in_fp32
|
34 |
+
self.fused_add_norm = fused_add_norm
|
35 |
+
self.norm = norm_cls(dim)
|
36 |
+
self.mixer = mixer_cls(dim)
|
37 |
+
if mlp_cls is not nn.Identity:
|
38 |
+
self.norm2 = norm_cls(dim)
|
39 |
+
self.mlp = mlp_cls(dim)
|
40 |
+
else:
|
41 |
+
self.mlp = None
|
42 |
+
if self.fused_add_norm:
|
43 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
44 |
+
assert isinstance(
|
45 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
46 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
47 |
+
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
hidden_states: Tensor,
|
51 |
+
residual: Optional[Tensor] = None,
|
52 |
+
inference_params=None,
|
53 |
+
**mixer_kwargs
|
54 |
+
):
|
55 |
+
r"""Pass the input through the encoder layer.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
hidden_states: the sequence to the encoder layer (required).
|
59 |
+
residual: hidden_states = Mixer(LN(residual))
|
60 |
+
"""
|
61 |
+
if not self.fused_add_norm:
|
62 |
+
residual = (
|
63 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
64 |
+
)
|
65 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
66 |
+
if self.residual_in_fp32:
|
67 |
+
residual = residual.to(torch.float32)
|
68 |
+
else:
|
69 |
+
hidden_states, residual = layer_norm_fn(
|
70 |
+
hidden_states,
|
71 |
+
self.norm.weight,
|
72 |
+
self.norm.bias,
|
73 |
+
residual=residual,
|
74 |
+
prenorm=True,
|
75 |
+
residual_in_fp32=self.residual_in_fp32,
|
76 |
+
eps=self.norm.eps,
|
77 |
+
is_rms_norm=isinstance(self.norm, RMSNorm),
|
78 |
+
)
|
79 |
+
hidden_states = self.mixer(
|
80 |
+
hidden_states, inference_params=inference_params, **mixer_kwargs
|
81 |
+
)
|
82 |
+
|
83 |
+
if self.mlp is not None:
|
84 |
+
if not self.fused_add_norm:
|
85 |
+
residual = hidden_states + residual
|
86 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
87 |
+
if self.residual_in_fp32:
|
88 |
+
residual = residual.to(torch.float32)
|
89 |
+
else:
|
90 |
+
hidden_states, residual = layer_norm_fn(
|
91 |
+
hidden_states,
|
92 |
+
self.norm2.weight,
|
93 |
+
self.norm2.bias,
|
94 |
+
residual=residual,
|
95 |
+
prenorm=True,
|
96 |
+
residual_in_fp32=self.residual_in_fp32,
|
97 |
+
eps=self.norm2.eps,
|
98 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
99 |
+
)
|
100 |
+
hidden_states = self.mlp(hidden_states)
|
101 |
+
|
102 |
+
return hidden_states, residual
|
103 |
+
|
104 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
105 |
+
return self.mixer.allocate_inference_cache(
|
106 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
107 |
+
)
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba2.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
try:
|
12 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
13 |
+
except ImportError:
|
14 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
15 |
+
|
16 |
+
try:
|
17 |
+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
|
18 |
+
except ImportError:
|
19 |
+
causal_conv1d_varlen_states = None
|
20 |
+
|
21 |
+
try:
|
22 |
+
from ..ops.triton.selective_state_update import selective_state_update
|
23 |
+
except ImportError:
|
24 |
+
selective_state_update = None
|
25 |
+
|
26 |
+
from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
27 |
+
|
28 |
+
from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
29 |
+
from ..distributed.distributed_utils import all_reduce, reduce_scatter
|
30 |
+
|
31 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
32 |
+
from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
33 |
+
|
34 |
+
from huggingface_hub import PyTorchModelHubMixin
|
35 |
+
|
36 |
+
|
37 |
+
class Mamba2(nn.Module, PyTorchModelHubMixin):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
d_model,
|
41 |
+
d_state=128,
|
42 |
+
d_conv=4,
|
43 |
+
conv_init=None,
|
44 |
+
expand=2,
|
45 |
+
headdim=64,
|
46 |
+
d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
|
47 |
+
ngroups=1,
|
48 |
+
A_init_range=(1, 16),
|
49 |
+
D_has_hdim=False,
|
50 |
+
rmsnorm=True,
|
51 |
+
norm_before_gate=False,
|
52 |
+
dt_min=0.001,
|
53 |
+
dt_max=0.1,
|
54 |
+
dt_init_floor=1e-4,
|
55 |
+
dt_limit=(0.0, float("inf")),
|
56 |
+
bias=False,
|
57 |
+
conv_bias=True,
|
58 |
+
# Fused kernel and sharding options
|
59 |
+
chunk_size=256,
|
60 |
+
use_mem_eff_path=True,
|
61 |
+
layer_idx=None, # Absorb kwarg for general module
|
62 |
+
process_group=None,
|
63 |
+
sequence_parallel=True,
|
64 |
+
device=None,
|
65 |
+
dtype=None,
|
66 |
+
):
|
67 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
68 |
+
super().__init__()
|
69 |
+
self.d_model = d_model
|
70 |
+
self.d_state = d_state
|
71 |
+
self.d_conv = d_conv
|
72 |
+
self.conv_init = conv_init
|
73 |
+
self.expand = expand
|
74 |
+
self.process_group = process_group
|
75 |
+
self.sequence_parallel = sequence_parallel
|
76 |
+
self.world_size = 1 if process_group is None else process_group.size()
|
77 |
+
self.local_rank = 0 if process_group is None else process_group.rank()
|
78 |
+
self.d_inner = (self.expand * self.d_model) // self.world_size
|
79 |
+
assert self.d_inner * self.world_size == self.expand * self.d_model
|
80 |
+
self.headdim = headdim
|
81 |
+
self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
|
82 |
+
assert ngroups % self.world_size == 0
|
83 |
+
self.ngroups = ngroups // self.world_size
|
84 |
+
assert self.d_ssm % self.headdim == 0
|
85 |
+
self.nheads = self.d_ssm // self.headdim
|
86 |
+
self.D_has_hdim = D_has_hdim
|
87 |
+
self.rmsnorm = rmsnorm
|
88 |
+
self.norm_before_gate = norm_before_gate
|
89 |
+
self.dt_limit = dt_limit
|
90 |
+
self.activation = "silu"
|
91 |
+
self.chunk_size = chunk_size
|
92 |
+
self.use_mem_eff_path = use_mem_eff_path
|
93 |
+
self.layer_idx = layer_idx
|
94 |
+
|
95 |
+
# Order: [z, x, B, C, dt]
|
96 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
97 |
+
if self.process_group is None:
|
98 |
+
self.in_proj = nn.Linear(
|
99 |
+
self.d_model, d_in_proj, bias=bias, **factory_kwargs
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
self.in_proj = ColumnParallelLinear(
|
103 |
+
self.d_model,
|
104 |
+
d_in_proj * self.world_size,
|
105 |
+
bias=bias,
|
106 |
+
process_group=self.process_group,
|
107 |
+
sequence_parallel=self.sequence_parallel,
|
108 |
+
**factory_kwargs,
|
109 |
+
)
|
110 |
+
|
111 |
+
conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
|
112 |
+
self.conv1d = nn.Conv1d(
|
113 |
+
in_channels=conv_dim,
|
114 |
+
out_channels=conv_dim,
|
115 |
+
bias=conv_bias,
|
116 |
+
kernel_size=d_conv,
|
117 |
+
groups=conv_dim,
|
118 |
+
padding=d_conv - 1,
|
119 |
+
**factory_kwargs,
|
120 |
+
)
|
121 |
+
if self.conv_init is not None:
|
122 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
123 |
+
|
124 |
+
self.act = nn.SiLU()
|
125 |
+
|
126 |
+
# Initialize log dt bias
|
127 |
+
dt = torch.exp(
|
128 |
+
torch.rand(self.nheads, **factory_kwargs)
|
129 |
+
* (math.log(dt_max) - math.log(dt_min))
|
130 |
+
+ math.log(dt_min)
|
131 |
+
)
|
132 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
133 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
134 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
135 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
136 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
137 |
+
# name.endswith("bias") in param_grouping.py
|
138 |
+
self.dt_bias._no_weight_decay = True
|
139 |
+
|
140 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
141 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
|
142 |
+
*A_init_range
|
143 |
+
)
|
144 |
+
A_log = torch.log(A).to(dtype=dtype)
|
145 |
+
self.A_log = nn.Parameter(A_log)
|
146 |
+
self.A_log._no_weight_decay = True
|
147 |
+
|
148 |
+
# D "skip" parameter
|
149 |
+
self.D = nn.Parameter(
|
150 |
+
torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
|
151 |
+
)
|
152 |
+
self.D._no_weight_decay = True
|
153 |
+
|
154 |
+
if self.rmsnorm:
|
155 |
+
assert RMSNormGated is not None
|
156 |
+
self.norm = RMSNormGated(
|
157 |
+
self.d_ssm,
|
158 |
+
eps=1e-5,
|
159 |
+
norm_before_gate=self.norm_before_gate,
|
160 |
+
group_size=self.d_ssm // ngroups,
|
161 |
+
**factory_kwargs,
|
162 |
+
)
|
163 |
+
|
164 |
+
if self.process_group is None:
|
165 |
+
self.out_proj = nn.Linear(
|
166 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
self.out_proj = RowParallelLinear(
|
170 |
+
self.d_inner * self.world_size,
|
171 |
+
self.d_model,
|
172 |
+
bias=bias,
|
173 |
+
process_group=self.process_group,
|
174 |
+
sequence_parallel=self.sequence_parallel,
|
175 |
+
**factory_kwargs,
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(
|
179 |
+
self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
|
180 |
+
):
|
181 |
+
"""
|
182 |
+
u: (batch, seqlen, hidden_dim) if seqlen=None.
|
183 |
+
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
|
184 |
+
split u during sequence parallel, we split the batch * seqlen dimension
|
185 |
+
(in case batch is small).
|
186 |
+
Returns: same shape as u
|
187 |
+
"""
|
188 |
+
seqlen_og = seqlen
|
189 |
+
if seqlen is None:
|
190 |
+
batch, seqlen, dim = u.shape
|
191 |
+
else:
|
192 |
+
batch_seqlen, dim = u.shape
|
193 |
+
batch = batch_seqlen // seqlen
|
194 |
+
|
195 |
+
conv_state, ssm_state = None, None
|
196 |
+
if inference_params is not None:
|
197 |
+
inference_batch = (
|
198 |
+
cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
|
199 |
+
)
|
200 |
+
conv_state, ssm_state = self._get_states_from_cache(
|
201 |
+
inference_params, inference_batch
|
202 |
+
)
|
203 |
+
if inference_params.seqlen_offset > 0:
|
204 |
+
# The states are updated inplace
|
205 |
+
out, _, _ = self.step(u, conv_state, ssm_state)
|
206 |
+
return out
|
207 |
+
|
208 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
|
209 |
+
if seqlen_og is not None:
|
210 |
+
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
|
211 |
+
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
212 |
+
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
|
213 |
+
dt_limit_kwargs = (
|
214 |
+
{} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
215 |
+
)
|
216 |
+
if self.use_mem_eff_path and inference_params is None:
|
217 |
+
out = mamba_split_conv1d_scan_combined(
|
218 |
+
zxbcdt,
|
219 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
220 |
+
self.conv1d.bias,
|
221 |
+
self.dt_bias,
|
222 |
+
A,
|
223 |
+
D=(
|
224 |
+
rearrange(self.D, "(h p) -> h p", p=self.headdim)
|
225 |
+
if self.D_has_hdim
|
226 |
+
else self.D
|
227 |
+
),
|
228 |
+
chunk_size=self.chunk_size,
|
229 |
+
seq_idx=seq_idx,
|
230 |
+
activation=self.activation,
|
231 |
+
rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
|
232 |
+
rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
|
233 |
+
outproj_weight=self.out_proj.weight,
|
234 |
+
outproj_bias=self.out_proj.bias,
|
235 |
+
headdim=None if self.D_has_hdim else self.headdim,
|
236 |
+
ngroups=self.ngroups,
|
237 |
+
norm_before_gate=self.norm_before_gate,
|
238 |
+
**dt_limit_kwargs,
|
239 |
+
)
|
240 |
+
if seqlen_og is not None:
|
241 |
+
out = rearrange(out, "b l d -> (b l) d")
|
242 |
+
if self.process_group is not None:
|
243 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
244 |
+
out = reduce_fn(out, self.process_group)
|
245 |
+
else:
|
246 |
+
d_mlp = (
|
247 |
+
zxbcdt.shape[-1]
|
248 |
+
- 2 * self.d_ssm
|
249 |
+
- 2 * self.ngroups * self.d_state
|
250 |
+
- self.nheads
|
251 |
+
) // 2
|
252 |
+
z0, x0, z, xBC, dt = torch.split(
|
253 |
+
zxbcdt,
|
254 |
+
[
|
255 |
+
d_mlp,
|
256 |
+
d_mlp,
|
257 |
+
self.d_ssm,
|
258 |
+
self.d_ssm + 2 * self.ngroups * self.d_state,
|
259 |
+
self.nheads,
|
260 |
+
],
|
261 |
+
dim=-1,
|
262 |
+
)
|
263 |
+
if conv_state is not None:
|
264 |
+
if cu_seqlens is None:
|
265 |
+
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
266 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
267 |
+
xBC_t = rearrange(xBC, "b l d -> b d l")
|
268 |
+
conv_state.copy_(
|
269 |
+
F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
|
270 |
+
) # Update state (B D W)
|
271 |
+
else:
|
272 |
+
assert (
|
273 |
+
causal_conv1d_varlen_states is not None
|
274 |
+
), "varlen inference requires causal_conv1d package"
|
275 |
+
assert (
|
276 |
+
batch == 1
|
277 |
+
), "varlen inference only supports batch dimension 1"
|
278 |
+
conv_varlen_states = causal_conv1d_varlen_states(
|
279 |
+
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
|
280 |
+
)
|
281 |
+
conv_state.copy_(conv_varlen_states)
|
282 |
+
assert self.activation in ["silu", "swish"]
|
283 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
284 |
+
assert (
|
285 |
+
seq_idx is None
|
286 |
+
), "varlen conv1d requires the causal_conv1d package"
|
287 |
+
xBC = self.act(
|
288 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
|
289 |
+
:, : -(self.d_conv - 1)
|
290 |
+
]
|
291 |
+
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
|
292 |
+
else:
|
293 |
+
xBC = causal_conv1d_fn(
|
294 |
+
xBC.transpose(1, 2),
|
295 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
296 |
+
bias=self.conv1d.bias,
|
297 |
+
activation=self.activation,
|
298 |
+
seq_idx=seq_idx,
|
299 |
+
).transpose(1, 2)
|
300 |
+
x, B, C = torch.split(
|
301 |
+
xBC,
|
302 |
+
[self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
|
303 |
+
dim=-1,
|
304 |
+
)
|
305 |
+
y = mamba_chunk_scan_combined(
|
306 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
307 |
+
dt,
|
308 |
+
A,
|
309 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
310 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
311 |
+
chunk_size=self.chunk_size,
|
312 |
+
D=(
|
313 |
+
rearrange(self.D, "(h p) -> h p", p=self.headdim)
|
314 |
+
if self.D_has_hdim
|
315 |
+
else self.D
|
316 |
+
),
|
317 |
+
z=(
|
318 |
+
rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
|
319 |
+
if not self.rmsnorm
|
320 |
+
else None
|
321 |
+
),
|
322 |
+
dt_bias=self.dt_bias,
|
323 |
+
dt_softplus=True,
|
324 |
+
seq_idx=seq_idx,
|
325 |
+
cu_seqlens=cu_seqlens,
|
326 |
+
**dt_limit_kwargs,
|
327 |
+
return_final_states=ssm_state is not None,
|
328 |
+
return_varlen_states=cu_seqlens is not None
|
329 |
+
and inference_params is not None,
|
330 |
+
)
|
331 |
+
if ssm_state is not None:
|
332 |
+
y, last_state, *rest = y
|
333 |
+
if cu_seqlens is None:
|
334 |
+
ssm_state.copy_(last_state)
|
335 |
+
else:
|
336 |
+
varlen_states = rest[0]
|
337 |
+
ssm_state.copy_(varlen_states)
|
338 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
339 |
+
if self.rmsnorm:
|
340 |
+
y = self.norm(y, z)
|
341 |
+
if d_mlp > 0:
|
342 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
343 |
+
if seqlen_og is not None:
|
344 |
+
y = rearrange(y, "b l d -> (b l) d")
|
345 |
+
out = self.out_proj(y)
|
346 |
+
return out
|
347 |
+
|
348 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
349 |
+
dtype = hidden_states.dtype
|
350 |
+
assert (
|
351 |
+
hidden_states.shape[1] == 1
|
352 |
+
), "Only support decoding with 1 token at a time for now"
|
353 |
+
zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
354 |
+
d_mlp = (
|
355 |
+
zxbcdt.shape[-1]
|
356 |
+
- 2 * self.d_ssm
|
357 |
+
- 2 * self.ngroups * self.d_state
|
358 |
+
- self.nheads
|
359 |
+
) // 2
|
360 |
+
z0, x0, z, xBC, dt = torch.split(
|
361 |
+
zxbcdt,
|
362 |
+
[
|
363 |
+
d_mlp,
|
364 |
+
d_mlp,
|
365 |
+
self.d_ssm,
|
366 |
+
self.d_ssm + 2 * self.ngroups * self.d_state,
|
367 |
+
self.nheads,
|
368 |
+
],
|
369 |
+
dim=-1,
|
370 |
+
)
|
371 |
+
|
372 |
+
# Conv step
|
373 |
+
if causal_conv1d_update is None:
|
374 |
+
conv_state.copy_(
|
375 |
+
torch.roll(conv_state, shifts=-1, dims=-1)
|
376 |
+
) # Update state (B D W)
|
377 |
+
conv_state[:, :, -1] = xBC
|
378 |
+
xBC = torch.sum(
|
379 |
+
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
380 |
+
) # (B D)
|
381 |
+
if self.conv1d.bias is not None:
|
382 |
+
xBC = xBC + self.conv1d.bias
|
383 |
+
xBC = self.act(xBC).to(dtype=dtype)
|
384 |
+
else:
|
385 |
+
xBC = causal_conv1d_update(
|
386 |
+
xBC,
|
387 |
+
conv_state,
|
388 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
389 |
+
self.conv1d.bias,
|
390 |
+
self.activation,
|
391 |
+
)
|
392 |
+
|
393 |
+
x, B, C = torch.split(
|
394 |
+
xBC,
|
395 |
+
[self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
|
396 |
+
dim=-1,
|
397 |
+
)
|
398 |
+
A = -torch.exp(self.A_log.float()) # (nheads,)
|
399 |
+
|
400 |
+
# SSM step
|
401 |
+
if selective_state_update is None:
|
402 |
+
assert (
|
403 |
+
self.ngroups == 1
|
404 |
+
), "Only support ngroups=1 for this inference code path"
|
405 |
+
# Discretize A and B
|
406 |
+
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
|
407 |
+
dA = torch.exp(dt * A) # (batch, nheads)
|
408 |
+
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
409 |
+
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
|
410 |
+
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
411 |
+
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
|
412 |
+
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
|
413 |
+
y = rearrange(y, "b h p -> b (h p)")
|
414 |
+
if not self.rmsnorm:
|
415 |
+
y = y * self.act(z) # (B D)
|
416 |
+
else:
|
417 |
+
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
|
418 |
+
dtype=torch.float32
|
419 |
+
)
|
420 |
+
dt = repeat(dt, "b h -> b h p", p=self.headdim)
|
421 |
+
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
|
422 |
+
D = repeat(self.D, "h -> h p", p=self.headdim)
|
423 |
+
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
|
424 |
+
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
|
425 |
+
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
426 |
+
if not self.rmsnorm:
|
427 |
+
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
|
428 |
+
y = selective_state_update(
|
429 |
+
ssm_state,
|
430 |
+
x_reshaped,
|
431 |
+
dt,
|
432 |
+
A,
|
433 |
+
B,
|
434 |
+
C,
|
435 |
+
D,
|
436 |
+
z=z if not self.rmsnorm else None,
|
437 |
+
dt_bias=dt_bias,
|
438 |
+
dt_softplus=True,
|
439 |
+
)
|
440 |
+
y = rearrange(y, "b h p -> b (h p)")
|
441 |
+
if self.rmsnorm:
|
442 |
+
y = self.norm(y, z)
|
443 |
+
if d_mlp > 0:
|
444 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
445 |
+
out = self.out_proj(y)
|
446 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
447 |
+
|
448 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
449 |
+
device = self.out_proj.weight.device
|
450 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
451 |
+
conv_state = torch.zeros(
|
452 |
+
batch_size,
|
453 |
+
self.d_conv,
|
454 |
+
self.conv1d.weight.shape[0],
|
455 |
+
device=device,
|
456 |
+
dtype=conv_dtype,
|
457 |
+
).transpose(1, 2)
|
458 |
+
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
|
459 |
+
ssm_state = torch.zeros(
|
460 |
+
batch_size,
|
461 |
+
self.nheads,
|
462 |
+
self.headdim,
|
463 |
+
self.d_state,
|
464 |
+
device=device,
|
465 |
+
dtype=ssm_dtype,
|
466 |
+
)
|
467 |
+
return conv_state, ssm_state
|
468 |
+
|
469 |
+
def _get_states_from_cache(
|
470 |
+
self, inference_params, batch_size, initialize_states=False
|
471 |
+
):
|
472 |
+
assert self.layer_idx is not None
|
473 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
474 |
+
batch_shape = (batch_size,)
|
475 |
+
conv_state = torch.zeros(
|
476 |
+
batch_size,
|
477 |
+
self.d_conv,
|
478 |
+
self.conv1d.weight.shape[0],
|
479 |
+
device=self.conv1d.weight.device,
|
480 |
+
dtype=self.conv1d.weight.dtype,
|
481 |
+
).transpose(1, 2)
|
482 |
+
ssm_state = torch.zeros(
|
483 |
+
batch_size,
|
484 |
+
self.nheads,
|
485 |
+
self.headdim,
|
486 |
+
self.d_state,
|
487 |
+
device=self.in_proj.weight.device,
|
488 |
+
dtype=self.in_proj.weight.dtype,
|
489 |
+
)
|
490 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (
|
491 |
+
conv_state,
|
492 |
+
ssm_state,
|
493 |
+
)
|
494 |
+
else:
|
495 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
496 |
+
self.layer_idx
|
497 |
+
]
|
498 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
499 |
+
if initialize_states:
|
500 |
+
conv_state.zero_()
|
501 |
+
ssm_state.zero_()
|
502 |
+
return conv_state, ssm_state
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba2_simple.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
try:
|
11 |
+
from causal_conv1d import causal_conv1d_fn
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
|
17 |
+
except ImportError:
|
18 |
+
RMSNormGated, LayerNorm = None, None
|
19 |
+
|
20 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
21 |
+
from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
22 |
+
|
23 |
+
|
24 |
+
class Mamba2Simple(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
d_model,
|
28 |
+
d_state=64,
|
29 |
+
d_conv=4,
|
30 |
+
conv_init=None,
|
31 |
+
expand=2,
|
32 |
+
headdim=128,
|
33 |
+
ngroups=1,
|
34 |
+
A_init_range=(1, 16),
|
35 |
+
dt_min=0.001,
|
36 |
+
dt_max=0.1,
|
37 |
+
dt_init_floor=1e-4,
|
38 |
+
dt_limit=(0.0, float("inf")),
|
39 |
+
learnable_init_states=False,
|
40 |
+
activation="swish",
|
41 |
+
bias=False,
|
42 |
+
conv_bias=True,
|
43 |
+
# Fused kernel and sharding options
|
44 |
+
chunk_size=256,
|
45 |
+
use_mem_eff_path=True,
|
46 |
+
layer_idx=None, # Absorb kwarg for general module
|
47 |
+
device=None,
|
48 |
+
dtype=None,
|
49 |
+
):
|
50 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
51 |
+
super().__init__()
|
52 |
+
self.d_model = d_model
|
53 |
+
self.d_state = d_state
|
54 |
+
self.d_conv = d_conv
|
55 |
+
self.conv_init = conv_init
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = self.expand * self.d_model
|
58 |
+
self.headdim = headdim
|
59 |
+
self.ngroups = ngroups
|
60 |
+
assert self.d_inner % self.headdim == 0
|
61 |
+
self.nheads = self.d_inner // self.headdim
|
62 |
+
self.dt_limit = dt_limit
|
63 |
+
self.learnable_init_states = learnable_init_states
|
64 |
+
self.activation = activation
|
65 |
+
self.chunk_size = chunk_size
|
66 |
+
self.use_mem_eff_path = use_mem_eff_path
|
67 |
+
self.layer_idx = layer_idx
|
68 |
+
|
69 |
+
# Order: [z, x, B, C, dt]
|
70 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
71 |
+
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
72 |
+
|
73 |
+
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
74 |
+
self.conv1d = nn.Conv1d(
|
75 |
+
in_channels=conv_dim,
|
76 |
+
out_channels=conv_dim,
|
77 |
+
bias=conv_bias,
|
78 |
+
kernel_size=d_conv,
|
79 |
+
groups=conv_dim,
|
80 |
+
padding=d_conv - 1,
|
81 |
+
**factory_kwargs,
|
82 |
+
)
|
83 |
+
if self.conv_init is not None:
|
84 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
85 |
+
# self.conv1d.weight._no_weight_decay = True
|
86 |
+
|
87 |
+
if self.learnable_init_states:
|
88 |
+
self.init_states = nn.Parameter(
|
89 |
+
torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
|
90 |
+
)
|
91 |
+
self.init_states._no_weight_decay = True
|
92 |
+
|
93 |
+
self.act = nn.SiLU()
|
94 |
+
|
95 |
+
# Initialize log dt bias
|
96 |
+
dt = torch.exp(
|
97 |
+
torch.rand(self.nheads, **factory_kwargs)
|
98 |
+
* (math.log(dt_max) - math.log(dt_min))
|
99 |
+
+ math.log(dt_min)
|
100 |
+
)
|
101 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
102 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
103 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
104 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
105 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
106 |
+
# name.endswith("bias") in param_grouping.py
|
107 |
+
self.dt_bias._no_weight_decay = True
|
108 |
+
|
109 |
+
# A parameter
|
110 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
111 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
|
112 |
+
*A_init_range
|
113 |
+
)
|
114 |
+
A_log = torch.log(A).to(dtype=dtype)
|
115 |
+
self.A_log = nn.Parameter(A_log)
|
116 |
+
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
|
117 |
+
self.A_log._no_weight_decay = True
|
118 |
+
|
119 |
+
# D "skip" parameter
|
120 |
+
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
|
121 |
+
self.D._no_weight_decay = True
|
122 |
+
|
123 |
+
# Extra normalization layer right before output projection
|
124 |
+
assert RMSNormGated is not None
|
125 |
+
self.norm = RMSNormGated(
|
126 |
+
self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
|
127 |
+
)
|
128 |
+
|
129 |
+
self.out_proj = nn.Linear(
|
130 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, u, seq_idx=None):
|
134 |
+
"""
|
135 |
+
u: (B, L, D)
|
136 |
+
Returns: same shape as u
|
137 |
+
"""
|
138 |
+
batch, seqlen, dim = u.shape
|
139 |
+
|
140 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
141 |
+
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
142 |
+
initial_states = (
|
143 |
+
repeat(self.init_states, "... -> b ...", b=batch)
|
144 |
+
if self.learnable_init_states
|
145 |
+
else None
|
146 |
+
)
|
147 |
+
dt_limit_kwargs = (
|
148 |
+
{} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
149 |
+
)
|
150 |
+
|
151 |
+
if self.use_mem_eff_path:
|
152 |
+
# Fully fused path
|
153 |
+
out = mamba_split_conv1d_scan_combined(
|
154 |
+
zxbcdt,
|
155 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
156 |
+
self.conv1d.bias,
|
157 |
+
self.dt_bias,
|
158 |
+
A,
|
159 |
+
D=self.D,
|
160 |
+
chunk_size=self.chunk_size,
|
161 |
+
seq_idx=seq_idx,
|
162 |
+
activation=self.activation,
|
163 |
+
rmsnorm_weight=self.norm.weight,
|
164 |
+
rmsnorm_eps=self.norm.eps,
|
165 |
+
outproj_weight=self.out_proj.weight,
|
166 |
+
outproj_bias=self.out_proj.bias,
|
167 |
+
headdim=self.headdim,
|
168 |
+
ngroups=self.ngroups,
|
169 |
+
norm_before_gate=False,
|
170 |
+
initial_states=initial_states,
|
171 |
+
**dt_limit_kwargs,
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
z, xBC, dt = torch.split(
|
175 |
+
zxbcdt,
|
176 |
+
[
|
177 |
+
self.d_inner,
|
178 |
+
self.d_inner + 2 * self.ngroups * self.d_state,
|
179 |
+
self.nheads,
|
180 |
+
],
|
181 |
+
dim=-1,
|
182 |
+
)
|
183 |
+
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
184 |
+
assert self.activation in ["silu", "swish"]
|
185 |
+
|
186 |
+
# 1D Convolution
|
187 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
188 |
+
xBC = self.act(
|
189 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
190 |
+
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
191 |
+
xBC = xBC[:, :seqlen, :]
|
192 |
+
else:
|
193 |
+
xBC = causal_conv1d_fn(
|
194 |
+
x=xBC.transpose(1, 2),
|
195 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
196 |
+
bias=self.conv1d.bias,
|
197 |
+
activation=self.activation,
|
198 |
+
).transpose(1, 2)
|
199 |
+
|
200 |
+
# Split into 3 main branches: X, B, C
|
201 |
+
# These correspond to V, K, Q respectively in the SSM/attention duality
|
202 |
+
x, B, C = torch.split(
|
203 |
+
xBC,
|
204 |
+
[
|
205 |
+
self.d_inner,
|
206 |
+
self.ngroups * self.d_state,
|
207 |
+
self.ngroups * self.d_state,
|
208 |
+
],
|
209 |
+
dim=-1,
|
210 |
+
)
|
211 |
+
y = mamba_chunk_scan_combined(
|
212 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
213 |
+
dt,
|
214 |
+
A,
|
215 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
216 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
217 |
+
chunk_size=self.chunk_size,
|
218 |
+
D=self.D,
|
219 |
+
z=None,
|
220 |
+
seq_idx=seq_idx,
|
221 |
+
initial_states=initial_states,
|
222 |
+
**dt_limit_kwargs,
|
223 |
+
)
|
224 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
225 |
+
|
226 |
+
# Multiply "gate" branch and apply extra normalization layer
|
227 |
+
y = self.norm(y, z)
|
228 |
+
out = self.out_proj(y)
|
229 |
+
return out
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mamba_simple.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
14 |
+
|
15 |
+
try:
|
16 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
17 |
+
except ImportError:
|
18 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from ..ops.triton.selective_state_update import selective_state_update
|
22 |
+
except ImportError:
|
23 |
+
selective_state_update = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
27 |
+
except ImportError:
|
28 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
29 |
+
|
30 |
+
|
31 |
+
class Mamba(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
d_model,
|
35 |
+
d_state=16,
|
36 |
+
d_conv=4,
|
37 |
+
expand=2,
|
38 |
+
dt_rank="auto",
|
39 |
+
dt_min=0.001,
|
40 |
+
dt_max=0.1,
|
41 |
+
dt_init="random",
|
42 |
+
dt_scale=1.0,
|
43 |
+
dt_init_floor=1e-4,
|
44 |
+
conv_bias=True,
|
45 |
+
bias=False,
|
46 |
+
use_fast_path=True, # Fused kernel options
|
47 |
+
layer_idx=None,
|
48 |
+
device=None,
|
49 |
+
dtype=None,
|
50 |
+
):
|
51 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
52 |
+
super().__init__()
|
53 |
+
self.d_model = d_model
|
54 |
+
self.d_state = d_state
|
55 |
+
self.d_conv = d_conv
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = int(self.expand * self.d_model)
|
58 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
59 |
+
self.use_fast_path = use_fast_path
|
60 |
+
self.layer_idx = layer_idx
|
61 |
+
|
62 |
+
self.in_proj = nn.Linear(
|
63 |
+
self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
|
64 |
+
)
|
65 |
+
|
66 |
+
self.conv1d = nn.Conv1d(
|
67 |
+
in_channels=self.d_inner,
|
68 |
+
out_channels=self.d_inner,
|
69 |
+
bias=conv_bias,
|
70 |
+
kernel_size=d_conv,
|
71 |
+
groups=self.d_inner,
|
72 |
+
padding=d_conv - 1,
|
73 |
+
**factory_kwargs,
|
74 |
+
)
|
75 |
+
|
76 |
+
self.activation = "silu"
|
77 |
+
self.act = nn.SiLU()
|
78 |
+
|
79 |
+
self.x_proj = nn.Linear(
|
80 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
81 |
+
)
|
82 |
+
self.dt_proj = nn.Linear(
|
83 |
+
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
|
84 |
+
)
|
85 |
+
|
86 |
+
# Initialize special dt projection to preserve variance at initialization
|
87 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
88 |
+
if dt_init == "constant":
|
89 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
90 |
+
elif dt_init == "random":
|
91 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
96 |
+
dt = torch.exp(
|
97 |
+
torch.rand(self.d_inner, **factory_kwargs)
|
98 |
+
* (math.log(dt_max) - math.log(dt_min))
|
99 |
+
+ math.log(dt_min)
|
100 |
+
).clamp(min=dt_init_floor)
|
101 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
102 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
103 |
+
with torch.no_grad():
|
104 |
+
self.dt_proj.bias.copy_(inv_dt)
|
105 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
106 |
+
self.dt_proj.bias._no_reinit = True
|
107 |
+
|
108 |
+
# S4D real initialization
|
109 |
+
A = repeat(
|
110 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
111 |
+
"n -> d n",
|
112 |
+
d=self.d_inner,
|
113 |
+
).contiguous()
|
114 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
115 |
+
self.A_log = nn.Parameter(A_log)
|
116 |
+
self.A_log._no_weight_decay = True
|
117 |
+
|
118 |
+
# D "skip" parameter
|
119 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
120 |
+
self.D._no_weight_decay = True
|
121 |
+
|
122 |
+
self.out_proj = nn.Linear(
|
123 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, hidden_states, inference_params=None):
|
127 |
+
"""
|
128 |
+
hidden_states: (B, L, D)
|
129 |
+
Returns: same shape as hidden_states
|
130 |
+
"""
|
131 |
+
batch, seqlen, dim = hidden_states.shape
|
132 |
+
|
133 |
+
conv_state, ssm_state = None, None
|
134 |
+
if inference_params is not None:
|
135 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
136 |
+
if inference_params.seqlen_offset > 0:
|
137 |
+
# The states are updated inplace
|
138 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
139 |
+
return out
|
140 |
+
|
141 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
142 |
+
xz = rearrange(
|
143 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
144 |
+
"d (b l) -> b d l",
|
145 |
+
l=seqlen,
|
146 |
+
)
|
147 |
+
if self.in_proj.bias is not None:
|
148 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
149 |
+
|
150 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
151 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
152 |
+
if (
|
153 |
+
self.use_fast_path
|
154 |
+
and causal_conv1d_fn is not None
|
155 |
+
and inference_params is None
|
156 |
+
): # Doesn't support outputting the states
|
157 |
+
out = mamba_inner_fn(
|
158 |
+
xz,
|
159 |
+
self.conv1d.weight,
|
160 |
+
self.conv1d.bias,
|
161 |
+
self.x_proj.weight,
|
162 |
+
self.dt_proj.weight,
|
163 |
+
self.out_proj.weight,
|
164 |
+
self.out_proj.bias,
|
165 |
+
A,
|
166 |
+
None, # input-dependent B
|
167 |
+
None, # input-dependent C
|
168 |
+
self.D.float(),
|
169 |
+
delta_bias=self.dt_proj.bias.float(),
|
170 |
+
delta_softplus=True,
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
x, z = xz.chunk(2, dim=1)
|
174 |
+
# Compute short convolution
|
175 |
+
if conv_state is not None:
|
176 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
177 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
178 |
+
conv_state.copy_(
|
179 |
+
F.pad(x, (self.d_conv - x.shape[-1], 0))
|
180 |
+
) # Update state (B D W)
|
181 |
+
if causal_conv1d_fn is None:
|
182 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
183 |
+
else:
|
184 |
+
assert self.activation in ["silu", "swish"]
|
185 |
+
x = causal_conv1d_fn(
|
186 |
+
x=x,
|
187 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
188 |
+
bias=self.conv1d.bias,
|
189 |
+
activation=self.activation,
|
190 |
+
)
|
191 |
+
|
192 |
+
# We're careful here about the layout, to avoid extra transposes.
|
193 |
+
# We want dt to have d as the slowest moving dimension
|
194 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
195 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
196 |
+
dt, B, C = torch.split(
|
197 |
+
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
198 |
+
)
|
199 |
+
dt = self.dt_proj.weight @ dt.t()
|
200 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
201 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
202 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
203 |
+
assert self.activation in ["silu", "swish"]
|
204 |
+
y = selective_scan_fn(
|
205 |
+
x,
|
206 |
+
dt,
|
207 |
+
A,
|
208 |
+
B,
|
209 |
+
C,
|
210 |
+
self.D.float(),
|
211 |
+
z=z,
|
212 |
+
delta_bias=self.dt_proj.bias.float(),
|
213 |
+
delta_softplus=True,
|
214 |
+
return_last_state=ssm_state is not None,
|
215 |
+
)
|
216 |
+
if ssm_state is not None:
|
217 |
+
y, last_state = y
|
218 |
+
ssm_state.copy_(last_state)
|
219 |
+
y = rearrange(y, "b d l -> b l d")
|
220 |
+
out = self.out_proj(y)
|
221 |
+
return out
|
222 |
+
|
223 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
224 |
+
dtype = hidden_states.dtype
|
225 |
+
assert (
|
226 |
+
hidden_states.shape[1] == 1
|
227 |
+
), "Only support decoding with 1 token at a time for now"
|
228 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
229 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
230 |
+
|
231 |
+
# Conv step
|
232 |
+
if causal_conv1d_update is None:
|
233 |
+
conv_state.copy_(
|
234 |
+
torch.roll(conv_state, shifts=-1, dims=-1)
|
235 |
+
) # Update state (B D W)
|
236 |
+
conv_state[:, :, -1] = x
|
237 |
+
x = torch.sum(
|
238 |
+
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
239 |
+
) # (B D)
|
240 |
+
if self.conv1d.bias is not None:
|
241 |
+
x = x + self.conv1d.bias
|
242 |
+
x = self.act(x).to(dtype=dtype)
|
243 |
+
else:
|
244 |
+
x = causal_conv1d_update(
|
245 |
+
x,
|
246 |
+
conv_state,
|
247 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
248 |
+
self.conv1d.bias,
|
249 |
+
self.activation,
|
250 |
+
)
|
251 |
+
|
252 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
253 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
254 |
+
# Don't add dt_bias here
|
255 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
256 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
257 |
+
|
258 |
+
# SSM step
|
259 |
+
if selective_state_update is None:
|
260 |
+
# Discretize A and B
|
261 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
262 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
263 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
264 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
265 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
266 |
+
y = y + self.D.to(dtype) * x
|
267 |
+
y = y * self.act(z) # (B D)
|
268 |
+
else:
|
269 |
+
y = selective_state_update(
|
270 |
+
ssm_state,
|
271 |
+
x,
|
272 |
+
dt,
|
273 |
+
A,
|
274 |
+
B,
|
275 |
+
C,
|
276 |
+
self.D,
|
277 |
+
z=z,
|
278 |
+
dt_bias=self.dt_proj.bias,
|
279 |
+
dt_softplus=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
out = self.out_proj(y)
|
283 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
284 |
+
|
285 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
286 |
+
device = self.out_proj.weight.device
|
287 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
288 |
+
conv_state = torch.zeros(
|
289 |
+
batch_size,
|
290 |
+
self.d_model * self.expand,
|
291 |
+
self.d_conv,
|
292 |
+
device=device,
|
293 |
+
dtype=conv_dtype,
|
294 |
+
)
|
295 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
296 |
+
# ssm_dtype = torch.float32
|
297 |
+
ssm_state = torch.zeros(
|
298 |
+
batch_size,
|
299 |
+
self.d_model * self.expand,
|
300 |
+
self.d_state,
|
301 |
+
device=device,
|
302 |
+
dtype=ssm_dtype,
|
303 |
+
)
|
304 |
+
return conv_state, ssm_state
|
305 |
+
|
306 |
+
def _get_states_from_cache(
|
307 |
+
self, inference_params, batch_size, initialize_states=False
|
308 |
+
):
|
309 |
+
assert self.layer_idx is not None
|
310 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
311 |
+
batch_shape = (batch_size,)
|
312 |
+
conv_state = torch.zeros(
|
313 |
+
batch_size,
|
314 |
+
self.d_model * self.expand,
|
315 |
+
self.d_conv,
|
316 |
+
device=self.conv1d.weight.device,
|
317 |
+
dtype=self.conv1d.weight.dtype,
|
318 |
+
)
|
319 |
+
ssm_state = torch.zeros(
|
320 |
+
batch_size,
|
321 |
+
self.d_model * self.expand,
|
322 |
+
self.d_state,
|
323 |
+
device=self.dt_proj.weight.device,
|
324 |
+
dtype=self.dt_proj.weight.dtype,
|
325 |
+
# dtype=torch.float32,
|
326 |
+
)
|
327 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (
|
328 |
+
conv_state,
|
329 |
+
ssm_state,
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
333 |
+
self.layer_idx
|
334 |
+
]
|
335 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
336 |
+
if initialize_states:
|
337 |
+
conv_state.zero_()
|
338 |
+
ssm_state.zero_()
|
339 |
+
return conv_state, ssm_state
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mha.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
try:
|
11 |
+
from flash_attn import flash_attn_with_kvcache
|
12 |
+
except ImportError:
|
13 |
+
flash_attn_with_kvcache = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
17 |
+
except ImportError:
|
18 |
+
RotaryEmbedding = None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
22 |
+
except ImportError:
|
23 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
24 |
+
|
25 |
+
|
26 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
27 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
28 |
+
# Pre-allocate memory for key-values for inference.
|
29 |
+
num_heads, head_dim = kv.shape[-2:]
|
30 |
+
assert layer_idx in inference_params.key_value_memory_dict
|
31 |
+
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
|
32 |
+
# Adjust key and value for inference
|
33 |
+
batch_start = inference_params.batch_size_offset
|
34 |
+
batch_end = batch_start + kv.shape[0]
|
35 |
+
sequence_start = inference_params.seqlen_offset
|
36 |
+
sequence_end = sequence_start + kv.shape[1]
|
37 |
+
assert batch_end <= kv_cache.shape[0]
|
38 |
+
assert sequence_end <= kv_cache.shape[1]
|
39 |
+
assert kv_cache is not None
|
40 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
41 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
42 |
+
|
43 |
+
|
44 |
+
class MHA(nn.Module):
|
45 |
+
"""Multi-head self-attention and cross-attention"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
embed_dim,
|
50 |
+
num_heads,
|
51 |
+
num_heads_kv=None,
|
52 |
+
head_dim=None, # If None, use embed_dim // num_heads
|
53 |
+
mlp_dim=0,
|
54 |
+
qkv_proj_bias=True,
|
55 |
+
out_proj_bias=True,
|
56 |
+
softmax_scale=None,
|
57 |
+
causal=False,
|
58 |
+
layer_idx=None,
|
59 |
+
d_conv=0,
|
60 |
+
rotary_emb_dim=0,
|
61 |
+
rotary_emb_base=10000.0,
|
62 |
+
rotary_emb_interleaved=False,
|
63 |
+
device=None,
|
64 |
+
dtype=None,
|
65 |
+
) -> None:
|
66 |
+
"""
|
67 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
68 |
+
return_residual: whether to return the input x along with the output. This is for
|
69 |
+
performance reason: for post-norm architecture, returning the input allows us
|
70 |
+
to fuse the backward of nn.Linear with the residual connection.
|
71 |
+
"""
|
72 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
73 |
+
super().__init__()
|
74 |
+
self.embed_dim = embed_dim
|
75 |
+
self.layer_idx = layer_idx
|
76 |
+
self.d_conv = d_conv
|
77 |
+
self.rotary_emb_dim = rotary_emb_dim
|
78 |
+
self.softmax_scale = softmax_scale
|
79 |
+
self.causal = causal
|
80 |
+
|
81 |
+
self.num_heads = num_heads
|
82 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
83 |
+
assert (
|
84 |
+
self.num_heads % self.num_heads_kv == 0
|
85 |
+
), "num_heads must be divisible by num_heads_kv"
|
86 |
+
if head_dim is None:
|
87 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
88 |
+
self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
|
89 |
+
self.mlp_dim = math.ceil(mlp_dim / 256) * 256
|
90 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
91 |
+
out_dim = self.head_dim * self.num_heads
|
92 |
+
|
93 |
+
if self.rotary_emb_dim > 0:
|
94 |
+
assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
|
95 |
+
self.rotary_emb = RotaryEmbedding(
|
96 |
+
self.rotary_emb_dim,
|
97 |
+
base=rotary_emb_base,
|
98 |
+
interleaved=rotary_emb_interleaved,
|
99 |
+
device=device,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
|
103 |
+
if self.d_conv > 0:
|
104 |
+
self.conv1d = nn.Conv1d(
|
105 |
+
qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
|
106 |
+
**factory_kwargs
|
107 |
+
)
|
108 |
+
self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
109 |
+
|
110 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
111 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
112 |
+
device = self.out_proj.weight.device
|
113 |
+
if self.d_conv > 0:
|
114 |
+
conv_state = torch.zeros(
|
115 |
+
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
conv_state = None
|
119 |
+
kv_cache = torch.empty(
|
120 |
+
batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
|
121 |
+
)
|
122 |
+
return kv_cache, conv_state
|
123 |
+
|
124 |
+
def _update_kv_cache(self, kv, inference_params):
|
125 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
126 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
127 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
128 |
+
|
129 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
130 |
+
"""
|
131 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
132 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
133 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
134 |
+
"""
|
135 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
136 |
+
if self.rotary_emb_dim > 0:
|
137 |
+
self.rotary_emb._update_cos_sin_cache(
|
138 |
+
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
139 |
+
)
|
140 |
+
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
141 |
+
else:
|
142 |
+
rotary_cos, rotary_sin = None, None
|
143 |
+
batch = q.shape[0]
|
144 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
145 |
+
kv_cache = kv_cache[:batch]
|
146 |
+
cache_seqlens = (
|
147 |
+
inference_params.lengths_per_sample[:batch]
|
148 |
+
if inference_params.lengths_per_sample is not None
|
149 |
+
else inference_params.seqlen_offset
|
150 |
+
)
|
151 |
+
assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
|
152 |
+
context = flash_attn_with_kvcache(
|
153 |
+
q,
|
154 |
+
kv_cache[:, :, 0],
|
155 |
+
kv_cache[:, :, 1],
|
156 |
+
kv[:, :, 0],
|
157 |
+
kv[:, :, 1],
|
158 |
+
rotary_cos=rotary_cos,
|
159 |
+
rotary_sin=rotary_sin,
|
160 |
+
cache_seqlens=cache_seqlens,
|
161 |
+
softmax_scale=self.softmax_scale,
|
162 |
+
causal=self.causal,
|
163 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
164 |
+
)
|
165 |
+
return context
|
166 |
+
|
167 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
168 |
+
"""Write kv to inference_params, then do attention"""
|
169 |
+
if (
|
170 |
+
inference_params.seqlen_offset == 0
|
171 |
+
or flash_attn_with_kvcache is None
|
172 |
+
):
|
173 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
174 |
+
kv = self._update_kv_cache(kv, inference_params)
|
175 |
+
k, v = kv.unbind(dim=-3)
|
176 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
177 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
178 |
+
return F.scaled_dot_product_attention(
|
179 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
180 |
+
).transpose(1, 2)
|
181 |
+
else:
|
182 |
+
batch = q.shape[0]
|
183 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
184 |
+
kv_cache = kv_cache[:batch]
|
185 |
+
cache_seqlens = (
|
186 |
+
inference_params.lengths_per_sample[:batch]
|
187 |
+
if inference_params.lengths_per_sample is not None
|
188 |
+
else inference_params.seqlen_offset
|
189 |
+
)
|
190 |
+
return flash_attn_with_kvcache(
|
191 |
+
q,
|
192 |
+
kv_cache[:, :, 0],
|
193 |
+
kv_cache[:, :, 1],
|
194 |
+
kv[:, :, 0],
|
195 |
+
kv[:, :, 1],
|
196 |
+
cache_seqlens=cache_seqlens,
|
197 |
+
softmax_scale=self.softmax_scale,
|
198 |
+
causal=self.causal,
|
199 |
+
)
|
200 |
+
|
201 |
+
def forward(self, x, inference_params=None):
|
202 |
+
"""
|
203 |
+
Arguments:
|
204 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
205 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
206 |
+
is the is the sum of the sequence lengths in the batch.
|
207 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
208 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
209 |
+
"""
|
210 |
+
if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
|
211 |
+
inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
|
212 |
+
x.shape[0], inference_params.max_seqlen, dtype=x.dtype
|
213 |
+
)
|
214 |
+
seqlen_offset = (
|
215 |
+
0
|
216 |
+
if inference_params is None
|
217 |
+
else (
|
218 |
+
inference_params.lengths_per_sample
|
219 |
+
if inference_params.lengths_per_sample is not None
|
220 |
+
else inference_params.seqlen_offset
|
221 |
+
)
|
222 |
+
)
|
223 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
224 |
+
qkv = self.in_proj(x)
|
225 |
+
if self.mlp_dim > 0:
|
226 |
+
qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
|
227 |
+
x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
|
228 |
+
x_mlp = x_mlp_up * F.silu(x_mlp_gate)
|
229 |
+
if self.d_conv > 0:
|
230 |
+
# The inference code for conv1d is pretty messy, should clean it up
|
231 |
+
if (inference_params is None or inference_params.seqlen_offset == 0):
|
232 |
+
if causal_conv1d_fn is None:
|
233 |
+
qkv = rearrange(
|
234 |
+
self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
|
235 |
+
).contiguous()
|
236 |
+
else:
|
237 |
+
qkv = causal_conv1d_fn(
|
238 |
+
qkv.transpose(1, 2),
|
239 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
240 |
+
self.conv1d.bias
|
241 |
+
).transpose(1, 2)
|
242 |
+
if inference_params is not None:
|
243 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
244 |
+
# If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
245 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
246 |
+
qkv_t = rearrange(qkv, "b l d -> b d l")
|
247 |
+
conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
|
248 |
+
else:
|
249 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
250 |
+
assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
251 |
+
qkv = qkv.squeeze(1)
|
252 |
+
# Conv step
|
253 |
+
if causal_conv1d_update is None:
|
254 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
255 |
+
conv_state[:, :, -1] = qkv
|
256 |
+
qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
257 |
+
if self.conv1d.bias is not None:
|
258 |
+
qkv = qkv + self.conv1d.bias
|
259 |
+
else:
|
260 |
+
qkv = causal_conv1d_update(
|
261 |
+
qkv,
|
262 |
+
conv_state,
|
263 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
264 |
+
self.conv1d.bias
|
265 |
+
)
|
266 |
+
qkv = qkv.unsqueeze(1)
|
267 |
+
q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
|
268 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
269 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
270 |
+
if (
|
271 |
+
inference_params is None
|
272 |
+
or inference_params.seqlen_offset == 0
|
273 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
274 |
+
):
|
275 |
+
if self.rotary_emb_dim > 0:
|
276 |
+
q, kv = self.rotary_emb(
|
277 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
278 |
+
)
|
279 |
+
if inference_params is None:
|
280 |
+
k, v = kv.unbind(dim=-3)
|
281 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
282 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
283 |
+
context = F.scaled_dot_product_attention(
|
284 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
285 |
+
).transpose(1, 2)
|
286 |
+
else:
|
287 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
288 |
+
else:
|
289 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
290 |
+
context = rearrange(context, "... h d -> ... (h d)")
|
291 |
+
if self.mlp_dim > 0:
|
292 |
+
context = torch.cat([context, x_mlp], dim=-1)
|
293 |
+
out = self.out_proj(context)
|
294 |
+
return out
|
build/torch25-cxx11-cu121-x86_64-linux/mamba_ssm/modules/mlp.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class GatedMLP(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_features,
|
10 |
+
hidden_features=None,
|
11 |
+
out_features=None,
|
12 |
+
activation=F.silu,
|
13 |
+
bias=False,
|
14 |
+
multiple_of=128,
|
15 |
+
device=None,
|
16 |
+
dtype=None,
|
17 |
+
):
|
18 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features if out_features is not None else in_features
|
21 |
+
hidden_features = (
|
22 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
23 |
+
)
|
24 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
25 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
|
26 |
+
self.activation = activation
|
27 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
y = self.fc1(x)
|
31 |
+
y, gate = y.chunk(2, dim=-1)
|
32 |
+
y = y * self.activation(gate)
|
33 |
+
y = self.fc2(y)
|
34 |
+
return y
|