drbh commited on
Commit
a7165c8
·
1 Parent(s): bdaa00d

feat: include source and enable build

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. build.toml +126 -0
  3. flake.lock +117 -0
  4. flake.nix +17 -0
  5. flash_attn/flash_api.cpp +1496 -0
  6. flash_attn/src/alibi.h +75 -0
  7. flash_attn/src/block_info.h +49 -0
  8. flash_attn/src/dropout.h +95 -0
  9. flash_attn/src/flash.h +194 -0
  10. flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu +14 -0
  11. flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu +14 -0
  12. flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu +14 -0
  13. flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu +14 -0
  14. flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +14 -0
  15. flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +14 -0
  16. flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +14 -0
  17. flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +14 -0
  18. flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu +14 -0
  19. flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu +14 -0
  20. flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu +14 -0
  21. flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu +14 -0
  22. flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu +14 -0
  23. flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu +14 -0
  24. flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu +14 -0
  25. flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu +14 -0
  26. flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu +14 -0
  27. flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu +14 -0
  28. flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu +14 -0
  29. flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu +14 -0
  30. flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu +14 -0
  31. flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu +14 -0
  32. flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu +14 -0
  33. flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu +14 -0
  34. flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu +14 -0
  35. flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu +14 -0
  36. flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu +14 -0
  37. flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu +14 -0
  38. flash_attn/src/flash_bwd_kernel.h +839 -0
  39. flash_attn/src/flash_bwd_launch_template.h +328 -0
  40. flash_attn/src/flash_bwd_preprocess_kernel.h +379 -0
  41. flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu +14 -0
  42. flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu +14 -0
  43. flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu +14 -0
  44. flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu +14 -0
  45. flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu +14 -0
  46. flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +14 -0
  47. flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu +14 -0
  48. flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +14 -0
  49. flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu +14 -0
  50. flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu +14 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .bak
