flash-attn / build.toml
drbh
feat: pass vars into fwd and include build
39b4aba
raw
history blame
6.12 kB
[general]
name = "flash_attn"
[torch]
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
[kernel.flash_attn]
cuda-capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
src = [
"flash_attn/flash_api.cpp",
"flash_attn/src/philox_unpack.cuh",
"flash_attn/src/namespace_config.h",
"flash_attn/src/hardware_info.h",
"flash_attn/src/flash.h",
"flash_attn/src/static_switch.h",
"flash_attn/src/alibi.h",
"flash_attn/src/block_info.h",
"flash_attn/src/dropout.h",
"flash_attn/src/flash.h",
"flash_attn/src/generate_kernels.py",
"flash_attn/src/hardware_info.h",
"flash_attn/src/kernel_traits.h",
"flash_attn/src/mask.h",
"flash_attn/src/namespace_config.h",
"flash_attn/src/philox.cuh",
"flash_attn/src/philox_unpack.cuh",
"flash_attn/src/rotary.h",
"flash_attn/src/softmax.h",
"flash_attn/src/static_switch.h",
"flash_attn/src/utils.h",
## TODO: include bwd kernels
# "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
# "flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
# "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
# "flash_attn/src/flash_bwd_kernel.h",
# "flash_attn/src/flash_bwd_launch_template.h",
# "flash_attn/src/flash_bwd_preprocess_kernel.h",
## TODO: include fwd kernels
# "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
"flash_attn/src/flash_fwd_kernel.h",
"flash_attn/src/flash_fwd_launch_template.h",
# "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
# "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
]
depends = ["torch", "cutlass_3_6"]