drbh commited on
Commit
39b4aba
·
1 Parent(s): a7165c8

feat: pass vars into fwd and include build

Browse files
Files changed (41) hide show
  1. .gitattributes +1 -0
  2. build.toml +71 -70
  3. build/torch25-cxx11-cu118-x86_64-linux/flash_attn/__init__.py +37 -0
  4. build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  5. build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
  6. build/torch25-cxx11-cu121-x86_64-linux/flash_attn/__init__.py +37 -0
  7. build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  8. build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_ops.py +9 -0
  9. build/torch25-cxx11-cu124-x86_64-linux/flash_attn/__init__.py +37 -0
  10. build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  11. build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_ops.py +9 -0
  12. build/torch25-cxx98-cu118-x86_64-linux/flash_attn/__init__.py +37 -0
  13. build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  14. build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
  15. build/torch25-cxx98-cu121-x86_64-linux/flash_attn/__init__.py +37 -0
  16. build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  17. build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_ops.py +9 -0
  18. build/torch25-cxx98-cu124-x86_64-linux/flash_attn/__init__.py +37 -0
  19. build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  20. build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_ops.py +9 -0
  21. build/torch26-cxx11-cu118-x86_64-linux/flash_attn/__init__.py +37 -0
  22. build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  23. build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
  24. build/torch26-cxx11-cu124-x86_64-linux/flash_attn/__init__.py +37 -0
  25. build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  26. build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py +9 -0
  27. build/torch26-cxx11-cu126-x86_64-linux/flash_attn/__init__.py +37 -0
  28. build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  29. build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py +9 -0
  30. build/torch26-cxx98-cu118-x86_64-linux/flash_attn/__init__.py +37 -0
  31. build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  32. build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
  33. build/torch26-cxx98-cu124-x86_64-linux/flash_attn/__init__.py +37 -0
  34. build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  35. build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py +9 -0
  36. build/torch26-cxx98-cu126-x86_64-linux/flash_attn/__init__.py +37 -0
  37. build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
  38. build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py +9 -0
  39. flake.lock +4 -4
  40. flash_attn/flash_api.cpp +20 -4
  41. torch-ext/flash_attn/__init__.py +2 -2
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
build.toml CHANGED
@@ -13,13 +13,24 @@ src = [
13
  "flash_attn/src/hardware_info.h",
14
  "flash_attn/src/flash.h",
15
  "flash_attn/src/static_switch.h",
16
- #
17
  "flash_attn/src/alibi.h",
18
  "flash_attn/src/block_info.h",
19
  "flash_attn/src/dropout.h",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- # TODO: dont skip bwd kernels
22
-
23
  # "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
24
  # "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
25
  # "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
@@ -52,75 +63,65 @@ src = [
52
  # "flash_attn/src/flash_bwd_launch_template.h",
53
  # "flash_attn/src/flash_bwd_preprocess_kernel.h",
54
 
55
- "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
56
- "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
57
- "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
58
- "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
59
- "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
60
- "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
61
- "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
62
- "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
63
- "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
64
- "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
65
- "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
66
- "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
67
- "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
68
- "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
69
- "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
70
- "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
71
- "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
72
- "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
73
- "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
74
- "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
75
- "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
76
- "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
77
- "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
78
- "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
79
- "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
80
- "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
81
- "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
 
 
82
  "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
83
  "flash_attn/src/flash_fwd_kernel.h",
84
  "flash_attn/src/flash_fwd_launch_template.h",
85
- "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
86
- "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
87
- "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
88
- "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
89
- "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
90
- "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
91
- "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
92
- "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
93
- "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
94
- "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
95
- "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
96
- "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
97
- "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
98
- "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
99
- "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
100
- "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
101
- "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
102
- "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
103
- "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
104
- "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
105
- "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
106
- "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
107
- "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
108
- "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
109
- "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
110
- "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
111
- "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
112
- "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
113
- "flash_attn/src/flash.h",
114
- "flash_attn/src/generate_kernels.py",
115
- "flash_attn/src/hardware_info.h",
116
- "flash_attn/src/kernel_traits.h",
117
- "flash_attn/src/mask.h",
118
- "flash_attn/src/namespace_config.h",
119
- "flash_attn/src/philox.cuh",
120
- "flash_attn/src/philox_unpack.cuh",
121
- "flash_attn/src/rotary.h",
122
- "flash_attn/src/softmax.h",
123
- "flash_attn/src/static_switch.h",
124
- "flash_attn/src/utils.h",
125
  ]
126
  depends = ["torch", "cutlass_3_6"]
 
13
  "flash_attn/src/hardware_info.h",
14
  "flash_attn/src/flash.h",
15
  "flash_attn/src/static_switch.h",
 
16
  "flash_attn/src/alibi.h",
17
  "flash_attn/src/block_info.h",
18
  "flash_attn/src/dropout.h",
19
+ "flash_attn/src/flash.h",
20
+ "flash_attn/src/generate_kernels.py",
21
+ "flash_attn/src/hardware_info.h",
22
+ "flash_attn/src/kernel_traits.h",
23
+ "flash_attn/src/mask.h",
24
+ "flash_attn/src/namespace_config.h",
25
+ "flash_attn/src/philox.cuh",
26
+ "flash_attn/src/philox_unpack.cuh",
27
+ "flash_attn/src/rotary.h",
28
+ "flash_attn/src/softmax.h",
29
+ "flash_attn/src/static_switch.h",
30
+ "flash_attn/src/utils.h",
31
+
32
+ ## TODO: include bwd kernels
33
 
 
 
34
  # "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
35
  # "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
36
  # "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
 
63
  # "flash_attn/src/flash_bwd_launch_template.h",
64
  # "flash_attn/src/flash_bwd_preprocess_kernel.h",
65
 
66
+ ## TODO: include fwd kernels
67
+
68
+ # "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
69
+ # "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
70
+ # "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
71
+ # "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
72
+ # "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
73
+ # "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
74
+ # "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
75
+ # "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
76
+ # "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
77
+ # "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
78
+ # "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
79
+ # "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
80
+ # "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
81
+ # "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
82
+ # "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
83
+ # "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
84
+ # "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
85
+ # "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
86
+ # "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
87
+ # "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
88
+ # "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
89
+ # "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
90
+ # "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
91
+ # "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
92
+ # "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
93
+ # "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
94
+ # "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
95
  "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
96
  "flash_attn/src/flash_fwd_kernel.h",
97
  "flash_attn/src/flash_fwd_launch_template.h",
98
+ # "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
99
+ # "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
100
+ # "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
101
+ # "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
102
+ # "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
103
+ # "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
104
+ # "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
105
+ # "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
106
+ # "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
107
+ # "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
108
+ # "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
109
+ # "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
110
+ # "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
111
+ # "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
112
+ # "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
113
+ # "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
114
+ # "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
115
+ # "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
116
+ # "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
117
+ # "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
118
+ # "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
119
+ # "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
120
+ # "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
121
+ # "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
122
+ # "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
123
+ # "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
124
+ # "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
125
+ # "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
 
 
 
 
 
 
 
 
 
 
 
 
126
  ]