build.toml ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "flash_attn"
3
+
4
+ [torch]
5
+ src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
6
+
7
+ [kernel.flash_attn]
8
+ cuda-capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
9
+ src = [
10
+ "flash_attn/flash_api.cpp",
11
+ "flash_attn/src/philox_unpack.cuh",
12
+ "flash_attn/src/namespace_config.h",
13
+ "flash_attn/src/hardware_info.h",
14
+ "flash_attn/src/flash.h",
15
+ "flash_attn/src/static_switch.h",
16
+ #
17
+ "flash_attn/src/alibi.h",
18
+ "flash_attn/src/block_info.h",
19
+ "flash_attn/src/dropout.h",
20
+
21
+ # TODO: dont skip bwd kernels
22
+
23
+ # "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
24
+ # "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
25
+ # "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
26
+ # "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
27
+ # "flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
28
+ # "flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
29
+ # "flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
30
+ # "flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
31
+ # "flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
32
+ # "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
33
+ # "flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
34
+ # "flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
35
+ # "flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
36
+ # "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
37
+ # "flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
38
+ # "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
39
+ # "flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
40
+ # "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
41
+ # "flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
42
+ # "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
43
+ # "flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
44
+ # "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
45
+ # "flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
46
+ # "flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
47
+ # "flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
48
+ # "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
49
+ # "flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
50
+ # "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
51
+ # "flash_attn/src/flash_bwd_kernel.h",
52
+ # "flash_attn/src/flash_bwd_launch_template.h",
53
+ # "flash_attn/src/flash_bwd_preprocess_kernel.h",
54
+
55
+ "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
56
+ "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
57
+ "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
58
+ "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
59
+ "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
60
+ "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
61
+ "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
62
+ "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
63
+ "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
64
+ "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
65
+ "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
66
+ "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
67
+ "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
68
+ "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
69
+ "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
70
+ "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
71
+ "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
72
+ "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
73
+ "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
74
+ "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
75
+ "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
76
+ "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
77
+ "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
78
+ "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
79
+ "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
80
+ "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
81
+ "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
82
+ "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
83
+ "flash_attn/src/flash_fwd_kernel.h",
84
+ "flash_attn/src/flash_fwd_launch_template.h",
85
+ "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
86
+ "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
87
+ "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
88
+ "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
89
+ "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
90
+ "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
91
+ "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
92
+ "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
93
+ "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
94
+ "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
95
+ "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
96
+ "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
97
+ "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
98
+ "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
99
+ "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
100
+ "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
101
+ "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
102
+ "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
103
+ "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
104
+ "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
105
+ "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
106
+ "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
107
+ "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
108
+ "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
109
+ "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
110
+ "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
111
+ "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
112
+ "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
113
+ "flash_attn/src/flash.h",
114
+ "flash_attn/src/generate_kernels.py",
115
+ "flash_attn/src/hardware_info.h",
116
+ "flash_attn/src/kernel_traits.h",
117
+ "flash_attn/src/mask.h",
118
+ "flash_attn/src/namespace_config.h",
119
+ "flash_attn/src/philox.cuh",
120
+ "flash_attn/src/philox_unpack.cuh",
121
+ "flash_attn/src/rotary.h",
122
+ "flash_attn/src/softmax.h",
123
+ "flash_attn/src/static_switch.h",
124
+ "flash_attn/src/utils.h",
125
+ ]
126
+ depends = ["torch", "cutlass_3_6"]
flake.lock ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1733328505,
6
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs",
41
+ "rocm-nix": "rocm-nix"
42
+ },
43
+ "locked": {
44
+ "lastModified": 1742582705,
45
+ "narHash": "sha256-1Vq5IauC/8fjBqcnMbDzckLN/XLIGwWr3/c2Wt3I2vs=",
46
+ "ref": "refs/heads/main",
47
+ "rev": "e06e3e72947fad8bfd2c1eb5d8e7f5ec01d359d6",
48
+ "revCount": 103,
49
+ "type": "git",
50
+ "url": "ssh://[email protected]/huggingface/kernel-builder"
51
+ },
52
+ "original": {
53
+ "type": "git",
54
+ "url": "ssh://[email protected]/huggingface/kernel-builder"
55
+ }
56
+ },
57
+ "nixpkgs": {
58
+ "locked": {
59
+ "lastModified": 1740557110,
60
+ "narHash": "sha256-D2waFyJkaepTchTrGVAIfCd/YP+37bgXWg9cXwuxuT0=",
61
+ "owner": "nixos",
62
+ "repo": "nixpkgs",
63
+ "rev": "b89a821293c3872992137114d0db9a791243a41b",
64
+ "type": "github"
65
+ },
66
+ "original": {
67
+ "owner": "nixos",
68
+ "ref": "nixos-unstable-small",
69
+ "repo": "nixpkgs",
70
+ "type": "github"
71
+ }
72
+ },
73
+ "rocm-nix": {
74
+ "inputs": {
75
+ "nixpkgs": [
76
+ "kernel-builder",
77
+ "nixpkgs"
78
+ ]
79
+ },
80
+ "locked": {
81
+ "lastModified": 1742285724,
82
+ "narHash": "sha256-2QQn9fzmF/SKW082kXpSrEBgfmwKO2RNT5R91Fn/K4M=",
83
+ "owner": "huggingface",
84
+ "repo": "rocm-nix",
85
+ "rev": "a90de1c2e5698b2f4fe984b5f0faf052f466be49",
86
+ "type": "github"
87
+ },
88
+ "original": {
89
+ "owner": "huggingface",
90
+ "repo": "rocm-nix",
91
+ "type": "github"
92
+ }
93
+ },
94
+ "root": {
95
+ "inputs": {
96
+ "kernel-builder": "kernel-builder"
97
+ }
98
+ },
99
+ "systems": {
100
+ "locked": {
101
+ "lastModified": 1681028828,
102
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
103
+ "owner": "nix-systems",
104
+ "repo": "default",
105
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "nix-systems",
110
+ "repo": "default",
111
+ "type": "github"
112
+ }
113
+ }
114
+ },
115
+ "root": "root",
116
+ "version": 7
117
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for ReLU kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "git+ssh://[email protected]/huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
flash_attn/flash_api.cpp ADDED
@@ -0,0 +1,1496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
6
+ // #include <torch/python.h>
7
+ #include <torch/nn/functional.h>
8
+ #include <c10/cuda/CUDAGuard.h>
9
+ #include <c10/cuda/CUDAStream.h>
10
+ #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
11
+ #include "src/philox_unpack.cuh" // For at::cuda::philox::unpack
12
+
13
+ #include <cutlass/numeric_types.h>
14
+
15
+ #include "src/namespace_config.h"
16
+ #include "src/hardware_info.h"
17
+ #include "src/flash.h"
18
+ #include "src/static_switch.h"
19
+
20
+ #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
21
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
22
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
23
+
24
+ namespace FLASH_NAMESPACE {
25
+
26
+ void set_params_fprop(Flash_fwd_params &params,
27
+ // sizes
28
+ const size_t b,
29
+ const size_t seqlen_q,
30
+ const size_t seqlen_k,
31
+ const size_t seqlen_q_rounded,
32
+ const size_t seqlen_k_rounded,
33
+ const size_t h,
34
+ const size_t h_k,
35
+ const size_t d,
36
+ const size_t d_rounded,
37
+ // device pointers
38
+ const at::Tensor q,
39
+ const at::Tensor k,
40
+ const at::Tensor v,
41
+ at::Tensor out,
42
+ void *cu_seqlens_q_d,
43
+ void *cu_seqlens_k_d,
44
+ void *seqused_k,
45
+ void *p_d,
46
+ void *softmax_lse_d,
47
+ float p_dropout,
48
+ float softmax_scale,
49
+ int window_size_left,
50
+ int window_size_right,
51
+ const float softcap,
52
+ bool seqlenq_ngroups_swapped=false,
53
+ const bool unpadded_lse=false) {
54
+
55
+ // Reset the parameters
56
+ params = {};
57
+
58
+ params.is_bf16 = q.dtype() == torch::kBFloat16;
59
+
60
+ // Set the pointers and strides.
61
+ params.q_ptr = q.data_ptr();
62
+ params.k_ptr = k.data_ptr();
63
+ params.v_ptr = v.data_ptr();
64
+ // All stride are in elements, not bytes.
65
+ params.q_row_stride = q.stride(-3);
66
+ params.k_row_stride = k.stride(-3);
67
+ params.v_row_stride = v.stride(-3);
68
+ params.q_head_stride = q.stride(-2);
69
+ params.k_head_stride = k.stride(-2);
70
+ params.v_head_stride = v.stride(-2);
71
+ params.o_ptr = out.data_ptr();
72
+ params.o_row_stride = out.stride(-3);
73
+ params.o_head_stride = out.stride(-2);
74
+
75
+ if (cu_seqlens_q_d == nullptr) {
76
+ params.q_batch_stride = q.stride(0);
77
+ params.k_batch_stride = k.stride(0);
78
+ params.v_batch_stride = v.stride(0);
79
+ params.o_batch_stride = out.stride(0);
80
+ if (seqlenq_ngroups_swapped) {
81
+ params.q_batch_stride *= seqlen_q;
82
+ params.o_batch_stride *= seqlen_q;
83
+ }
84
+ }
85
+
86
+ params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
87
+ params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
88
+ params.seqused_k = static_cast<int *>(seqused_k);
89
+
90
+ // P = softmax(QK^T)
91
+ params.p_ptr = p_d;
92
+
93
+ // Softmax sum
94
+ params.softmax_lse_ptr = softmax_lse_d;
95
+
96
+ // Set the dimensions.
97
+ params.b = b;
98
+ params.h = h;
99
+ params.h_k = h_k;
100
+ params.h_h_k_ratio = h / h_k;
101
+ params.seqlen_q = seqlen_q;
102
+ params.seqlen_k = seqlen_k;
103
+ params.seqlen_q_rounded = seqlen_q_rounded;
104
+ params.seqlen_k_rounded = seqlen_k_rounded;
105
+ params.d = d;
106
+ params.d_rounded = d_rounded;
107
+
108
+ // Set the different scale values.
109
+ #ifdef FLASHATTENTION_DISABLE_SOFTCAP
110
+ TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
111
+ #endif
112
+ if (softcap > 0.0) {
113
+ params.softcap = softmax_scale / softcap;
114
+ params.scale_softmax = softcap;
115
+ params.scale_softmax_log2 = softcap * M_LOG2E;
116
+ } else{
117
+ // Remove potential NaN
118
+ params.softcap = 0.0;
119
+ params.scale_softmax = softmax_scale;
120
+ params.scale_softmax_log2 = softmax_scale * M_LOG2E;
121
+ }
122
+
123
+ // Set this to probability of keeping an element to simplify things.
124
+ params.p_dropout = 1.f - p_dropout;
125
+ // Convert p from float to int so we don't have to convert the random uint to float to compare.
126
+ // [Minor] We want to round down since when we do the comparison we use <= instead of <
127
+ // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
128
+ // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
129
+ params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
130
+ params.rp_dropout = 1.f / params.p_dropout;
131
+ params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
132
+ TORCH_CHECK(p_dropout < 1.f);
133
+ #ifdef FLASHATTENTION_DISABLE_DROPOUT
134
+ TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
135
+ #endif
136
+
137
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
138
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
139
+ params.is_causal = window_size_left < 0 && window_size_right == 0;
140
+
141
+ if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
142
+ if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
143
+ params.window_size_left = window_size_left;
144
+ params.window_size_right = window_size_right;
145
+
146
+ #ifdef FLASHATTENTION_DISABLE_LOCAL
147
+ TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
148
+ "This flash attention build does not support local attention.");
149
+ #endif
150
+
151
+ params.is_seqlens_k_cumulative = true;
152
+
153
+ #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
154
+ TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
155
+ #endif
156
+
157
+ params.unpadded_lse = unpadded_lse;
158
+ params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
159
+ }
160
+
161
+ void set_params_dgrad(Flash_bwd_params &params,
162
+ // sizes
163
+ const size_t b,
164
+ const size_t seqlen_q,
165
+ const size_t seqlen_k,
166
+ const size_t seqlen_q_rounded,
167
+ const size_t seqlen_k_rounded,
168
+ const size_t h,
169
+ const size_t h_k,
170
+ const size_t d,
171
+ const size_t d_rounded,
172
+ // device pointers
173
+ const at::Tensor q,
174
+ const at::Tensor k,
175
+ const at::Tensor v,
176
+ const at::Tensor out,
177
+ const at::Tensor dout,
178
+ at::Tensor dq,
179
+ at::Tensor dk,
180
+ at::Tensor dv,
181
+ void *cu_seqlens_q_d,
182
+ void *cu_seqlens_k_d,
183
+ void *dq_accum_d,
184
+ void *dk_accum_d,
185
+ void *dv_accum_d,
186
+ void *softmax_lse_d,
187
+ void *dsoftmax_sum_d,
188
+ float p_dropout,
189
+ float softmax_scale,
190
+ int window_size_left,
191
+ int window_size_right,
192
+ const float softcap,
193
+ bool deterministic,
194
+ const bool unpadded_lse) {
195
+
196
+ set_params_fprop(params,
197
+ b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
198
+ q, k, v, out,
199
+ cu_seqlens_q_d,
200
+ cu_seqlens_k_d,
201
+ nullptr,
202
+ nullptr,
203
+ softmax_lse_d,
204
+ p_dropout,
205
+ softmax_scale,
206
+ window_size_left,
207
+ window_size_right,
208
+ softcap,
209
+ false, // seqlenq_ngroups_swapped
210
+ unpadded_lse);
211
+
212
+ // Set the pointers and strides.
213
+ params.do_ptr = dout.data_ptr();
214
+ params.do_row_stride = dout.stride(-3);
215
+ params.do_head_stride = dout.stride(-2);
216
+ params.dq_ptr = dq.data_ptr();
217
+ params.dk_ptr = dk.data_ptr();
218
+ params.dv_ptr = dv.data_ptr();
219
+ params.dq_row_stride = dq.stride(-3);
220
+ params.dk_row_stride = dk.stride(-3);
221
+ params.dv_row_stride = dv.stride(-3);
222
+ params.dq_head_stride = dq.stride(-2);
223
+ params.dk_head_stride = dk.stride(-2);
224
+ params.dv_head_stride = dv.stride(-2);
225
+
226
+ if (cu_seqlens_q_d == nullptr) {
227
+ params.do_batch_stride = dout.stride(0);
228
+ params.dq_batch_stride = dq.stride(0);
229
+ params.dk_batch_stride = dk.stride(0);
230
+ params.dv_batch_stride = dv.stride(0);
231
+ }
232
+
233
+ params.dq_accum_ptr = dq_accum_d;
234
+ params.dk_accum_ptr = dk_accum_d;
235
+ params.dv_accum_ptr = dv_accum_d;
236
+
237
+ // Softmax sum
238
+ params.dsoftmax_sum = dsoftmax_sum_d;
239
+
240
+ params.deterministic = deterministic;
241
+ }
242
+
243
+ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
244
+ FP16_SWITCH(!params.is_bf16, [&] {
245
+ HEADDIM_SWITCH(params.d, [&] {
246
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
247
+ if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
248
+ run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
249
+ } else {
250
+ run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
251
+ }
252
+ });
253
+ });
254
+ });
255
+ }
256
+
257
+ // Find the number of splits that maximizes the occupancy. For example, if we have
258
+ // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
259
+ // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
260
+ // splits as that would incur more HBM reads/writes.
261
+ // So we find the best efficiency, then find the smallest number of splits that gets 85%
262
+ // of the best efficiency.
263
+ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
264
+ // If we have enough to almost fill the SMs, then just use 1 split
265
+ if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
266
+ max_splits = std::min({max_splits, num_SMs, num_n_blocks});
267
+ float max_efficiency = 0.f;
268
+ std::vector<float> efficiency;
269
+ efficiency.reserve(max_splits);
270
+ auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
271
+ // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
272
+ // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
273
+ // (i.e. it's 11 splits anyway).
274
+ // So we check if the number of blocks per split is the same as the previous num_splits.
275
+ auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
276
+ return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
277
+ };
278
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
279
+ if (!is_split_eligible(num_splits)) {
280
+ efficiency.push_back(0.f);
281
+ } else {
282
+ float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
283
+ float eff = n_waves / ceil(n_waves);
284
+ // printf("num_splits = %d, eff = %f\n", num_splits, eff);
285
+ if (eff > max_efficiency) { max_efficiency = eff; }
286
+ efficiency.push_back(eff);
287
+ }
288
+ }
289
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
290
+ if (!is_split_eligible(num_splits)) { continue; }
291
+ if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
292
+ // printf("num_splits chosen = %d\n", num_splits);
293
+ return num_splits;
294
+ }
295
+ }
296
+ return 1;
297
+ }
298
+
299
+ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
300
+ const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
301
+ const int head_size_rounded, const float p_dropout,
302
+ const int num_splits, const int num_sm, struct c10::TensorOptions opts) {
303
+
304
+ // This needs to match with run_mha_fwd_splitkv_dispatch
305
+ const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
306
+ const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
307
+ // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
308
+ // In any case we don't expect seqlen_q to be larger than 64 for inference.
309
+ const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
310
+ params.num_splits = num_splits;
311
+ at::Tensor softmax_lse_accum;
312
+ at::Tensor out_accum;
313
+
314
+ if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
315
+ if (num_splits < 1) {
316
+ // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
317
+ params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
318
+ }
319
+ if (params.num_splits > 1) {
320
+ softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
321
+ out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
322
+ params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
323
+ params.oaccum_ptr = out_accum.data_ptr();
324
+ }
325
+ TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
326
+ }
327
+
328
+ return std::make_tuple(softmax_lse_accum, out_accum);
329
+ }
330
+
331
+ void set_params_alibi(Flash_fwd_params &params, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
332
+ #ifdef FLASHATTENTION_DISABLE_ALIBI
333
+ TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
334
+ params.alibi_slopes_ptr = nullptr;
335
+ #else
336
+ if (alibi_slopes_.has_value()) {
337
+ auto alibi_slopes = alibi_slopes_.value();
338
+ TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
339
+ CHECK_DEVICE(alibi_slopes);
340
+ TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
341
+ TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
342
+ params.alibi_slopes_ptr = alibi_slopes.data_ptr();
343
+ params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
344
+ } else {
345
+ params.alibi_slopes_ptr = nullptr;
346
+ }
347
+ #endif
348
+ }
349
+
350
+ std::vector<at::Tensor>
351
+ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
352
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
353
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
354
+ std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
355
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
356
+ const float p_dropout,
357
+ const float softmax_scale,
358
+ bool is_causal,
359
+ int window_size_left,
360
+ int window_size_right,
361
+ const float softcap,
362
+ const bool return_softmax,
363
+ std::optional<at::Generator> gen_) {
364
+
365
+ // Otherwise the kernel will be launched from cuda:0 device
366
+ at::cuda::CUDAGuard device_guard{q.device()};
367
+
368
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
369
+ bool is_sm8x_min = cc_major >= 8;
370
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
371
+
372
+ auto q_dtype = q.dtype();
373
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
374
+ "FlashAttention only support fp16 and bf16 data type");
375
+ TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
376
+ TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
377
+
378
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
379
+
380
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
381
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
382
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
383
+
384
+ const auto sizes = q.sizes();
385
+
386
+ const int batch_size = sizes[0];
387
+ int seqlen_q = sizes[1];
388
+ int num_heads = sizes[2];
389
+ const int head_size = sizes[3];
390
+ const int seqlen_k = k.size(1);
391
+ const int num_heads_k = k.size(2);
392
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
393
+ TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
394
+ TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
395
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
396
+
397
+ if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
398
+
399
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
400
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
401
+
402
+ // causal=true is the same as causal=false in this case
403
+ if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
404
+ if (is_causal) { window_size_right = 0; }
405
+
406
+ // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
407
+ // H/t Daniel Haziza
408
+ const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
409
+ const int ngroups = num_heads / num_heads_k;
410
+ if (seqlenq_ngroups_swapped) {
411
+ q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
412
+ seqlen_q = ngroups;
413
+ num_heads = num_heads_k;
414
+ }
415
+
416
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
417
+ CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
418
+ CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
419
+
420
+ at::Tensor out;
421
+ if (out_.has_value()) {
422
+ out = out_.value();
423
+ TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
424
+ CHECK_DEVICE(out);
425
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
426
+ CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
427
+ if (seqlenq_ngroups_swapped) {
428
+ out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
429
+ }
430
+ } else {
431
+ out = torch::empty_like(q);
432
+ }
433
+
434
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
435
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
436
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
437
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
438
+
439
+ auto opts = q.options();
440
+
441
+ auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
442
+ at::Tensor p;
443
+ // Only return softmax if there's dropout to reduce compilation time
444
+ if (return_softmax) {
445
+ TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
446
+ p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
447
+ }
448
+ else {
449
+ p = torch::empty({ 0 }, opts);
450
+ }
451
+
452
+ Flash_fwd_params params;
453
+ set_params_fprop(params,
454
+ batch_size,
455
+ seqlen_q, seqlen_k,
456
+ seqlen_q_rounded, seqlen_k_rounded,
457
+ num_heads, num_heads_k,
458
+ head_size, head_size_rounded,
459
+ q, k, v, out,
460
+ /*cu_seqlens_q_d=*/nullptr,
461
+ /*cu_seqlens_k_d=*/nullptr,
462
+ /*seqused_k=*/nullptr,
463
+ return_softmax ? p.data_ptr() : nullptr,
464
+ softmax_lse.data_ptr(),
465
+ p_dropout,
466
+ softmax_scale,
467
+ window_size_left,
468
+ window_size_right,
469
+ softcap
470
+ );
471
+
472
+ // Keep references to these tensors to extend their lifetime
473
+ at::Tensor softmax_lse_accum, out_accum;
474
+ std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
475
+ params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
476
+ head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
477
+
478
+ // number of times random will be generated per thread, to offset philox counter in thc random
479
+ // state
480
+ // We use a custom RNG that increases the offset by batch_size * nheads * 32.
481
+ int64_t counter_offset = params.b * params.h * 32;
482
+ auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
483
+ auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
484
+ // Forward kernel will populate memory with the seed and offset.
485
+ params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
486
+
487
+ if (p_dropout > 0.0) {
488
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
489
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
490
+ // See Note [Acquire lock when using random generators]
491
+ std::lock_guard<std::mutex> lock(gen->mutex_);
492
+ params.philox_args = gen->philox_cuda_state(counter_offset);
493
+ }
494
+
495
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
496
+
497
+ if (seqlen_k > 0) {
498
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
499
+ run_mha_fwd(params, stream);
500
+ } else {
501
+ // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
502
+ out.zero_();
503
+ softmax_lse.fill_(std::numeric_limits<float>::infinity());
504
+ }
505
+
506
+ if (seqlenq_ngroups_swapped) {
507
+ out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
508
+ q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
509
+ softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
510
+ }
511
+ return {out, softmax_lse, p, rng_state};
512
+ }
513
+
514
+ std::vector<at::Tensor>
515
+ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
516
+ const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
517
+ const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
518
+ std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
519
+ const at::Tensor &cu_seqlens_q, // b+1
520
+ const at::Tensor &cu_seqlens_k, // b+1
521
+ std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
522
+ std::optional<const at::Tensor> &leftpad_k_, // batch_size
523
+ std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
524
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
525
+ int max_seqlen_q,
526
+ const int max_seqlen_k,
527
+ const float p_dropout,
528
+ const float softmax_scale,
529
+ const bool zero_tensors,
530
+ bool is_causal,
531
+ int window_size_left,
532
+ int window_size_right,
533
+ const float softcap,
534
+ const bool return_softmax,
535
+ std::optional<at::Generator> gen_) {
536
+
537
+ // Otherwise the kernel will be launched from cuda:0 device
538
+ at::cuda::CUDAGuard device_guard{q.device()};
539
+
540
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
541
+ bool is_sm8x_min = cc_major >= 8;
542
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
543
+
544
+ auto q_dtype = q.dtype();
545
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
546
+ "FlashAttention only support fp16 and bf16 data type");
547
+ TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
548
+ TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
549
+ TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
550
+ TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
551
+
552
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
553
+ CHECK_DEVICE(cu_seqlens_q);
554
+ CHECK_DEVICE(cu_seqlens_k);
555
+
556
+ at::Tensor block_table;
557
+ const bool paged_KV = block_table_.has_value();
558
+ if (paged_KV) {
559
+ block_table = block_table_.value();
560
+ CHECK_DEVICE(block_table);
561
+ TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
562
+ TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
563
+ }
564
+
565
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
566
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
567
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
568
+ CHECK_CONTIGUOUS(cu_seqlens_q);
569
+ CHECK_CONTIGUOUS(cu_seqlens_k);
570
+
571
+ const auto sizes = q.sizes();
572
+
573
+ const int batch_size = cu_seqlens_q.numel() - 1;
574
+ int num_heads = sizes[1];
575
+ const int head_size = sizes[2];
576
+ const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
577
+
578
+ if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
579
+
580
+ const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
581
+ const int num_blocks = !paged_KV ? 0 : k.size(0);
582
+ const int page_block_size = !paged_KV ? 1 : k.size(1);
583
+ TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
584
+
585
+ if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
586
+ if (is_causal) { window_size_right = 0; }
587
+
588
+ void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
589
+
590
+ // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
591
+ // H/t Daniel Haziza
592
+ const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
593
+ const int ngroups = num_heads / num_heads_k;
594
+ if (seqlenq_ngroups_swapped) {
595
+ q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
596
+ max_seqlen_q = ngroups;
597
+ num_heads = num_heads_k;
598
+ cu_seqlens_q_d = nullptr;
599
+ }
600
+
601
+ const int total_q = q.sizes()[0];
602
+
603
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
604
+ TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
605
+ TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
606
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
607
+
608
+ if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
609
+ if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
610
+
611
+ CHECK_SHAPE(q, total_q, num_heads, head_size);
612
+ if (!paged_KV) {
613
+ const int total_k = k.size(0);
614
+ CHECK_SHAPE(k, total_k, num_heads_k, head_size);
615
+ CHECK_SHAPE(v, total_k, num_heads_k, head_size);
616
+ } else {
617
+ CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
618
+ CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
619
+ CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
620
+ }
621
+
622
+ CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
623
+ CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
624
+ if (seqused_k.has_value()){
625
+ auto seqused_k_ = seqused_k.value();
626
+ TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
627
+ TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
628
+ TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
629
+ CHECK_SHAPE(seqused_k_, batch_size);
630
+ }
631
+
632
+ at::Tensor out;
633
+ if (out_.has_value()) {
634
+ out = out_.value();
635
+ TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
636
+ CHECK_DEVICE(out);
637
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
638
+ CHECK_SHAPE(out, sizes[0], sizes[1], head_size);
639
+ if (seqlenq_ngroups_swapped) {
640
+ out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
641
+ }
642
+ } else {
643
+ out = torch::empty_like(q);
644
+ }
645
+
646
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
647
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
648
+ const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
649
+ const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
650
+
651
+ auto opts = q.options();
652
+ auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
653
+ at::Tensor p;
654
+ // Only return softmax if there's dropout to reduce compilation time
655
+ if (return_softmax) {
656
+ TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
657
+ p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
658
+ }
659
+ else {
660
+ p = torch::empty({ 0 }, opts);
661
+ }
662
+
663
+ if (zero_tensors) {
664
+ out.zero_();
665
+ softmax_lse.fill_(-std::numeric_limits<float>::infinity());
666
+ if (return_softmax) {p.zero_();}
667
+ }
668
+
669
+ Flash_fwd_params params;
670
+ set_params_fprop(params,
671
+ batch_size,
672
+ max_seqlen_q, max_seqlen_k,
673
+ seqlen_q_rounded, seqlen_k_rounded,
674
+ num_heads, num_heads_k,
675
+ head_size, head_size_rounded,
676
+ q, k, v, out,
677
+ cu_seqlens_q_d,
678
+ cu_seqlens_k.data_ptr(),
679
+ seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
680
+ return_softmax ? p.data_ptr() : nullptr,
681
+ softmax_lse.data_ptr(),
682
+ p_dropout,
683
+ softmax_scale,
684
+ window_size_left,
685
+ window_size_right,
686
+ softcap,
687
+ seqlenq_ngroups_swapped,
688
+ /*unpadded_lse*/true);
689
+ params.total_q = total_q;
690
+
691
+ if (paged_KV) {
692
+ params.block_table = block_table.data_ptr<int>();
693
+ params.block_table_batch_stride = block_table.stride(0);
694
+ params.k_batch_stride = k.stride(0);
695
+ params.v_batch_stride = v.stride(0);
696
+ }
697
+ params.page_block_size = page_block_size;
698
+ // Keep references to these tensors to extend their lifetime
699
+ at::Tensor softmax_lse_accum, out_accum;
700
+ if (seqlenq_ngroups_swapped) {
701
+ // Only apply split-k for decoding
702
+ std::tie(softmax_lse_accum, out_accum) =
703
+ set_params_splitkv(params, batch_size, num_heads, head_size,
704
+ max_seqlen_k, max_seqlen_q, head_size_rounded,
705
+ p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
706
+ }
707
+
708
+ if (leftpad_k_.has_value()) {
709
+ auto leftpad_k = leftpad_k_.value();
710
+ TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
711
+ TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
712
+ CHECK_DEVICE(leftpad_k);
713
+ CHECK_CONTIGUOUS(leftpad_k);
714
+ CHECK_SHAPE(leftpad_k, batch_size);
715
+ params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
716
+ }
717
+
718
+ // number of times random will be generated per thread, to offset philox counter in thc random
719
+ // state
720
+ // We use a custom RNG that increases the offset by batch_size * nheads * 32.
721
+ int64_t counter_offset = params.b * params.h * 32;
722
+ auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
723
+ auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
724
+ // Forward kernel will populate memory with the seed and offset.
725
+ params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
726
+
727
+ if (p_dropout > 0.0) {
728
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
729
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
730
+ // See Note [Acquire lock when using random generators]
731
+ std::lock_guard<std::mutex> lock(gen->mutex_);
732
+ params.philox_args = gen->philox_cuda_state(counter_offset);
733
+ }
734
+
735
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
736
+
737
+ if (max_seqlen_k > 0) {
738
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
739
+ run_mha_fwd(params, stream, paged_KV);
740
+ } else {
741
+ // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
742
+ out.zero_();
743
+ softmax_lse.fill_(std::numeric_limits<float>::infinity());
744
+ }
745
+
746
+ if (seqlenq_ngroups_swapped) {
747
+ int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
748
+ int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
749
+ out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
750
+ q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
751
+ softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
752
+ }
753
+
754
+ return {out, softmax_lse, p, rng_state};
755
+ }
756
+
757
+ void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
758
+ FP16_SWITCH(!params.is_bf16, [&] {
759
+ HEADDIM_SWITCH(params.d, [&] {
760
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
761
+ run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
762
+ });
763
+ });
764
+ });
765
+ }
766
+
767
+ std::vector<at::Tensor>
768
+ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
769
+ const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
770
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
771
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
772
+ const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
773
+ const at::Tensor &softmax_lse, // b x h x seqlen_q
774
+ std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
775
+ std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
776
+ std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
777
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
778
+ const float p_dropout, // probability to drop
779
+ const float softmax_scale,
780
+ const bool is_causal,
781
+ int window_size_left,
782
+ int window_size_right,
783
+ const float softcap,
784
+ const bool deterministic,
785
+ std::optional<at::Generator> gen_,
786
+ std::optional<at::Tensor> &rng_state) {
787
+
788
+ #ifdef FLASHATTENTION_DISABLE_BACKWARD
789
+ TORCH_CHECK(false, "This flash attention build does not support backward.");
790
+ #endif
791
+ if (is_causal) { window_size_right = 0; }
792
+
793
+ // Otherwise the kernel will be launched from cuda:0 device
794
+ at::cuda::CUDAGuard device_guard{q.device()};
795
+
796
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
797
+ bool is_sm8x_min = cc_major >= 8;
798
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
799
+
800
+ bool is_dropout = p_dropout > 0.0;
801
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
802
+
803
+ auto q_dtype = q.dtype();
804
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
805
+ "FlashAttention only support fp16 and bf16 data type");
806
+ TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
807
+ TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
808
+ TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
809
+ TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
810
+
811
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
812
+ CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
813
+
814
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
815
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
816
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
817
+ TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
818
+ TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
819
+
820
+ const auto sizes = q.sizes();
821
+
822
+ const int batch_size = sizes[0];
823
+ const int seqlen_q = sizes[1];
824
+ const int num_heads = sizes[2];
825
+ const int head_size = sizes[3];
826
+ const int seqlen_k = k.size(1);
827
+ const int num_heads_k = k.size(2);
828
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
829
+ TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
830
+ TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
831
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
832
+
833
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
834
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
835
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
836
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
837
+
838
+ if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
839
+
840
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
841
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
842
+
843
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
844
+ CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
845
+ CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
846
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
847
+ CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
848
+
849
+ at::Tensor dq, dk, dv;
850
+ if (dq_.has_value()) {
851
+ dq = dq_.value();
852
+ TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
853
+ CHECK_DEVICE(dq);
854
+ TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
855
+ CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
856
+ } else {
857
+ dq = torch::empty_like(q);
858
+ }
859
+ if (dk_.has_value()) {
860
+ dk = dk_.value();
861
+ TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
862
+ CHECK_DEVICE(dk);
863
+ TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
864
+ CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
865
+ } else {
866
+ dk = torch::empty_like(k);
867
+ }
868
+ if (dv_.has_value()) {
869
+ dv = dv_.value();
870
+ TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
871
+ CHECK_DEVICE(dv);
872
+ TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
873
+ CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
874
+ } else {
875
+ dv = torch::empty_like(v);
876
+ }
877
+
878
+ // bool loop = seqlen_k > blocksize_c;
879
+ // TODO: change later, for now set to true for simplicity
880
+ bool loop = true;
881
+
882
+ auto opts = q.options();
883
+ auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
884
+ at::Tensor dq_accum;
885
+ at::Tensor dk_accum, dv_accum;
886
+ if (loop) {
887
+ if (!deterministic) {
888
+ dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
889
+ } else {
890
+ const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
891
+ dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
892
+ }
893
+ // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
894
+ // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
895
+ }
896
+
897
+ at::Tensor dk_expanded, dv_expanded;
898
+ if (num_heads_k != num_heads) { // MQA / GQA
899
+ dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
900
+ dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
901
+ } else {
902
+ dk_expanded = dk;
903
+ dv_expanded = dv;
904
+ }
905
+
906
+ Flash_bwd_params params;
907
+
908
+ set_params_dgrad(params,
909
+ batch_size,
910
+ seqlen_q, seqlen_k,
911
+ seqlen_q_rounded, seqlen_k_rounded,
912
+ num_heads, num_heads_k,
913
+ head_size, head_size_rounded,
914
+ q, k, v, out,
915
+ dout, dq, dk_expanded, dv_expanded,
916
+ nullptr,
917
+ nullptr,
918
+ loop ? dq_accum.data_ptr() : nullptr,
919
+ // loop ? dk_accum.data_ptr() : nullptr,
920
+ // loop ? dv_accum.data_ptr() : nullptr,
921
+ nullptr,
922
+ nullptr,
923
+ softmax_lse.data_ptr(),
924
+ softmax_d.data_ptr(),
925
+ p_dropout,
926
+ softmax_scale,
927
+ window_size_left,
928
+ window_size_right,
929
+ softcap,
930
+ deterministic,
931
+ /*unpadded_lse*/false);
932
+ params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
933
+
934
+ auto launch = &run_mha_bwd;
935
+
936
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
937
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
938
+
939
+ // We use a custom RNG that increases the offset by batch_size * nheads * 32.
940
+ int64_t counter_offset = params.b * params.h * 32;
941
+
942
+ if ( rng_state.has_value() ) {
943
+ params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
944
+ } else if( is_dropout ) {
945
+ // See Note [Acquire lock when using random generators]
946
+ std::lock_guard<std::mutex> lock(gen->mutex_);
947
+ params.philox_args = gen->philox_cuda_state(counter_offset);
948
+ auto seeds = at::cuda::philox::unpack(params.philox_args);
949
+ params.rng_state[0] = std::get<0>(seeds);
950
+ params.rng_state[1] = std::get<1>(seeds);
951
+ }
952
+
953
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
954
+
955
+ if (seqlen_q > 0) {
956
+ launch(params, stream);
957
+ } else {
958
+ // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
959
+ dk_expanded.zero_();
960
+ dv_expanded.zero_();
961
+ softmax_d.zero_();
962
+ }
963
+
964
+ // For MQA/GQA we need to sum dK and dV across the groups
965
+ if (num_heads_k != num_heads) {
966
+ at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
967
+ at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
968
+ }
969
+
970
+ return { dq, dk, dv, softmax_d };
971
+ }
972
+
973
+ std::vector<at::Tensor>
974
+ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
975
+ const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
976
+ const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
977
+ const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
978
+ const at::Tensor &out, // total_q x num_heads x head_size
979
+ const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
980
+ std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
981
+ std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
982
+ std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
983
+ const at::Tensor &cu_seqlens_q, // b+1
984
+ const at::Tensor &cu_seqlens_k, // b+1
985
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
986
+ const int max_seqlen_q,
987
+ const int max_seqlen_k, // max sequence length to choose the kernel
988
+ const float p_dropout, // probability to drop
989
+ const float softmax_scale,
990
+ const bool zero_tensors,
991
+ const bool is_causal,
992
+ int window_size_left,
993
+ int window_size_right,
994
+ const float softcap,
995
+ const bool deterministic,
996
+ std::optional<at::Generator> gen_,
997
+ std::optional<at::Tensor> &rng_state) {
998
+
999
+ #ifdef FLASHATTENTION_DISABLE_BACKWARD
1000
+ TORCH_CHECK(false, "This flash attention build does not support backward.");
1001
+ #endif
1002
+ if (is_causal) { window_size_right = 0; }
1003
+
1004
+ // Otherwise the kernel will be launched from cuda:0 device
1005
+ at::cuda::CUDAGuard device_guard{q.device()};
1006
+
1007
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
1008
+ bool is_sm8x_min = cc_major >= 8;
1009
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
1010
+
1011
+ bool is_dropout = p_dropout > 0.0;
1012
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1013
+
1014
+ auto q_dtype = q.dtype();
1015
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
1016
+ "FlashAttention only support fp16 and bf16 data type");
1017
+ TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
1018
+ TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
1019
+ TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
1020
+ TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
1021
+ TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
1022
+ TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
1023
+
1024
+ CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
1025
+ CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
1026
+ CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
1027
+
1028
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1029
+ TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1030
+ TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1031
+ TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
1032
+ TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
1033
+ CHECK_CONTIGUOUS(cu_seqlens_q);
1034
+ CHECK_CONTIGUOUS(cu_seqlens_k);
1035
+
1036
+ const auto sizes = q.sizes();
1037
+
1038
+ const int total_q = sizes[0];
1039
+ const int batch_size = cu_seqlens_q.numel() - 1;
1040
+ const int num_heads = sizes[1];
1041
+ const int head_size = sizes[2];
1042
+ const int total_k = k.size(0);
1043
+ const int num_heads_k = k.size(1);
1044
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
1045
+ TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1046
+ TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
1047
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1048
+ if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
1049
+
1050
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1051
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
1052
+ const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
1053
+ const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
1054
+
1055
+ if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
1056
+ if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
1057
+
1058
+ CHECK_SHAPE(q, total_q, num_heads, head_size);
1059
+ CHECK_SHAPE(k, total_k, num_heads_k, head_size);
1060
+ CHECK_SHAPE(v, total_k, num_heads_k, head_size);
1061
+ CHECK_SHAPE(out, total_q, num_heads, head_size);
1062
+ CHECK_SHAPE(dout, total_q, num_heads, head_size);
1063
+ CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
1064
+ CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
1065
+
1066
+ at::Tensor dq, dk, dv;
1067
+ if (dq_.has_value()) {
1068
+ dq = dq_.value();
1069
+ TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
1070
+ CHECK_DEVICE(dq);
1071
+ TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
1072
+ CHECK_SHAPE(dq, total_q, num_heads, head_size);
1073
+ } else {
1074
+ dq = torch::empty_like(q);
1075
+ }
1076
+ if (dk_.has_value()) {
1077
+ dk = dk_.value();
1078
+ TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
1079
+ CHECK_DEVICE(dk);
1080
+ TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
1081
+ CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
1082
+ } else {
1083
+ dk = torch::empty_like(k);
1084
+ }
1085
+ if (dv_.has_value()) {
1086
+ dv = dv_.value();
1087
+ TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
1088
+ CHECK_DEVICE(dv);
1089
+ TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
1090
+ CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
1091
+ } else {
1092
+ dv = torch::empty_like(v);
1093
+ }
1094
+
1095
+ // bool loop = max_seqlen_k > blocksize_c;
1096
+ // TODO: change later, for now set to true for simplicity
1097
+ bool loop = true;
1098
+
1099
+ auto opts = q.options();
1100
+ auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
1101
+ at::Tensor dq_accum;
1102
+ if (loop) {
1103
+ // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
1104
+ // because that would be too large if there is a very long sequence and the rest of the sequences are short.
1105
+ // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
1106
+ // Note that 128 is the max block size on the seqlen_q dimension.
1107
+ // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
1108
+ // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
1109
+ // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
1110
+ // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
1111
+ // Same holds for softmax_d, since LSE is stored in unpadded format.
1112
+ if (!deterministic) {
1113
+ dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1114
+ } else {
1115
+ const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
1116
+ dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
1117
+ }
1118
+ }
1119
+
1120
+ at::Tensor dk_expanded, dv_expanded;
1121
+ if (num_heads_k != num_heads) { // MQA / GQA
1122
+ dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
1123
+ dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
1124
+ } else {
1125
+ dk_expanded = dk;
1126
+ dv_expanded = dv;
1127
+ }
1128
+
1129
+ if( zero_tensors ) {
1130
+ dq.zero_();
1131
+ dk_expanded.zero_();
1132
+ dv_expanded.zero_();
1133
+ softmax_d.zero_();
1134
+ }
1135
+
1136
+ Flash_bwd_params params;
1137
+
1138
+ set_params_dgrad(params,
1139
+ batch_size,
1140
+ max_seqlen_q, max_seqlen_k,
1141
+ seqlen_q_rounded, seqlen_k_rounded,
1142
+ num_heads, num_heads_k,
1143
+ head_size, head_size_rounded,
1144
+ q, k, v, out,
1145
+ dout, dq, dk_expanded, dv_expanded,
1146
+ cu_seqlens_q.data_ptr(),
1147
+ cu_seqlens_k.data_ptr(),
1148
+ loop ? dq_accum.data_ptr() : nullptr,
1149
+ nullptr,
1150
+ nullptr,
1151
+ softmax_lse.data_ptr(),
1152
+ softmax_d.data_ptr(),
1153
+ p_dropout,
1154
+ softmax_scale,
1155
+ window_size_left,
1156
+ window_size_right,
1157
+ softcap,
1158
+ deterministic,
1159
+ /*unpadded_lse*/true);
1160
+ params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
1161
+ params.total_q = total_q;
1162
+
1163
+ auto launch = &run_mha_bwd;
1164
+
1165
+ auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
1166
+ gen_, at::cuda::detail::getDefaultCUDAGenerator());
1167
+
1168
+ // We use a custom RNG that increases the offset by batch_size * nheads * 32.
1169
+ int64_t counter_offset = params.b * params.h * 32;
1170
+
1171
+ if ( rng_state.has_value() ) {
1172
+ params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
1173
+ } else if( is_dropout ) {
1174
+ // See Note [Acquire lock when using random generators]
1175
+ std::lock_guard<std::mutex> lock(gen->mutex_);
1176
+ params.philox_args = gen->philox_cuda_state(counter_offset);
1177
+ auto seeds = at::cuda::philox::unpack(params.philox_args);
1178
+ params.rng_state[0] = std::get<0>(seeds);
1179
+ params.rng_state[1] = std::get<1>(seeds);
1180
+ }
1181
+
1182
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
1183
+
1184
+ if (max_seqlen_q > 0) {
1185
+ launch(params, stream);
1186
+ } else {
1187
+ // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1188
+ dk_expanded.zero_();
1189
+ dv_expanded.zero_();
1190
+ softmax_d.zero_();
1191
+ }
1192
+
1193
+ // For MQA/GQA we need to sum dK and dV across the groups
1194
+ if (num_heads_k != num_heads) {
1195
+ at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
1196
+ at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
1197
+ }
1198
+
1199
+ return { dq, dk, dv, softmax_d };
1200
+ }
1201
+
1202
+ std::vector<at::Tensor>
1203
+ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1204
+ const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1205
+ const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1206
+ std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
1207
+ std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
1208
+ std::optional<const at::Tensor> &seqlens_k_, // batch_size
1209
+ std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
1210
+ std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
1211
+ std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
1212
+ std::optional<const at::Tensor> &leftpad_k_, // batch_size
1213
+ std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1214
+ std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1215
+ std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
1216
+ const float softmax_scale,
1217
+ bool is_causal,
1218
+ int window_size_left,
1219
+ int window_size_right,
1220
+ const float softcap,
1221
+ bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1222
+ int num_splits
1223
+ ) {
1224
+
1225
+ // Otherwise the kernel will be launched from cuda:0 device
1226
+ at::cuda::CUDAGuard device_guard{q.device()};
1227
+
1228
+ auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
1229
+ bool is_sm8x_min = cc_major >= 8;
1230
+ TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
1231
+
1232
+ auto q_dtype = q.dtype();
1233
+ TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
1234
+ "FlashAttention only support fp16 and bf16 data type");
1235
+ TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
1236
+ TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
1237
+
1238
+ CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
1239
+
1240
+ TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1241
+ TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1242
+ TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1243
+
1244
+ at::Tensor block_table;
1245
+ const bool paged_KV = block_table_.has_value();
1246
+ if (paged_KV) {
1247
+ TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
1248
+ block_table = block_table_.value();
1249
+ CHECK_DEVICE(block_table);
1250
+ TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
1251
+ TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
1252
+ }
1253
+
1254
+ const auto sizes = q.sizes();
1255
+
1256
+ const int batch_size = sizes[0];
1257
+ int seqlen_q = sizes[1];
1258
+ int num_heads = sizes[2];
1259
+ const int head_size_og = sizes[3];
1260
+
1261
+ const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
1262
+ const int num_blocks = !paged_KV ? 0 : kcache.size(0);
1263
+ const int page_block_size = !paged_KV ? 1 : kcache.size(1);
1264
+ TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
1265
+ const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
1266
+ const int num_heads_k = kcache.size(2);
1267
+ const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
1268
+ TORCH_CHECK(batch_size > 0, "batch size must be positive");
1269
+ TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
1270
+ TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1271
+
1272
+ // causal=true is the same as causal=false in this case
1273
+ if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
1274
+ if (is_causal) { window_size_right = 0; }
1275
+
1276
+ // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
1277
+ // H/t Daniel Haziza
1278
+ const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
1279
+ if (seqlenq_ngroups_swapped) {
1280
+ const int ngroups = num_heads / num_heads_k;
1281
+ q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
1282
+ seqlen_q = ngroups;
1283
+ num_heads = num_heads_k;
1284
+ }
1285
+
1286
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
1287
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
1288
+
1289
+ CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
1290
+ if (!paged_KV) {
1291
+ CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
1292
+ CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
1293
+ } else {
1294
+ CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
1295
+ CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
1296
+ CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
1297
+ }
1298
+
1299
+ at::Tensor q_padded, kcache_padded, vcache_padded;
1300
+ if (head_size_og % 8 != 0) {
1301
+ q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1302
+ kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1303
+ vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1304
+ } else {
1305
+ q_padded = q;
1306
+ kcache_padded = kcache;
1307
+ vcache_padded = vcache;
1308
+ }
1309
+
1310
+ at::Tensor out;
1311
+ if (out_.has_value()) {
1312
+ out = out_.value();
1313
+ TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
1314
+ CHECK_DEVICE(out);
1315
+ TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
1316
+ CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
1317
+ if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
1318
+ } else {
1319
+ out = torch::empty_like(q_padded);
1320
+ }
1321
+
1322
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1323
+ const int head_size = round_multiple(head_size_og, 8);
1324
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
1325
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
1326
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
1327
+
1328
+ auto opts = q.options();
1329
+
1330
+ auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
1331
+
1332
+ Flash_fwd_params params;
1333
+ set_params_fprop(params,
1334
+ batch_size,
1335
+ seqlen_q, seqlen_k,
1336
+ seqlen_q_rounded, seqlen_k_rounded,
1337
+ num_heads, num_heads_k,
1338
+ head_size, head_size_rounded,
1339
+ q_padded, kcache_padded, vcache_padded, out,
1340
+ /*cu_seqlens_q_d=*/nullptr,
1341
+ /*cu_seqlens_k_d=*/nullptr,
1342
+ /*seqused_k=*/nullptr,
1343
+ /*p_ptr=*/nullptr,
1344
+ softmax_lse.data_ptr(),
1345
+ /*p_dropout=*/0.f,
1346
+ softmax_scale,
1347
+ window_size_left,
1348
+ window_size_right,
1349
+ softcap
1350
+ );
1351
+
1352
+ at::Tensor k, v, k_padded, v_padded;
1353
+ if (k_.has_value()) {
1354
+ TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
1355
+ TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
1356
+ TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
1357
+ k = k_.value();
1358
+ v = v_.value();
1359
+ TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
1360
+ TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
1361
+ CHECK_DEVICE(k); CHECK_DEVICE(v);
1362
+ TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
1363
+ TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
1364
+ int seqlen_knew = k.size(1);
1365
+ CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
1366
+ CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
1367
+ if (head_size_og % 8 != 0) {
1368
+ k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1369
+ v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
1370
+ } else {
1371
+ k_padded = k;
1372
+ v_padded = v;
1373
+ }
1374
+ params.seqlen_knew = seqlen_knew;
1375
+ params.knew_ptr = k_padded.data_ptr();
1376
+ params.vnew_ptr = v_padded.data_ptr();
1377
+ // All stride are in elements, not bytes.
1378
+ params.knew_batch_stride = k_padded.stride(0);
1379
+ params.vnew_batch_stride = v_padded.stride(0);
1380
+ params.knew_row_stride = k_padded.stride(-3);
1381
+ params.vnew_row_stride = v_padded.stride(-3);
1382
+ params.knew_head_stride = k_padded.stride(-2);
1383
+ params.vnew_head_stride = v_padded.stride(-2);
1384
+ }
1385
+
1386
+ if (seqlens_k_.has_value()) {
1387
+ auto seqlens_k = seqlens_k_.value();
1388
+ TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
1389
+ CHECK_DEVICE(seqlens_k);
1390
+ CHECK_CONTIGUOUS(seqlens_k);
1391
+ CHECK_SHAPE(seqlens_k, batch_size);
1392
+ params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
1393
+ }
1394
+ params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
1395
+ if (leftpad_k_.has_value()) {
1396
+ TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
1397
+ auto leftpad_k = leftpad_k_.value();
1398
+ TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
1399
+ CHECK_DEVICE(leftpad_k);
1400
+ CHECK_CONTIGUOUS(leftpad_k);
1401
+ CHECK_SHAPE(leftpad_k, batch_size);
1402
+ params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
1403
+ }
1404
+
1405
+ if (rotary_cos_.has_value()) {
1406
+ TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
1407
+ auto rotary_cos = rotary_cos_.value();
1408
+ CHECK_DEVICE(rotary_cos);
1409
+ params.rotary_dim = rotary_cos.size(1) * 2;
1410
+ TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
1411
+ TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
1412
+ const int seqlen_ro = rotary_cos.size(0);
1413
+ TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
1414
+ CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
1415
+ CHECK_CONTIGUOUS(rotary_cos);
1416
+ TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
1417
+
1418
+ TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
1419
+ auto rotary_sin = rotary_sin_.value();
1420
+ CHECK_DEVICE(rotary_sin);
1421
+ CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
1422
+ CHECK_CONTIGUOUS(rotary_sin);
1423
+ TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
1424
+ params.rotary_cos_ptr = rotary_cos.data_ptr();
1425
+ params.rotary_sin_ptr = rotary_sin.data_ptr();
1426
+ params.is_rotary_interleaved = is_rotary_interleaved;
1427
+ } else {
1428
+ params.rotary_dim = 0;
1429
+ }
1430
+
1431
+ if (cache_batch_idx_.has_value()) {
1432
+ auto cache_batch_idx = cache_batch_idx_.value();
1433
+ CHECK_DEVICE(cache_batch_idx);
1434
+ CHECK_CONTIGUOUS(cache_batch_idx);
1435
+ TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
1436
+ params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
1437
+ }
1438
+
1439
+ // Keep references to these tensors to extend their lifetime
1440
+ at::Tensor softmax_lse_accum, out_accum;
1441
+ std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
1442
+ params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
1443
+ head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts);
1444
+
1445
+ if (paged_KV) {
1446
+ params.block_table = block_table.data_ptr<int>();
1447
+ params.block_table_batch_stride = block_table.stride(0);
1448
+ }
1449
+ params.page_block_size = page_block_size;
1450
+
1451
+
1452
+ set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
1453
+
1454
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
1455
+ // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
1456
+ // or paged KV cache
1457
+ run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
1458
+
1459
+ if (head_size_og % 8 != 0) {
1460
+ out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
1461
+ if (out_.has_value()) { out_.value().copy_(out); }
1462
+ if (k_.has_value()) {
1463
+ // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
1464
+ // but we don't expect to get this case in practice. This is just so that the code works for that case.
1465
+ kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
1466
+ vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
1467
+ }
1468
+ }
1469
+
1470
+ if (seqlenq_ngroups_swapped) {
1471
+ out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
1472
+ softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
1473
+ }
1474
+ return {out, softmax_lse};
1475
+ }
1476
+ } // namespace FLASH_NAMESPACE
1477
+
1478
+ // NOTE: wrap the namespaced functions so all types are doubles and longs
1479
+ std::vector<at::Tensor>
1480
+ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1481
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1482
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1483
+ const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1484
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1485
+ const double p_dropout,
1486
+ const double softmax_scale,
1487
+ bool is_causal,
1488
+ const int64_t window_size_left,
1489
+ const int64_t window_size_right,
1490
+ const double softcap,
1491
+ const bool return_softmax,
1492
+ const c10::optional<at::Generator> gen_) {
1493
+ // return FLASH_NAMESPACE::mha_fwd(q, k, v, out_, alibi_slopes_, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, softcap, return_softmax, gen_);
1494
+ // return dummy value for now
1495
+ return {};
1496
+ };
flash_attn/src/alibi.h ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cmath>
2
+
3
+ #include "namespace_config.h"
4
+ #include <cute/tensor.hpp>
5
+
6
+ #include <cutlass/cutlass.h>
7
+ #include <cutlass/array.h>
8
+
9
+ #include "utils.h"
10
+
11
+ namespace FLASH_NAMESPACE {
12
+
13
+ using namespace cute;
14
+
15
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
16
+
17
+ template <bool Is_causal>
18
+ struct Alibi {
19
+
20
+ const float alibi_slope;
21
+ const int max_seqlen_k, max_seqlen_q;
22
+
23
+ __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
24
+ : alibi_slope(alibi_slope)
25
+ , max_seqlen_k(max_seqlen_k)
26
+ , max_seqlen_q(max_seqlen_q) {
27
+ };
28
+
29
+
30
+ template <typename Engine, typename Layout>
31
+ __forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
32
+ const int col_idx_offset_,
33
+ const int row_idx_offset,
34
+ const int warp_row_stride) {
35
+ // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
36
+ static_assert(Layout::rank == 2, "Only support 2D Tensor");
37
+ const int lane_id = threadIdx.x % 32;
38
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
39
+ if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
40
+ #pragma unroll
41
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
42
+ const int col_idx_base = col_idx_offset + nj * 8;
43
+ #pragma unroll
44
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
45
+ const int col_idx = col_idx_base + j;
46
+ #pragma unroll
47
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
48
+ tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
49
+ }
50
+ }
51
+ }
52
+ } else { // Bias depends on both row_idx and col_idx
53
+ #pragma unroll
54
+ for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
55
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
56
+ #pragma unroll
57
+ for (int i = 0; i < size<0, 0>(tensor); ++i) {
58
+ const int row_idx = row_idx_base + i * 8;
59
+ #pragma unroll
60
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
61
+ const int col_idx_base = col_idx_offset + nj * 8;
62
+ #pragma unroll
63
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
64
+ const int col_idx = col_idx_base + j;
65
+ tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
66
+ }
67
+ }
68
+ }
69
+ }
70
+ }
71
+ }
72
+
73
+ };
74
+
75
+ } // namespace FLASH_NAMESPACE
flash_attn/src/block_info.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+ namespace FLASH_NAMESPACE {
9
+
10
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
11
+
12
+ template<bool Varlen=true>
13
+ struct BlockInfo {
14
+
15
+ template<typename Params>
16
+ __device__ BlockInfo(const Params &params, const int bidb)
17
+ : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
18
+ , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
19
+ , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
20
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
21
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
22
+ , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
23
+ , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k)
24
+ , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
25
+ {
26
+ }
27
+
28
+ template <typename index_t>
29
+ __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
30
+ return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
31
+ }
32
+
33
+ template <typename index_t>
34
+ __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
35
+ return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
36
+ }
37
+
38
+ const int sum_s_q;
39
+ const int sum_s_k;
40
+ const int actual_seqlen_q;
41
+ // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
42
+ const int leftpad_k;
43
+ const int seqlen_k_cache;
44
+ const int actual_seqlen_k;
45
+ };
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ } // namespace FLASH_NAMESPACE
flash_attn/src/dropout.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+ #include "philox.cuh"
9
+ #include "utils.h"
10
+
11
+ namespace FLASH_NAMESPACE {
12
+
13
+ struct Dropout {
14
+
15
+ const unsigned long long seed, offset;
16
+ const uint8_t p_dropout_in_uint8_t;
17
+
18
+ __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
19
+ const uint8_t p_dropout_in_uint8_t,
20
+ const int bid, const int hid, const int tid, const int nheads)
21
+ : seed(seed)
22
+ , offset(offset + (bid * nheads + hid) * 32 + tid % 32)
23
+ , p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
24
+ }
25
+
26
+ template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
27
+ __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
28
+ int block_row_start, int block_col_start, int block_row_stride) {
29
+ // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
30
+ Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout()));
31
+ using T = typename Engine::value_type;
32
+ auto encode_dropout = [](bool keep, T val) {
33
+ return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
34
+ };
35
+ static_assert(decltype(size<2>(tensor))::value % 2 == 0);
36
+ const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
37
+ const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
38
+ // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
39
+ #pragma unroll
40
+ for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
41
+ uint2 rowcol = make_uint2(block_row_start, block_col_start);
42
+ #pragma unroll
43
+ for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
44
+ // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
45
+ uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
46
+ // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
47
+ uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
48
+ // Special implementation for 16-bit types: we duplicate the threshold to the
49
+ // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
50
+ // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
51
+ // and the high 16 bits will be either 0xffff or 0x0000, depending on whether
52
+ // the random value is less than the threshold.
53
+ // We then do a bit-wise AND between the mask and the original value (in 32-bit).
54
+ // We're exploiting the fact that floating point comparison is equivalent to integer
55
+ // comparison, since we're comparing unsigned integers whose top 8-bits are zero.
56
+ if (!encode_dropout_in_sign_bit
57
+ && (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
58
+ uint16_t rnd_16[16];
59
+ #pragma unroll
60
+ for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
61
+ uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
62
+ #pragma unroll
63
+ for (int j = 0; j < 2; j++) {
64
+ Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
65
+ // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
66
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
67
+ #pragma unroll
68
+ for (int i = 0; i < 4; i++) {
69
+ uint32_t mask;
70
+ asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
71
+ tensor_uint32(i) &= mask;
72
+ }
73
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
74
+ }
75
+ } else {
76
+ #pragma unroll
77
+ for (int j = 0; j < 2; j++) {
78
+ #pragma unroll
79
+ for (int i = 0; i < 8; i++) {
80
+ tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
81
+ }
82
+ Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
83
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
84
+ }
85
+ }
86
+ // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
87
+ // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
88
+ // // }
89
+ }
90
+ }
91
+ }
92
+
93
+ };
94
+
95
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash.h ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+
9
+ #include <cuda.h>
10
+ #include <vector>
11
+
12
+ #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
13
+
14
+ namespace FLASH_NAMESPACE {
15
+ constexpr int TOTAL_DIM = 0;
16
+ constexpr int H_DIM = 1;
17
+ constexpr int D_DIM = 2;
18
+
19
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
20
+
21
+ struct Qkv_params {
22
+ using index_t = int64_t;
23
+ // The QKV matrices.
24
+ void *__restrict__ q_ptr;
25
+ void *__restrict__ k_ptr;
26
+ void *__restrict__ v_ptr;
27
+
28
+ // The stride between rows of the Q, K and V matrices.
29
+ index_t q_batch_stride;
30
+ index_t k_batch_stride;
31
+ index_t v_batch_stride;
32
+ index_t q_row_stride;
33
+ index_t k_row_stride;
34
+ index_t v_row_stride;
35
+ index_t q_head_stride;
36
+ index_t k_head_stride;
37
+ index_t v_head_stride;
38
+
39
+ // The number of heads.
40
+ int h, h_k;
41
+ // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
42
+ // different from nheads (query).
43
+ int h_h_k_ratio; // precompute h / h_k,
44
+ };
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ struct Flash_fwd_params : public Qkv_params {
49
+
50
+ // The O matrix (output).
51
+ void * __restrict__ o_ptr;
52
+ void * __restrict__ oaccum_ptr;
53
+
54
+ // The stride between rows of O.
55
+ index_t o_batch_stride;
56
+ index_t o_row_stride;
57
+ index_t o_head_stride;
58
+
59
+ // The pointer to the P matrix.
60
+ void * __restrict__ p_ptr;
61
+
62
+ // The pointer to the softmax sum.
63
+ void * __restrict__ softmax_lse_ptr;
64
+ void * __restrict__ softmax_lseaccum_ptr;
65
+
66
+ // The dimensions.
67
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
68
+
69
+ // The scaling factors for the kernel.
70
+ float scale_softmax;
71
+ float scale_softmax_log2;
72
+
73
+ // array of length b+1 holding starting offset of each sequence.
74
+ int * __restrict__ cu_seqlens_q;
75
+ int * __restrict__ cu_seqlens_k;
76
+ int * __restrict__ leftpad_k;
77
+
78
+ // If provided, the actual length of each k sequence.
79
+ int * __restrict__ seqused_k;
80
+
81
+ int *__restrict__ blockmask;
82
+
83
+ // The K_new and V_new matrices.
84
+ void * __restrict__ knew_ptr;
85
+ void * __restrict__ vnew_ptr;
86
+
87
+ // The stride between rows of the Q, K and V matrices.
88
+ index_t knew_batch_stride;
89
+ index_t vnew_batch_stride;
90
+ index_t knew_row_stride;
91
+ index_t vnew_row_stride;
92
+ index_t knew_head_stride;
93
+ index_t vnew_head_stride;
94
+
95
+ // The cos and sin matrices for rotary embedding.
96
+ void * __restrict__ rotary_cos_ptr;
97
+ void * __restrict__ rotary_sin_ptr;
98
+
99
+ // The indices to index into the KV cache.
100
+ int * __restrict__ cache_batch_idx;
101
+
102
+ // Paged KV cache
103
+ int * __restrict__ block_table;
104
+ index_t block_table_batch_stride;
105
+ int page_block_size;
106
+
107
+ // The dropout probability (probability of keeping an activation).
108
+ float p_dropout;
109
+ // uint32_t p_dropout_in_uint;
110
+ // uint16_t p_dropout_in_uint16_t;
111
+ uint8_t p_dropout_in_uint8_t;
112
+
113
+ // Scale factor of 1 / (1 - p_dropout).
114
+ float rp_dropout;
115
+ float scale_softmax_rp_dropout;
116
+
117
+ // Local window size
118
+ int window_size_left, window_size_right;
119
+ float softcap;
120
+
121
+ // Random state.
122
+ at::PhiloxCudaState philox_args;
123
+
124
+ // Pointer to the RNG seed (idx 0) and offset (idx 1).
125
+ uint64_t * rng_state;
126
+
127
+ bool is_bf16;
128
+ bool is_causal;
129
+
130
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
131
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
132
+ bool is_seqlens_k_cumulative;
133
+
134
+ bool is_rotary_interleaved;
135
+
136
+ int num_splits; // For split-KV version
137
+
138
+ void * __restrict__ alibi_slopes_ptr;
139
+ index_t alibi_slopes_batch_stride;
140
+
141
+ bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
142
+ bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
143
+ };
144
+
145
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
146
+
147
+ struct Flash_bwd_params : public Flash_fwd_params {
148
+
149
+ // The dO and dQKV matrices.
150
+ void *__restrict__ do_ptr;
151
+ void *__restrict__ dq_ptr;
152
+ void *__restrict__ dk_ptr;
153
+ void *__restrict__ dv_ptr;
154
+
155
+ // To accumulate dQ
156
+ void *__restrict__ dq_accum_ptr;
157
+ void *__restrict__ dk_accum_ptr;
158
+ void *__restrict__ dv_accum_ptr;
159
+
160
+ // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
161
+ // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
162
+ // dv_accum_ptr;
163
+
164
+ // The stride between rows of the dO, dQ, dK and dV matrices.
165
+ // TD [2022-04-16]: We're using 32-bit indexing to save registers.
166
+ // The code probably won't work for arrays larger than 2GB.
167
+ index_t do_batch_stride;
168
+ index_t do_row_stride;
169
+ index_t do_head_stride;
170
+ index_t dq_batch_stride;
171
+ index_t dk_batch_stride;
172
+ index_t dv_batch_stride;
173
+ index_t dq_row_stride;
174
+ index_t dk_row_stride;
175
+ index_t dv_row_stride;
176
+ index_t dq_head_stride;
177
+ index_t dk_head_stride;
178
+ index_t dv_head_stride;
179
+
180
+ // The pointer to the softmax d sum.
181
+ void *__restrict__ dsoftmax_sum;
182
+
183
+ bool deterministic;
184
+ index_t dq_accum_split_stride;
185
+ };
186
+
187
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
188
+
189
+ template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
190
+ template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
191
+
192
+ template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
193
+
194
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim128<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim192<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim256<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim32<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim32<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim64<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim64<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim96<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_bwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_bwd_<cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
+ run_mha_bwd_hdim96<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_bwd_kernel.h ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+ #include <cute/tensor.hpp>
9
+
10
+ #include <cutlass/cutlass.h>
11
+ #include <cutlass/array.h>
12
+ #include <cutlass/numeric_types.h>
13
+
14
+ #include "block_info.h"
15
+ #include "kernel_traits.h"
16
+ #include "utils.h"
17
+ #include "softmax.h"
18
+ #include "mask.h"
19
+ #include "dropout.h"
20
+
21
+ #include "alibi.h"
22
+
23
+ namespace FLASH_NAMESPACE {
24
+
25
+ using namespace cute;
26
+
27
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
28
+
29
+ template <int MMA_N,
30
+ class... Args,
31
+ class TiledMMA>
32
+ CUTE_HOST_DEVICE
33
+ auto
34
+ make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
35
+ TiledMMA const& tiled_mma) {
36
+ constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
37
+ constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;
38
+ using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
39
+ constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
40
+ // Divide by 2 because right now we always use 2 for the ValLayout
41
+ constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
42
+ constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
43
+ // This gives the correct layout, idk why.
44
+ // auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
45
+ // Stride<Stride<_1, _64>, _8> >{},
46
+ // auto t = make_tile(Layout<Shape<_8, _2, _2>,
47
+ // Stride<_1, _64, _8> >{},
48
+ auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
49
+ Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
50
+ make_layout(Int<TileShape_K>{}));
51
+ // if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
52
+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
53
+ }
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
56
+
57
+ template <int MMA_N,
58
+ class... Args,
59
+ class TiledMMA>
60
+ CUTE_HOST_DEVICE
61
+ auto
62
+ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
63
+ TiledMMA const& tiled_mma) {
64
+ constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;
65
+ constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
66
+ using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
67
+ constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
68
+ // Divide by 2 because right now we always use 2 for the ValLayout
69
+ constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
70
+ constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
71
+ auto t = make_tile(make_layout(Int<TileShape_M>{}),
72
+ Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
73
+ Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
74
+ // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
75
+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
76
+ }
77
+
78
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
79
+
80
+ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
81
+ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
82
+
83
+ using Element = typename Kernel_traits::Element;
84
+ using ElementAccum = typename Kernel_traits::ElementAccum;
85
+ using index_t = typename Kernel_traits::index_t;
86
+
87
+ // Shared memory.
88
+ extern __shared__ char smem_[];
89
+
90
+ // The thread index.
91
+ const int tidx = threadIdx.x;
92
+
93
+ constexpr int kBlockM = Kernel_traits::kBlockM;
94
+ constexpr int kBlockN = Kernel_traits::kBlockN;
95
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
96
+ constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value;
97
+ constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
98
+ constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
99
+
100
+ const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
101
+ if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
102
+
103
+ int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
104
+ if (Is_local) {
105
+ m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
106
+ }
107
+
108
+ const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
109
+ + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
110
+ const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
111
+ + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
112
+ const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
113
+ + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
114
+ const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
115
+ + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
116
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
117
+ + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
118
+ const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
119
+ + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
120
+ const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
121
+ + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
122
+ // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
123
+ + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
124
+ const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM;
125
+ // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
126
+ const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;
127
+
128
+ Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
129
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
130
+ make_stride(params.q_row_stride, _1{}));
131
+ Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
132
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
133
+ make_stride(params.k_row_stride, _1{}));
134
+ Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
135
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
136
+ make_stride(params.v_row_stride, _1{}));
137
+ Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
138
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
139
+ make_stride(params.do_row_stride, _1{}));
140
+ Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
141
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
142
+ make_stride(params.o_row_stride, _1{}));
143
+ Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
144
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
145
+ make_stride(params.dq_row_stride, _1{}));
146
+ Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
147
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
148
+ make_stride(params.h * params.d_rounded, _1{}));
149
+ Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
150
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
151
+ Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
152
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
153
+
154
+ Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
155
+ typename Kernel_traits::SmemLayoutQdO{});
156
+ Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
157
+ Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
158
+ // Double buffer for sQ
159
+ Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{});
160
+ Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
161
+ Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(),
162
+ typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
163
+ Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{});
164
+ Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
165
+ Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{});
166
+ Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{});
167
+ Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK),
168
+ typename Kernel_traits::SmemLayoutPdS{});
169
+ Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
170
+ Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
171
+ Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{});
172
+ Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
173
+ Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
174
+ // sP and sdQ share the same memory so be careful
175
+ Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
176
+
177
+ typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
178
+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
179
+ using GmemTiledCopydO = std::conditional_t<
180
+ Is_first,
181
+ typename Kernel_traits::GmemTiledCopydO,
182
+ typename Kernel_traits::GmemTiledCopyQKV
183
+ >;
184
+ GmemTiledCopydO gmem_tiled_copy_dO;
185
+ auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
186
+ typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
187
+ auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
188
+ using GmemLayoutAtomdQaccum = std::conditional_t<
189
+ !Seq_parallel,
190
+ typename Kernel_traits::GmemTiledCopydQaccum,
191
+ typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd
192
+ >;
193
+ GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;
194
+ auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
195
+
196
+ Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
197
+ Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
198
+ Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
199
+ Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO);
200
+ Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
201
+ Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
202
+ Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
203
+ Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
204
+ Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
205
+ Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
206
+ Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
207
+ Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
208
+ // if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
209
+ // __syncthreads();
210
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) {
211
+ // printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data());
212
+ // }
213
+
214
+ typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
215
+ auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
216
+ Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K)
217
+ Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
218
+ Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K)
219
+ Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K)
220
+
221
+ typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
222
+ auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);
223
+ Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N)
224
+ Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N)
225
+ Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N)
226
+ Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)
227
+
228
+ typename Kernel_traits::TiledMmadQ tiled_mma_dq;
229
+ auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx);
230
+ Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N)
231
+ Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N)
232
+
233
+ Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
234
+ Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
235
+
236
+ //
237
+ // Copy Atom retiling
238
+ //
239
+
240
+ auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
241
+ auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
242
+ Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
243
+ Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);
244
+
245
+ // auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx);
246
+ auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
247
+ auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);
248
+ Tensor tSsK = smem_thr_copy_KV.partition_S(sK);
249
+ // if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); }
250
+ // if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
251
+ Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);
252
+
253
+ // Partition sP and sdS to match the accumulator partitioning
254
+ // This has to be tiled_mma_sdp, not tiled_mma_dkv
255
+ // auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
256
+ auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
257
+ auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
258
+ Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
259
+ // if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); }
260
+ // if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); }
261
+ // if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) {
262
+ // printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data());
263
+ // }
264
+ Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
265
+
266
+ auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
267
+ auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx);
268
+ Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);
269
+ Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);
270
+
271
+ auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
272
+ auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
273
+ Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);
274
+ Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);
275
+
276
+ auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq);
277
+ auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx);
278
+ Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);
279
+
280
+ auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq);
281
+ auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);
282
+ Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);
283
+
284
+ auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
285
+ auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
286
+ Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
287
+
288
+ //
289
+ // PREDICATES
290
+ //
291
+
292
+ Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
293
+ Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
294
+ Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ);
295
+ Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV);
296
+
297
+ // Allocate predicate tensors for k
298
+ Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
299
+ Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
300
+
301
+ // Set predicates for k bounds
302
+ if (!Is_even_K) {
303
+ #pragma unroll
304
+ for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
305
+ #pragma unroll
306
+ for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
307
+ }
308
+
309
+ // Prologue
310
+
311
+ // We'll advance gdQ and gdQaccum before the 1st read/write.
312
+ tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
313
+ tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
314
+
315
+ int m_block = m_block_max - 1;
316
+ int m_block_min = (!Is_causal && !Is_local)
317
+ ? 0
318
+ : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
319
+ // If not local, we're guaranteed that m_block_min <= m_block:
320
+ // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
321
+ // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
322
+ // So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
323
+ // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
324
+ // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
325
+ // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
326
+ // However, if local, then this possible to have some blocks of K & V not attending to any query.
327
+ // We might need to exit early and write 0 to dK and dV for those blocks.
328
+ // Otherwise we get wrong result for the case where we don't enter the for loop.
329
+ // And we might read OOB elements from gQ and gdO.
330
+ // This also covers the case where actual_seqlen_q == 0
331
+ if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
332
+ const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
333
+ + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
334
+ const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
335
+ + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
336
+ Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
337
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
338
+ make_stride(params.dk_row_stride, _1{}));
339
+ Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
340
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
341
+ make_stride(params.dv_row_stride, _1{}));
342
+ typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
343
+ auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
344
+ Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
345
+ Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
346
+ Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
347
+ Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
348
+ clear(tdKrdK);
349
+ clear(tdVrdV);
350
+ Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
351
+ Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
352
+ Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
353
+ #pragma unroll
354
+ for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
355
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
356
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
357
+ gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
358
+ );
359
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
360
+ gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
361
+ );
362
+ return;
363
+ }
364
+
365
+ if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
366
+ tQsQ.data() = tQsQ.data() + size(sQ);
367
+ tSsQ.data() = tSsQ.data() + size(sQ);
368
+ tdKsQt.data() = tdKsQt.data() + size(sQ);
369
+ }
370
+
371
+ if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }
372
+
373
+ if (Kernel_traits::Is_V_in_regs) {
374
+ // Clear the smem tiles to account for predicated off loads
375
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
376
+ gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
377
+ );
378
+ FLASH_NAMESPACE::cp_async_fence();
379
+ }
380
+
381
+ Tensor tdOrdO = make_fragment_like(tdOgdO);
382
+ Tensor tdOrO = make_fragment_like(tdOgO);
383
+ if (!Is_first) {
384
+ // Clear the smem tiles to account for predicated off loads
385
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
386
+ gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
387
+ );
388
+ } else {
389
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
390
+ gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
391
+ );
392
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
393
+ gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
394
+ );
395
+ }
396
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
397
+ gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
398
+ );
399
+
400
+ Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
401
+ Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N)
402
+ static_assert(decltype(size<0>(taccScS))::value == 4);
403
+ // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices.
404
+ Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
405
+ Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
406
+ #pragma unroll
407
+ for (int mi = 0; mi < size(lse); ++mi) {
408
+ const int row = get<0>(taccScS_row(mi));
409
+ lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
410
+ }
411
+ // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
412
+ // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
413
+ // with V (which would be zero), we're fine. However, with ALiBi, we might modify these
414
+ // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
415
+
416
+ // Tensor tKrK = make_fragment_like(tKsK);
417
+ // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
418
+ // cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
419
+ // // if (cute::thread(1, 0)) { print(tKrK); }
420
+
421
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
422
+ gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
423
+ );
424
+ if (!Kernel_traits::Is_V_in_regs) {
425
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
426
+ gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
427
+ );
428
+ }
429
+ FLASH_NAMESPACE::cp_async_fence();
430
+
431
+ // if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
432
+ if (Is_first) {
433
+ cute::copy(tdOrdO, tdOsdO);
434
+ dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
435
+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
436
+ }
437
+
438
+ if (Kernel_traits::Is_V_in_regs) {
439
+ cute::cp_async_wait<1>();
440
+ __syncthreads();
441
+ Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV);
442
+ CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M
443
+ cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
444
+ }
445
+
446
+ FLASH_NAMESPACE::Dropout dropout(params.rng_state[0], params.rng_state[1], params.p_dropout_in_uint8_t,
447
+ bidb, bidh, tidx, params.h);
448
+
449
+ clear(acc_dv);
450
+ clear(acc_dk);
451
+
452
+ const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
453
+ FLASH_NAMESPACE::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
454
+
455
+ for (; m_block >= m_block_min; --m_block) {
456
+ Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
457
+ clear(acc_s);
458
+ cute::cp_async_wait<0>();
459
+ __syncthreads();
460
+
461
+ Tensor dP_sum = make_fragment_like(lse);
462
+ #pragma unroll
463
+ for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
464
+
465
+ // if (cute::thread0()) { print(sK); }
466
+ // Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK);
467
+ // #pragma unroll
468
+ // for (int k = 0; k < size<2>(tSrK_copy_view); ++k) {
469
+ // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
470
+ // }
471
+ // if (cute::thread0()) { print(tSrK); }
472
+ FLASH_NAMESPACE::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
473
+ smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
474
+
475
+ if constexpr (Is_softcap) {
476
+ FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap);
477
+ }
478
+
479
+ // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
480
+ Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout()));
481
+ // if (cute::thread(32, 0)) { print(scores); }
482
+
483
+ // Softcapping - calculating dTanh and scaling dS later with it
484
+ [[maybe_unused]] Tensor dtanh = make_tensor_like(scores);
485
+ if constexpr (Is_softcap) {
486
+ FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap);
487
+ }
488
+
489
+ // Alibi
490
+ if (Has_alibi) {
491
+ alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
492
+ m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
493
+ }
494
+
495
+ // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
496
+ // actual_seqlen_k, because acc_s would be some finite value for those indices.
497
+ // In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
498
+ // so the result would still be correct.
499
+ // However, it's possible that the values in acc_s are so large that they overflow
500
+ // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
501
+ // So we need to mask out the elements beyond actual_seqlen_k.
502
+ if (!Is_causal && !Is_local) {
503
+ if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
504
+ FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k,
505
+ n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
506
+ }
507
+ } else if (Is_causal) {
508
+ // Putting this causal masking right after acc_s is *much* slower for some reason.
509
+ // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
510
+ // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
511
+ // But we still want to mask out elements beyond actual_seqlen_k.
512
+ if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
513
+ || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
514
+ FLASH_NAMESPACE::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
515
+ binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
516
+ binfo.actual_seqlen_q,
517
+ // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
518
+ AtomLayoutMS * 16);
519
+ }
520
+ } else if (Is_local) {
521
+ if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
522
+ || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
523
+ || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
524
+ FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
525
+ binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
526
+ binfo.actual_seqlen_q, AtomLayoutMS * 16,
527
+ params.window_size_left, params.window_size_right);
528
+ }
529
+
530
+ }
531
+
532
+ // if (cute::thread(32, 0)) { print(scores); }
533
+ // Compute the exponential value.
534
+ FLASH_NAMESPACE::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
535
+ if constexpr (Is_dropout) {
536
+ int warp_id = tidx / 32;
537
+ int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
538
+ // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
539
+ static_assert(MMA_N_SdP % 2 == 0);
540
+ int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
541
+ dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
542
+ acc_s, block_row_idx, block_col_idx, AtomLayoutMS
543
+ );
544
+ }
545
+ // Convert scores from fp32 to fp16/bf16
546
+ Tensor rP = !Is_dropout
547
+ ? FLASH_NAMESPACE::convert_type<Element>(acc_s)
548
+ : FLASH_NAMESPACE::convert_type_relu<Element>(acc_s);
549
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
550
+ // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
551
+ Tensor tPrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
552
+ Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N)
553
+ cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
554
+ // if (cute::thread0()) { print(tPaP); }
555
+ // __syncthreads();
556
+ // if (cute::thread0()) { print(sP); }
557
+
558
+ Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
559
+ CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA
560
+ CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA
561
+ CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA
562
+
563
+ clear(acc_dp);
564
+ // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dp.layout()));
565
+ // #pragma unroll
566
+ // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) {
567
+ // #pragma unroll
568
+ // for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) {
569
+ // acc_dp_reshaped(mi, ni) = -dP_sum(mi);
570
+ // }
571
+ // }
572
+
573
+ // if (cute::thread0()) { print(dP_sum); }
574
+
575
+ FLASH_NAMESPACE::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
576
+ acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
577
+ smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
578
+ );
579
+
580
+ // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
581
+ Tensor dS = make_tensor(acc_dp.data(), scores.layout());
582
+ auto pointwise_mult = [](float p, float dp, float d) {
583
+ return p * (!Is_dropout || p >= 0 ? dp - d : d);
584
+ };
585
+ #pragma unroll
586
+ for (int mi = 0; mi < size<0>(dS); ++mi) {
587
+ #pragma unroll
588
+ for (int ni = 0; ni < size<1>(dS); ++ni) {
589
+ float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
590
+ if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
591
+ dS(mi, ni) = scaled_ds;
592
+ }
593
+ }
594
+ // if (cute::thread0()) { print(dS); }
595
+
596
+ Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
597
+ tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded));
598
+ if (Is_first || Seq_parallel) {
599
+ clear(acc_dq);
600
+ } else {
601
+ // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum
602
+ Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),
603
+ make_layout(get<0>(acc_dq.layout()),
604
+ get<2>(acc_dq.layout()),
605
+ get<1>(acc_dq.layout())));
606
+ cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped);
607
+ }
608
+
609
+ if (Double_buffer && m_block > m_block_min) {
610
+ // Double buffer for sQ
611
+ const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);
612
+ tQsQ.data() = tQsQ.data() + sQ_offset;
613
+ tSsQ.data() = tSsQ.data() + sQ_offset;
614
+ // Advance gQ
615
+ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
616
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
617
+ FLASH_NAMESPACE::cp_async_fence();
618
+ }
619
+
620
+ Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
621
+ // Convert dS from fp32 to fp16
622
+ Tensor tdSrdS = FLASH_NAMESPACE::convert_type<Element>(dS_reshaped);
623
+ // if (cute::thread0()) { print(tPrP); }
624
+ Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N)
625
+ cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
626
+ __syncthreads();
627
+
628
+ // Layout p_l = tPrP.layout();
629
+ // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
630
+ // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
631
+ // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
632
+ // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
633
+ FLASH_NAMESPACE::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
634
+ smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
635
+ // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
636
+ // if (cute::thread0()) { print(acc_dv); }
637
+
638
+ __syncthreads(); // Need syncthreads since we're writing to the same sdO location
639
+
640
+ if (m_block > m_block_min) {
641
+ // Advance gdO
642
+ tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
643
+ if (Is_first) {
644
+ tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
645
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
646
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
647
+ } else {
648
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);
649
+ FLASH_NAMESPACE::cp_async_fence();
650
+ }
651
+ }
652
+
653
+ FLASH_NAMESPACE::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
654
+ smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
655
+ // if (cute::thread0()) { print(acc_dq); }
656
+
657
+ if (m_block > m_block_min) {
658
+ gLSE.data() = gLSE.data() + (-int(kBlockM));
659
+ #pragma unroll
660
+ for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
661
+ gdPsum.data() = gdPsum.data() + (-int(kBlockM));
662
+ }
663
+
664
+ if (!Is_last) {
665
+ // Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum
666
+ Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),
667
+ make_layout(get<0>(acc_dq.layout()),
668
+ get<2>(acc_dq.layout()),
669
+ get<1>(acc_dq.layout())));
670
+ if (!Seq_parallel) {
671
+ cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum);
672
+ } else {
673
+ // if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); }
674
+ CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
675
+ #pragma unroll
676
+ for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
677
+ }
678
+ } else {
679
+ #pragma unroll
680
+ for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
681
+ // Convert acc_dq from fp32 to fp16
682
+ Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
683
+ Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
684
+ cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
685
+ }
686
+
687
+ FLASH_NAMESPACE::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
688
+ smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
689
+ // if (cute::thread0()) { print(acc_dk); }
690
+ if (Double_buffer) { // Double buffer for sQ
691
+ tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
692
+ }
693
+ if (!Double_buffer && m_block > m_block_min) {
694
+ __syncthreads();
695
+ // Advance gQ
696
+ tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
697
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
698
+ FLASH_NAMESPACE::cp_async_fence();
699
+ }
700
+
701
+ if (Is_first && m_block > m_block_min) {
702
+ cute::copy(tdOrdO, tdOsdO);
703
+ dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
704
+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
705
+ }
706
+
707
+ if (Is_last) {
708
+ __syncthreads();
709
+ Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
710
+ cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
711
+ tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride));
712
+ Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
713
+ Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
714
+ #pragma unroll
715
+ for (int m = 0; m < size<1>(tdQgdQ); ++m) {
716
+ if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
717
+ cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _));
718
+ }
719
+ }
720
+ }
721
+
722
+ }
723
+
724
+ // Epilogue
725
+
726
+ if (Is_dropout) {
727
+ #pragma unroll
728
+ for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; }
729
+ }
730
+ #pragma unroll
731
+ for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; }
732
+
733
+ // Convert acc_dv from fp32 to fp16
734
+ Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
735
+ Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);
736
+
737
+ Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
738
+ Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
739
+
740
+ // Partition sdV and sdK to match the accumulator partitioning
741
+ auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
742
+ auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
743
+ Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
744
+ Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
745
+ Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
746
+ Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
747
+
748
+ // We need syncthreads here since we're writing to the same location as sK and sV.
749
+ // Without syncthreads, some thread might modify the location of sK while another thread
750
+ // is reading it for dQ gemm, leading to a race condition.
751
+ // If Is_last, there's already a __syncthreads() at the end of the loop.
752
+ if (!Is_last) { __syncthreads(); }
753
+
754
+ cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
755
+ cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
756
+
757
+ const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
758
+ + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
759
+ const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
760
+ + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
761
+ Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
762
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
763
+ make_stride(params.dk_row_stride, _1{}));
764
+ Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
765
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
766
+ make_stride(params.dv_row_stride, _1{}));
767
+
768
+ typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
769
+ auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
770
+ Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
771
+ Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
772
+ Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
773
+ Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
774
+
775
+ __syncthreads();
776
+ Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
777
+ cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
778
+ Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
779
+ cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
780
+ Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
781
+ Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
782
+ Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
783
+ #pragma unroll
784
+ for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
785
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
786
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
787
+ gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
788
+ );
789
+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
790
+ gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
791
+ );
792
+
793
+ }
794
+
795
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
796
+
797
+ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>
798
+ inline __device__ void compute_dq_dk_dv(const Params &params) {
799
+
800
+ // The block index for the batch.
801
+ const int bidb = blockIdx.x;
802
+ // const int bidb = blockIdx.y;
803
+ // The block index for the head.
804
+ const int bidh = blockIdx.y;
805
+ // const int bidh = blockIdx.z;
806
+ // The thread index.
807
+ const int tidx = threadIdx.x;
808
+
809
+ const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
810
+ if (n_block_max == 1) {
811
+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
812
+ } else {
813
+ // Iterating backward from n_block_max - 1 to 0 might save 1 register
814
+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
815
+ for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
816
+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
817
+ }
818
+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
819
+ }
820
+ }
821
+
822
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
823
+
824
+ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
825
+ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
826
+
827
+ // The block index for the batch.
828
+ const int bidb = blockIdx.y;
829
+ // The block index for the head.
830
+ const int bidh = blockIdx.z;
831
+
832
+ // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
833
+ for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
834
+ compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
835
+ }
836
+ }
837
+
838
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
839
+ } // namespace flash
flash_attn/src/flash_bwd_launch_template.h ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
9
+
10
+ #include "static_switch.h"
11
+ #include "hardware_info.h"
12
+ #include "flash.h"
13
+ #include "flash_bwd_preprocess_kernel.h"
14
+ #include "flash_bwd_kernel.h"
15
+
16
+ namespace FLASH_NAMESPACE {
17
+
18
+ // Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
19
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
20
+ #define ARCH_SUPPORTS_FLASH
21
+ #define KERNEL_PARAM_MODIFIER __grid_constant__
22
+ #else
23
+ #define KERNEL_PARAM_MODIFIER
24
+ #endif
25
+
26
+ // Define a macro for unsupported architecture handling to centralize the error message
27
+ #define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
28
+
29
+ // Use a macro to clean up kernel definitions
30
+ #define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
31
+ template<typename Kernel_traits, __VA_ARGS__> \
32
+ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
33
+
34
+ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
35
+ #if defined(ARCH_SUPPORTS_FLASH)
36
+ FLASH_NAMESPACE::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
37
+ #else
38
+ FLASH_UNSUPPORTED_ARCH
39
+ #endif
40
+ }
41
+
42
+ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
43
+ #if defined(ARCH_SUPPORTS_FLASH)
44
+ static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
45
+ FLASH_NAMESPACE::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
46
+ #else
47
+ FLASH_UNSUPPORTED_ARCH
48
+ #endif
49
+ }
50
+
51
+
52
+ template<bool Clear_dQaccum=true, typename Kernel_traits>
53
+ __global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
54
+ FLASH_NAMESPACE::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
55
+ }
56
+
57
+ template<typename Kernel_traits>
58
+ __global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
59
+ FLASH_NAMESPACE::clear_dKVaccum<Kernel_traits>(params);
60
+ }
61
+
62
+ template<typename Kernel_traits>
63
+ __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
64
+ FLASH_NAMESPACE::convert_dQ<Kernel_traits>(params, nsplits);
65
+ }
66
+
67
+ template<typename Kernel_traits>
68
+ __global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
69
+ FLASH_NAMESPACE::convert_dKV<Kernel_traits>(params);
70
+ }
71
+
72
+ template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
73
+ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
74
+ const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
75
+ dim3 grid_m(num_m_block, params.b, params.h);
76
+ const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
77
+ int gridDimx = num_n_block;
78
+ if (params.deterministic) {
79
+ int num_sm = get_num_sm(get_current_device());
80
+ gridDimx = (num_sm + params.b * params.h - 1) / (params.b * params.h);
81
+ }
82
+ dim3 grid_n(gridDimx, params.b, params.h);
83
+
84
+ if (!params.deterministic) {
85
+ flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
86
+ } else {
87
+ flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
88
+ }
89
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
90
+
91
+ // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
92
+ // a multiple of kBlockN, we'll need to apply mask in the loop.
93
+ const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
94
+ const bool is_even_K = params.d == Kernel_traits::kHeadDim;
95
+ constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
96
+ // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
97
+ BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
98
+ EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
99
+ LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
100
+ ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
101
+ SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
102
+ // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
103
+ // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
104
+ // If Is_local, set Is_causal to false
105
+ auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
106
+ // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
107
+ if (smem_size_dq_dk_dv >= 48 * 1024) {
108
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
109
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
110
+ }
111
+ kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
112
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
113
+ });
114
+ });
115
+ });
116
+ });
117
+ });
118
+
119
+ auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
120
+ if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
121
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
122
+ kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
123
+ }
124
+ kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
125
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
126
+ }
127
+
128
+ template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
129
+ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
130
+ #ifndef FLASHATTENTION_DISABLE_BACKWARD
131
+ run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout, Is_causal>(params, stream);
132
+ #endif
133
+ }
134
+
135
+ template<typename T, bool Is_causal>
136
+ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
137
+ constexpr static int Headdim = 32;
138
+ int device;
139
+ cudaGetDevice(&device);
140
+ int max_smem_per_block;
141
+ cudaError status_ = cudaDeviceGetAttribute(
142
+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
143
+ if (status_ != cudaSuccess) {
144
+ C10_CUDA_CHECK(status_);
145
+ }
146
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
147
+ if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
148
+ if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
149
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
150
+ } else {
151
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
152
+ }
153
+ } else { // 96 KB
154
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
155
+ }
156
+ });
157
+ }
158
+
159
+ template<typename T, bool Is_causal>
160
+ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
161
+ constexpr static int Headdim = 64;
162
+ int device;
163
+ cudaGetDevice(&device);
164
+ int max_smem_per_block;
165
+ cudaError status_ = cudaDeviceGetAttribute(
166
+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
167
+ if (status_ != cudaSuccess) {
168
+ C10_CUDA_CHECK(status_);
169
+ }
170
+ // printf("max_smem_per_block = %d\n", max_smem_per_block);
171
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
172
+ // Changing AtomLayoutMdQ from 2 to 4 takes the same time
173
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
174
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
175
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
176
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
177
+ // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
178
+ if (max_smem_per_block >= 144 * 1024) {
179
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
180
+ // This has a lot of register spilling
181
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
182
+ } else {
183
+ // if (params.h == params.h_k) {
184
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
185
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
186
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
187
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
188
+ // } else {
189
+ // }
190
+ }
191
+ });
192
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
193
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
194
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
195
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
196
+ // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
197
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
198
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
199
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
200
+
201
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
202
+ }
203
+
204
+ template<typename T, bool Is_causal>
205
+ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
206
+ constexpr static int Headdim = 96;
207
+ int device;
208
+ cudaGetDevice(&device);
209
+ int max_smem_per_block;
210
+ cudaError status_ = cudaDeviceGetAttribute(
211
+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
212
+ if (status_ != cudaSuccess) {
213
+ C10_CUDA_CHECK(status_);
214
+ }
215
+ // printf("max_smem_per_block = %d\n", max_smem_per_block);
216
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
217
+ if (max_smem_per_block >= 116 * 1024) {
218
+ if constexpr(!Is_dropout) { // 92KB
219
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
220
+ } else { // 116 KB
221
+ // This is faster for dropout since we don't have many registers to spare
222
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
223
+ }
224
+ } else {
225
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
226
+ }
227
+ });
228
+ }
229
+
230
+ template<typename T, bool Is_causal>
231
+ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
232
+ constexpr static int Headdim = 128;
233
+ int device;
234
+ cudaGetDevice(&device);
235
+ int max_smem_per_block;
236
+ cudaError status_ = cudaDeviceGetAttribute(
237
+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
238
+ if (status_ != cudaSuccess) {
239
+ C10_CUDA_CHECK(status_);
240
+ }
241
+ // printf("max_smem_per_block = %d\n", max_smem_per_block);
242
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
243
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
244
+ // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
245
+ // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
246
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
247
+ if (max_smem_per_block >= 144 * 1024) {
248
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
249
+ // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
250
+ // run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
251
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
252
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
253
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
254
+ } else {
255
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
256
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout, Is_causal>(params, stream);
257
+ }
258
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
259
+
260
+ // run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
261
+ });
262
+ }
263
+
264
+ template<typename T, bool Is_causal>
265
+ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
266
+ constexpr static int Headdim = 160;
267
+ int device;
268
+ cudaGetDevice(&device);
269
+ int max_smem_per_block;
270
+ cudaError status_ = cudaDeviceGetAttribute(
271
+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
272
+ if (status_ != cudaSuccess) {
273
+ C10_CUDA_CHECK(status_);
274
+ }
275
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
276
+ if (max_smem_per_block >= 116 * 1024) {
277
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
278
+ } else {
279
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
280
+ }
281
+ });
282
+ }
283
+
284
+ template<typename T, bool Is_causal>
285
+ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
286
+ constexpr static int Headdim = 192;
287
+ int device;
288
+ cudaGetDevice(&device);
289
+ int max_smem_per_block;
290
+ cudaError status_ = cudaDeviceGetAttribute(
291
+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
292
+ if (status_ != cudaSuccess) {
293
+ C10_CUDA_CHECK(status_);
294
+ }
295
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
296
+ if (max_smem_per_block >= 136 * 1024) {
297
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
298
+ } else {
299
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);
300
+ }
301
+ });
302
+ }
303
+
304
+ template<typename T, bool Is_causal>
305
+ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
306
+ constexpr static int Headdim = 256;
307
+ int device;
308
+ cudaGetDevice(&device);
309
+ int max_smem_per_block;
310
+ cudaError status_ = cudaDeviceGetAttribute(
311
+ &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
312
+ if (status_ != cudaSuccess) {
313
+ C10_CUDA_CHECK(status_);
314
+ }
315
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
316
+ if (max_smem_per_block >= 176 * 1024) { // H100
317
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
318
+ } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
319
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout, Is_causal>(params, stream);
320
+ } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
321
+ if constexpr (!Is_dropout) {
322
+ run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false, Is_causal>(params, stream);
323
+ }
324
+ }
325
+ });
326
+ }
327
+
328
+ } // namespace FLASH_NAMESPACE {
flash_attn/src/flash_bwd_preprocess_kernel.h ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include "namespace_config.h"
8
+ #include <cute/tensor.hpp>
9
+
10
+ #include <cutlass/cutlass.h>
11
+ #include <cutlass/array.h>
12
+ #include <cutlass/numeric_types.h>
13
+
14
+ #include "block_info.h"
15
+ #include "kernel_traits.h"
16
+ #include "utils.h"
17
+
18
+ namespace FLASH_NAMESPACE {
19
+
20
+ using namespace cute;
21
+
22
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
23
+
24
+ template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
25
+ inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
26
+ Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
27
+ static_assert(Layout0::rank == 3, "Only support 3D Tensor");
28
+ static_assert(Layout1::rank == 1, "Only support 1D Tensor");
29
+ CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
30
+ // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
31
+ // The last coordinate is the "page".
32
+ Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
33
+ make_layout(get<0>(do_.layout()),
34
+ get<2>(do_.layout()))));
35
+ Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
36
+ Tensor do_fp32 = FLASH_NAMESPACE::convert_type<float>(do_reshaped);
37
+ Tensor o_fp32 = FLASH_NAMESPACE::convert_type<float>(o_reshaped);
38
+ #pragma unroll
39
+ for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
40
+ float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
41
+ #pragma unroll
42
+ for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
43
+ dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
44
+ }
45
+ FLASH_NAMESPACE::SumOp<float> sum_op;
46
+ dP_sum_cur = FLASH_NAMESPACE::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
47
+ if (threadIdx.x % THREADS_PER_ROW == 0) {
48
+ dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
49
+ }
50
+ }
51
+ }
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ // Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
56
+ // This is used in the case where we want to parallelize the backward across seqlen_k.
57
+ template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
58
+ inline __device__ void compute_dot_do_o(const Params &params) {
59
+ using Element = typename Kernel_traits::Element;
60
+ using ElementAccum = typename Kernel_traits::ElementAccum;
61
+ using index_t = typename Kernel_traits::index_t;
62
+
63
+ const int m_block = blockIdx.x;
64
+ // The block index for the batch.
65
+ const int bidb = blockIdx.y;
66
+ // The block index for the head.
67
+ const int bidh = blockIdx.z;
68
+ // The thread index.
69
+ const int tidx = threadIdx.x;
70
+
71
+ constexpr int kBlockM = Kernel_traits::kBlockM;
72
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
73
+
74
+ const BlockInfo binfo(params, bidb);
75
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
76
+
77
+ const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
78
+ + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
79
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
80
+ + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
81
+ const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
82
+ + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
83
+ // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
84
+ const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;
85
+
86
+ Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
87
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
88
+ make_stride(params.do_row_stride, _1{}));
89
+ Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
90
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
91
+ make_stride(params.o_row_stride, _1{}));
92
+ Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
93
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
94
+ make_stride(params.h * params.d_rounded, _1{}));
95
+ Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
96
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
97
+
98
+ typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
99
+ auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
100
+ // TODO: careful, we're zeroing out dQaccum with type float4, but when
101
+ // we do atomicAdds, we use type float. The layouts are different. Check this.
102
+ typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
103
+ auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
104
+
105
+ Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
106
+ Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
107
+ Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
108
+
109
+ Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
110
+ Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
111
+
112
+ // Allocate predicate tensors for k
113
+ Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
114
+ // Set predicates for k bounds
115
+ #pragma unroll
116
+ for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
117
+
118
+ Tensor tdOrdO = make_fragment_like(tdOgdO);
119
+ Tensor tdOrO = make_fragment_like(tdOgO);
120
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
121
+ gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
122
+ );
123
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
124
+ gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
125
+ );
126
+ // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
127
+ // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
128
+ // so that (dP - dP_sum) is on the same scale.
129
+ dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
130
+ Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
131
+ if (Clear_dQaccum) {
132
+ // We're actually not zero'ing out all of dQaccum, but only the part that we're going to
133
+ // do atomicAdds on.
134
+ Tensor zero = make_fragment_like(tdQgdQaccum);
135
+ clear(zero);
136
+ cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
137
+ }
138
+ }
139
+
140
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
141
+
142
+ template<typename Kernel_traits, typename Params>
143
+ inline __device__ void clear_dKVaccum(const Params &params) {
144
+ using ElementAccum = typename Kernel_traits::ElementAccum;
145
+ using index_t = typename Kernel_traits::index_t;
146
+
147
+ const int n_block = blockIdx.x;
148
+ // The block index for the batch.
149
+ const int bidb = blockIdx.y;
150
+ // The block index for the head.
151
+ const int bidh = blockIdx.z;
152
+ // The thread index.
153
+ const int tidx = threadIdx.x;
154
+
155
+ constexpr int kBlockN = Kernel_traits::kBlockN;
156
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
157
+
158
+ const BlockInfo binfo(params, bidb);
159
+ if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
160
+
161
+ const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
162
+
163
+ Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
164
+ Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
165
+ Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
166
+ Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
167
+
168
+ typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
169
+ auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
170
+ Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
171
+ Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
172
+ Tensor zero = make_fragment_like(tdKgdKaccum);
173
+ clear(zero);
174
+ cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
175
+ cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
176
+ }
177
+
178
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
179
+
180
+ // Convert dQ from dQaccum (in float) to fp16/bf16.
181
+ // This is used in the case where we want to parallelize the backward across seqlen_k.
182
+ template<typename Kernel_traits, typename Params>
183
+ inline __device__ void convert_dQ(const Params &params, const int nsplits) {
184
+ using Element = typename Kernel_traits::Element;
185
+ using ElementAccum = typename Kernel_traits::ElementAccum;
186
+ using index_t = typename Kernel_traits::index_t;
187
+
188
+ // Shared memory.
189
+ extern __shared__ char smem_[];
190
+
191
+ const int m_block = blockIdx.x;
192
+ // The block index for the batch.
193
+ const int bidb = blockIdx.y;
194
+ // The block index for the head.
195
+ const int bidh = blockIdx.z;
196
+ // The thread index.
197
+ const int tidx = threadIdx.x;
198
+
199
+ constexpr int kBlockM = Kernel_traits::kBlockM;
200
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
201
+
202
+ const BlockInfo binfo(params, bidb);
203
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
204
+
205
+ const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
206
+ + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
207
+ const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
208
+ + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128ll * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
209
+
210
+ Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
211
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
212
+ make_stride(params.dq_row_stride, _1{}));
213
+ Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
214
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
215
+ make_stride(params.h * params.d_rounded, _1{}));
216
+
217
+ Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
218
+ typename Kernel_traits::SmemLayoutdQ{});
219
+
220
+ typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
221
+ auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
222
+ typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
223
+ auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
224
+
225
+ typename Kernel_traits::TiledMmadQ tiled_mma_dq;
226
+ auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
227
+ auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
228
+ Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
229
+
230
+ Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
231
+ Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
232
+ Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
233
+
234
+ Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
235
+ CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
236
+
237
+ Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
238
+ clear(acc_dq);
239
+ for (int s = 0; s < nsplits; ++s) {
240
+ cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
241
+ #pragma unroll
242
+ for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
243
+ tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
244
+ }
245
+ #pragma unroll
246
+ for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
247
+ // Convert acc_dq from fp32 to fp16
248
+ Tensor rdQ = FLASH_NAMESPACE::convert_type<Element>(acc_dq);
249
+ Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
250
+ cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
251
+ __syncthreads();
252
+ Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
253
+ cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
254
+
255
+ Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
256
+ Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
257
+ Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
258
+ #pragma unroll
259
+ for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
260
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
261
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
262
+ gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
263
+ );
264
+ }
265
+
266
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
267
+
268
+ // Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
269
+ // This is used in the case where we want to parallelize the backward across seqlen_q.
270
+ template<typename Kernel_traits, typename Params>
271
+ inline __device__ void convert_dKV(const Params &params) {
272
+ using Element = typename Kernel_traits::Element;
273
+ using ElementAccum = typename Kernel_traits::ElementAccum;
274
+ using index_t = typename Kernel_traits::index_t;
275
+
276
+ // Shared memory.
277
+ extern __shared__ char smem_[];
278
+
279
+ const int n_block = blockIdx.x;
280
+ // The block index for the batch.
281
+ const int bidb = blockIdx.y;
282
+ // The block index for the head.
283
+ const int bidh = blockIdx.z;
284
+ // The thread index.
285
+ const int tidx = threadIdx.x;
286
+
287
+ constexpr int kBlockN = Kernel_traits::kBlockN;
288
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
289
+
290
+ const BlockInfo binfo(params, bidb);
291
+ if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
292
+
293
+ const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
294
+ + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
295
+ const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
296
+ + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
297
+ const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
298
+ + n_block * kBlockN) * params.d_rounded;
299
+
300
+ Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
301
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
302
+ make_stride(params.dk_row_stride, _1{}));
303
+ Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
304
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
305
+ make_stride(params.dv_row_stride, _1{}));
306
+ Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
307
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
308
+ Stride<Int<kHeadDim>, _1>{});
309
+ Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
310
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
311
+ Stride<Int<kHeadDim>, _1>{});
312
+
313
+ Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
314
+ typename Kernel_traits::SmemLayoutdKV{});
315
+ Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
316
+
317
+ typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
318
+ auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
319
+ typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
320
+ auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
321
+
322
+ typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
323
+ auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
324
+ auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
325
+ Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
326
+ Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
327
+
328
+ Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
329
+ Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
330
+ Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
331
+ Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
332
+ Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
333
+ Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
334
+
335
+ Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
336
+ Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
337
+ CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
338
+ CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
339
+
340
+ Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
341
+ Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
342
+ cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
343
+ cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
344
+ #pragma unroll
345
+ for (int i = 0; i < size(acc_dk); ++i) {
346
+ acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
347
+ }
348
+ #pragma unroll
349
+ for (int i = 0; i < size(acc_dv); ++i) {
350
+ acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
351
+ }
352
+ // Convert acc_dk from fp32 to fp16
353
+ Tensor rdK = FLASH_NAMESPACE::convert_type<Element>(acc_dk);
354
+ Tensor rdV = FLASH_NAMESPACE::convert_type<Element>(acc_dv);
355
+ Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
356
+ Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
357
+ cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
358
+ cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
359
+ __syncthreads();
360
+ Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
361
+ Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
362
+ cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
363
+ cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
364
+
365
+ Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
366
+ Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
367
+ Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
368
+ #pragma unroll
369
+ for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
370
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
371
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
372
+ gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
373
+ );
374
+ FLASH_NAMESPACE::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
375
+ gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
376
+ );
377
+ }
378
+
379
+ } // namespace flash
flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE
flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2024, Tri Dao.
2
+ // Splitting the different head dimensions to different files to speed up compilation.
3
+ // This file is auto-generated. See "generate_kernels.py"
4
+ #include "namespace_config.h"
5
+ #include "flash_fwd_launch_template.h"
6
+
7
+ namespace FLASH_NAMESPACE {
8
+
9
+ template<>
10
+ void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
11
+ run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
12
+ }
13
+
14
+ } // namespace FLASH_NAMESPACE