File size: 5,934 Bytes
a7165c8
 
 
 
 
 
 
9002ff5
d774688
 
 
 
9002ff5
d774688
9002ff5
a7165c8
 
 
 
 
 
 
 
 
 
39b4aba
 
 
 
 
 
 
 
 
 
 
 
 
b0d3c12
a7165c8
b0d3c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9002ff5
 
 
 
b0d3c12
 
 
 
 
 
 
 
9002ff5
 
 
a7165c8
b0d3c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9002ff5
 
 
 
b0d3c12
 
 
 
 
 
 
 
a7165c8
 
b0d3c12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9002ff5
 
 
 
b0d3c12
 
 
 
 
 
 
 
a7165c8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
[general]
name = "flash_attn"

[torch]
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]

[kernel.flash_attn]
cuda-capabilities = [
  # "7.0", "7.2", "7.5", 
  "8.0",
  "8.6",
  "8.7",
  "8.9",
  "9.0",
]
src = [
  "flash_attn/flash_api.cpp",
  "flash_attn/src/philox_unpack.cuh",
  "flash_attn/src/namespace_config.h",
  "flash_attn/src/hardware_info.h",
  "flash_attn/src/flash.h",
  "flash_attn/src/static_switch.h",
  "flash_attn/src/alibi.h",
  "flash_attn/src/block_info.h",
  "flash_attn/src/dropout.h",
  "flash_attn/src/flash.h",
  "flash_attn/src/generate_kernels.py",
  "flash_attn/src/hardware_info.h",
  "flash_attn/src/kernel_traits.h",
  "flash_attn/src/mask.h",
  "flash_attn/src/namespace_config.h",
  "flash_attn/src/philox.cuh",
  "flash_attn/src/philox_unpack.cuh",
  "flash_attn/src/rotary.h",
  "flash_attn/src/softmax.h",
  "flash_attn/src/static_switch.h",
  "flash_attn/src/utils.h",

  ## bwd kernels

  "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
  "flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
  "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
  "flash_attn/src/flash_bwd_kernel.h",
  "flash_attn/src/flash_bwd_launch_template.h",
  "flash_attn/src/flash_bwd_preprocess_kernel.h",

  ## fwd kernels
  "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_kernel.h",
  "flash_attn/src/flash_fwd_launch_template.h",
  "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
  "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
]
depends = ["torch", "cutlass_3_6"]