127
  depends = ["torch", "cutlass_3_6"]
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2a6f11f1665f62c8f3b96cd843c806b737966575c28804c602bc68d089c1759
3
+ size 17469320
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e481af6967a53e2017631ade57897e3ef32e1a13e8badb11310df46e8748dab
3
+ size 17561616
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df22c84c094e57e3e08c4adb615637c8e1a10fc914f9601a372eb1749ffcda12
3
+ size 17820800
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd9920d56ee47082c06be48f07d20a869864954713bb8d05991dfcf01992cc6b
3
+ size 17461960
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:806a71437827eb1724e80bbaf1cee7f1ef0242cd7c9a34b7e6ff696a8536f16a
3
+ size 17558544
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:098cc28e134482be440715e9df0fe5b3e4023c1b5ca2c562da39571b630c4d73
3
+ size 17817728
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db8e9a06cafa5dffe988c22df459745deb3ee1b22b084e53ed6429e49867aae7
3
+ size 17469464
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3177cf407996b4f51ee139bfa4dcaf647fd659429cf9901ade2ac08117e20f9d
3
+ size 17821096
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43f21a0f290a6f42e004303c760e5aacc851ad55bd9093cea4752c0a7d6b202e
3
+ size 17981304
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:434702696304310402d3ce50496e7f9f113b632ebc90ef602e255562a54d480a
3
+ size 17462256
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d90d1be6a4a87ec538a3b009356af76b6c1a1b5b18ce1e69b0fe8b0316972090
3
+ size 17817920
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def mha_fwd(
8
+ q: torch.Tensor,
9
+ k: torch.Tensor,
10
+ v: torch.Tensor,
11
+ out: torch.Tensor,
12
+ alibi_slopes: torch.Tensor,
13
+ p_dropout: float,
14
+ softmax_scale: float,
15
+ is_causal: bool,
16
+ window_size_left: int,
17
+ window_size_right: int,
18
+ softcap: float,
19
+ return_softmax: bool,
20
+ gen: Optional[torch.Generator],
21
+ ) -> torch.Tensor:
22
+ return ops.mha_fwd(
23
+ q,
24
+ k,
25
+ v,
26
+ out,
27
+ alibi_slopes,
28
+ p_dropout,
29
+ softmax_scale,
30
+ is_causal,
31
+ window_size_left,
32
+ window_size_right,
33
+ softcap,
34
+ return_softmax,
35
+ gen,
36
+ )
37
+ return out
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff763dccb46211a07fab8e63cfd96f76984cd994525d6c8ce0e274489e8099ca
3
+ size 17978128
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_a7165c8_dirty
3
+ ops = torch.ops._flash_attn_a7165c8_dirty
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_a7165c8_dirty::{op_name}"
flake.lock CHANGED
@@ -41,11 +41,11 @@
41
  "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
