drbh
commited on
Commit
·
39b4aba
1
Parent(s):
a7165c8
feat: pass vars into fwd and include build
Browse files- .gitattributes +1 -0
- build.toml +71 -70
- build/torch25-cxx11-cu118-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch25-cxx11-cu121-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch25-cxx11-cu124-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch25-cxx98-cu118-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch25-cxx98-cu121-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch25-cxx98-cu124-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch26-cxx11-cu118-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch26-cxx11-cu124-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch26-cxx11-cu126-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch26-cxx98-cu118-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch26-cxx98-cu124-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py +9 -0
- build/torch26-cxx98-cu126-x86_64-linux/flash_attn/__init__.py +37 -0
- build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so +3 -0
- build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py +9 -0
- flake.lock +4 -4
- flash_attn/flash_api.cpp +20 -4
- torch-ext/flash_attn/__init__.py +2 -2
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
build.toml
CHANGED
@@ -13,13 +13,24 @@ src = [
|
|
13 |
"flash_attn/src/hardware_info.h",
|
14 |
"flash_attn/src/flash.h",
|
15 |
"flash_attn/src/static_switch.h",
|
16 |
-
#
|
17 |
"flash_attn/src/alibi.h",
|
18 |
"flash_attn/src/block_info.h",
|
19 |
"flash_attn/src/dropout.h",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
# TODO: dont skip bwd kernels
|
22 |
-
|
23 |
# "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
|
24 |
# "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
25 |
# "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
|
@@ -52,75 +63,65 @@ src = [
|
|
52 |
# "flash_attn/src/flash_bwd_launch_template.h",
|
53 |
# "flash_attn/src/flash_bwd_preprocess_kernel.h",
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
"flash_attn/src/
|
58 |
-
"flash_attn/src/
|
59 |
-
"flash_attn/src/
|
60 |
-
"flash_attn/src/
|
61 |
-
"flash_attn/src/
|
62 |
-
"flash_attn/src/
|
63 |
-
"flash_attn/src/
|
64 |
-
"flash_attn/src/
|
65 |
-
"flash_attn/src/
|
66 |
-
"flash_attn/src/
|
67 |
-
"flash_attn/src/
|
68 |
-
"flash_attn/src/
|
69 |
-
"flash_attn/src/
|
70 |
-
"flash_attn/src/
|
71 |
-
"flash_attn/src/
|
72 |
-
"flash_attn/src/
|
73 |
-
"flash_attn/src/
|
74 |
-
"flash_attn/src/
|
75 |
-
"flash_attn/src/
|
76 |
-
"flash_attn/src/
|
77 |
-
"flash_attn/src/
|
78 |
-
"flash_attn/src/
|
79 |
-
"flash_attn/src/
|
80 |
-
"flash_attn/src/
|
81 |
-
"flash_attn/src/
|
|
|
|
|
82 |
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
83 |
"flash_attn/src/flash_fwd_kernel.h",
|
84 |
"flash_attn/src/flash_fwd_launch_template.h",
|
85 |
-
"flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
|
86 |
-
"flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
|
87 |
-
"flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
|
88 |
-
"flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
|
89 |
-
"flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
|
90 |
-
"flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
|
91 |
-
"flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
|
92 |
-
"flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
|
93 |
-
"flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
|
94 |
-
"flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
|
95 |
-
"flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
|
96 |
-
"flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
|
97 |
-
"flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
|
98 |
-
"flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
|
99 |
-
"flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
|
100 |
-
"flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
|
101 |
-
"flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
|
102 |
-
"flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
|
103 |
-
"flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
|
104 |
-
"flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
|
105 |
-
"flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
|
106 |
-
"flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
|
107 |
-
"flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
|
108 |
-
"flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
|
109 |
-
"flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
|
110 |
-
"flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
|
111 |
-
"flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
|
112 |
-
"flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
|
113 |
-
"flash_attn/src/flash.h",
|
114 |
-
"flash_attn/src/generate_kernels.py",
|
115 |
-
"flash_attn/src/hardware_info.h",
|
116 |
-
"flash_attn/src/kernel_traits.h",
|
117 |
-
"flash_attn/src/mask.h",
|
118 |
-
"flash_attn/src/namespace_config.h",
|
119 |
-
"flash_attn/src/philox.cuh",
|
120 |
-
"flash_attn/src/philox_unpack.cuh",
|
121 |
-
"flash_attn/src/rotary.h",
|
122 |
-
"flash_attn/src/softmax.h",
|
123 |
-
"flash_attn/src/static_switch.h",
|
124 |
-
"flash_attn/src/utils.h",
|
125 |
]
|
126 |
depends = ["torch", "cutlass_3_6"]
|
|
|
13 |
"flash_attn/src/hardware_info.h",
|
14 |
"flash_attn/src/flash.h",
|
15 |
"flash_attn/src/static_switch.h",
|
|
|
16 |
"flash_attn/src/alibi.h",
|
17 |
"flash_attn/src/block_info.h",
|
18 |
"flash_attn/src/dropout.h",
|
19 |
+
"flash_attn/src/flash.h",
|
20 |
+
"flash_attn/src/generate_kernels.py",
|
21 |
+
"flash_attn/src/hardware_info.h",
|
22 |
+
"flash_attn/src/kernel_traits.h",
|
23 |
+
"flash_attn/src/mask.h",
|
24 |
+
"flash_attn/src/namespace_config.h",
|
25 |
+
"flash_attn/src/philox.cuh",
|
26 |
+
"flash_attn/src/philox_unpack.cuh",
|
27 |
+
"flash_attn/src/rotary.h",
|
28 |
+
"flash_attn/src/softmax.h",
|
29 |
+
"flash_attn/src/static_switch.h",
|
30 |
+
"flash_attn/src/utils.h",
|
31 |
+
|
32 |
+
## TODO: include bwd kernels
|
33 |
|
|
|
|
|
34 |
# "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
|
35 |
# "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
36 |
# "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
|
|
|
63 |
# "flash_attn/src/flash_bwd_launch_template.h",
|
64 |
# "flash_attn/src/flash_bwd_preprocess_kernel.h",
|
65 |
|
66 |
+
## TODO: include fwd kernels
|
67 |
+
|
68 |
+
# "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
|
69 |
+
# "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
70 |
+
# "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
|
71 |
+
# "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
72 |
+
# "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
|
73 |
+
# "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
|
74 |
+
# "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
|
75 |
+
# "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
|
76 |
+
# "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
|
77 |
+
# "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
78 |
+
# "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
|
79 |
+
# "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
|
80 |
+
# "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
|
81 |
+
# "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
|
82 |
+
# "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
|
83 |
+
# "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
|
84 |
+
# "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
|
85 |
+
# "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
|
86 |
+
# "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
|
87 |
+
# "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
|
88 |
+
# "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
|
89 |
+
# "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
|
90 |
+
# "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
|
91 |
+
# "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
|
92 |
+
# "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
|
93 |
+
# "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
|
94 |
+
# "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
|
95 |
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
96 |
"flash_attn/src/flash_fwd_kernel.h",
|
97 |
"flash_attn/src/flash_fwd_launch_template.h",
|
98 |
+
# "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
|
99 |
+
# "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
|
100 |
+
# "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
|
101 |
+
# "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
|
102 |
+
# "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
|
103 |
+
# "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
|
104 |
+
# "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
|
105 |
+
# "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
|
106 |
+
# "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
|
107 |
+
# "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
|
108 |
+
# "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
|
109 |
+
# "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
|
110 |
+
# "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
|
111 |
+
# "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
|
112 |
+
# "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
|
113 |
+
# "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
|
114 |
+
# "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
|
115 |
+
# "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
|
116 |
+
# "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
|
117 |
+
# "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
|
118 |
+
# "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
|
119 |
+
# "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
|
120 |
+
# "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
|
121 |
+
# "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
|
122 |
+
# "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
|
123 |
+
# "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
|
124 |
+
# "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
|
125 |
+
# "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
]
|
127 |
depends = ["torch", "cutlass_3_6"]
|
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c2a6f11f1665f62c8f3b96cd843c806b737966575c28804c602bc68d089c1759
|
3 |
+
size 17469320
|
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e481af6967a53e2017631ade57897e3ef32e1a13e8badb11310df46e8748dab
|
3 |
+
size 17561616
|
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df22c84c094e57e3e08c4adb615637c8e1a10fc914f9601a372eb1749ffcda12
|
3 |
+
size 17820800
|
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd9920d56ee47082c06be48f07d20a869864954713bb8d05991dfcf01992cc6b
|
3 |
+
size 17461960
|
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:806a71437827eb1724e80bbaf1cee7f1ef0242cd7c9a34b7e6ff696a8536f16a
|
3 |
+
size 17558544
|
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:098cc28e134482be440715e9df0fe5b3e4023c1b5ca2c562da39571b630c4d73
|
3 |
+
size 17817728
|
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db8e9a06cafa5dffe988c22df459745deb3ee1b22b084e53ed6429e49867aae7
|
3 |
+
size 17469464
|
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3177cf407996b4f51ee139bfa4dcaf647fd659429cf9901ade2ac08117e20f9d
|
3 |
+
size 17821096
|
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43f21a0f290a6f42e004303c760e5aacc851ad55bd9093cea4752c0a7d6b202e
|
3 |
+
size 17981304
|
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:434702696304310402d3ce50496e7f9f113b632ebc90ef602e255562a54d480a
|
3 |
+
size 17462256
|
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d90d1be6a4a87ec538a3b009356af76b6c1a1b5b18ce1e69b0fe8b0316972090
|
3 |
+
size 17817920
|
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from ._ops import ops
|
6 |
+
|
7 |
+
def mha_fwd(
|
8 |
+
q: torch.Tensor,
|
9 |
+
k: torch.Tensor,
|
10 |
+
v: torch.Tensor,
|
11 |
+
out: torch.Tensor,
|
12 |
+
alibi_slopes: torch.Tensor,
|
13 |
+
p_dropout: float,
|
14 |
+
softmax_scale: float,
|
15 |
+
is_causal: bool,
|
16 |
+
window_size_left: int,
|
17 |
+
window_size_right: int,
|
18 |
+
softcap: float,
|
19 |
+
return_softmax: bool,
|
20 |
+
gen: Optional[torch.Generator],
|
21 |
+
) -> torch.Tensor:
|
22 |
+
return ops.mha_fwd(
|
23 |
+
q,
|
24 |
+
k,
|
25 |
+
v,
|
26 |
+
out,
|
27 |
+
alibi_slopes,
|
28 |
+
p_dropout,
|
29 |
+
softmax_scale,
|
30 |
+
is_causal,
|
31 |
+
window_size_left,
|
32 |
+
window_size_right,
|
33 |
+
softcap,
|
34 |
+
return_softmax,
|
35 |
+
gen,
|
36 |
+
)
|
37 |
+
return out
|
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_a7165c8_dirty.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ff763dccb46211a07fab8e63cfd96f76984cd994525d6c8ce0e274489e8099ca
|
3 |
+
size 17978128
|
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _flash_attn_a7165c8_dirty
|
3 |
+
ops = torch.ops._flash_attn_a7165c8_dirty
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_flash_attn_a7165c8_dirty::{op_name}"
|
flake.lock
CHANGED
@@ -41,11 +41,11 @@
|
|
41 |
"rocm-nix": "rocm-nix"
|
42 |
},
|
43 |
"locked": {
|
44 |
-
"lastModified":
|
45 |
-
"narHash": "sha256-
|
46 |
"ref": "refs/heads/main",
|
47 |
-
"rev": "
|
48 |
-
"revCount":
|
49 |
"type": "git",
|
50 |
"url": "ssh://[email protected]/huggingface/kernel-builder"
|
51 |
},
|
|
|
41 |
"rocm-nix": "rocm-nix"
|
42 |
},
|
43 |
"locked": {
|
44 |
+
"lastModified": 1742905006,
|
45 |
+
"narHash": "sha256-SCi1f5Lti4AM0kNPlAidcgN/5YM4HgJP4KwCsMrB0IE=",
|
46 |
"ref": "refs/heads/main",
|
47 |
+
"rev": "517a2bf2d0a3f1faf058ab995b6ca280b0999e7c",
|
48 |
+
"revCount": 105,
|
49 |
"type": "git",
|
50 |
"url": "ssh://[email protected]/huggingface/kernel-builder"
|
51 |
},
|
flash_attn/flash_api.cpp
CHANGED
@@ -1490,7 +1490,23 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x roun
|
|
1490 |
const double softcap,
|
1491 |
const bool return_softmax,
|
1492 |
const c10::optional<at::Generator> gen_) {
|
1493 |
-
|
1494 |
-
|
1495 |
-
|
1496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1490 |
const double softcap,
|
1491 |
const bool return_softmax,
|
1492 |
const c10::optional<at::Generator> gen_) {
|
1493 |
+
|
1494 |
+
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1495 |
+
|
1496 |
+
// Prepare the optional arguments as non-const references.
|
1497 |
+
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
1498 |
+
std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
|
1499 |
+
|
1500 |
+
if (!out.has_value()){
|
1501 |
+
out = torch::empty_like(q);
|
1502 |
+
}
|
1503 |
+
|
1504 |
+
// Convert double to float and int64_t to int.
|
1505 |
+
float p_dropout_float = static_cast<float>(p_dropout);
|
1506 |
+
float softmax_scale_float = static_cast<float>(softmax_scale);
|
1507 |
+
float softcap_float = static_cast<float>(softcap);
|
1508 |
+
int window_size_left_int = static_cast<int>(window_size_left);
|
1509 |
+
int window_size_right_int = static_cast<int>(window_size_right);
|
1510 |
+
|
1511 |
+
return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
1512 |
+
}
|
torch-ext/flash_attn/__init__.py
CHANGED
@@ -19,7 +19,7 @@ def mha_fwd(
|
|
19 |
return_softmax: bool,
|
20 |
gen: Optional[torch.Generator],
|
21 |
) -> torch.Tensor:
|
22 |
-
|
23 |
q,
|
24 |
k,
|
25 |
v,
|
@@ -34,4 +34,4 @@ def mha_fwd(
|
|
34 |
return_softmax,
|
35 |
gen,
|
36 |
)
|
37 |
-
return out
|
|
|
19 |
return_softmax: bool,
|
20 |
gen: Optional[torch.Generator],
|
21 |
) -> torch.Tensor:
|
22 |
+
ops.mha_fwd(
|
23 |
q,
|
24 |
k,
|
25 |
v,
|
|
|
34 |
return_softmax,
|
35 |
gen,
|
36 |
)
|
37 |
+
return out
|