- "lastModified": 1742582705,
45
- "narHash": "sha256-1Vq5IauC/8fjBqcnMbDzckLN/XLIGwWr3/c2Wt3I2vs=",
46
  "ref": "refs/heads/main",
47
- "rev": "e06e3e72947fad8bfd2c1eb5d8e7f5ec01d359d6",
48
- "revCount": 103,
49
  "type": "git",
50
  "url": "ssh://[email protected]/huggingface/kernel-builder"
51
  },
 
41
  "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
+ "lastModified": 1742905006,
45
+ "narHash": "sha256-SCi1f5Lti4AM0kNPlAidcgN/5YM4HgJP4KwCsMrB0IE=",
46
  "ref": "refs/heads/main",
47
+ "rev": "517a2bf2d0a3f1faf058ab995b6ca280b0999e7c",
48
+ "revCount": 105,
49
  "type": "git",
50
  "url": "ssh://[email protected]/huggingface/kernel-builder"
51
  },
flash_attn/flash_api.cpp CHANGED
@@ -1490,7 +1490,23 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x roun
1490
  const double softcap,
1491
  const bool return_softmax,
1492
  const c10::optional<at::Generator> gen_) {
1493
- // return FLASH_NAMESPACE::mha_fwd(q, k, v, out_, alibi_slopes_, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, softcap, return_softmax, gen_);
1494
- // return dummy value for now
1495
- return {};
1496
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1490
  const double softcap,
1491
  const bool return_softmax,
1492
  const c10::optional<at::Generator> gen_) {
1493
+
1494
+ auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1495
+
1496
+ // Prepare the optional arguments as non-const references.
1497
+ std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
1498
+ std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
1499
+
1500
+ if (!out.has_value()){
1501
+ out = torch::empty_like(q);
1502
+ }
1503
+
1504
+ // Convert double to float and int64_t to int.
1505
+ float p_dropout_float = static_cast<float>(p_dropout);
1506
+ float softmax_scale_float = static_cast<float>(softmax_scale);
1507
+ float softcap_float = static_cast<float>(softcap);
1508
+ int window_size_left_int = static_cast<int>(window_size_left);
1509
+ int window_size_right_int = static_cast<int>(window_size_right);
1510
+
1511
+ return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
1512
+ }
torch-ext/flash_attn/__init__.py CHANGED
@@ -19,7 +19,7 @@ def mha_fwd(
19
  return_softmax: bool,
20
  gen: Optional[torch.Generator],
21
  ) -> torch.Tensor:
22
- return ops.mha_fwd(
23
  q,
24
  k,
25
  v,
@@ -34,4 +34,4 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
19
  return_softmax: bool,
20
  gen: Optional[torch.Generator],
21
  ) -> torch.Tensor:
22
+ ops.mha_fwd(
23
  q,
24
  k,
25
  v,
 
34
  return_softmax,
35
  gen,
36
  )
37
+ return out