drbh commited on
Commit
9002ff5
·
1 Parent(s): 39b4aba

fix: limit build head size and cuda caps for dev

Browse files
Files changed (4) hide show
  1. build.toml +21 -17
  2. flake.lock +4 -4
  3. flake.nix +1 -1
  4. flash_attn/src/static_switch.h +28 -23
build.toml CHANGED
@@ -5,7 +5,11 @@ name = "flash_attn"
5
  src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
6
 
7
  [kernel.flash_attn]
8
- cuda-capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
 
 
 
 
9
  src = [
10
  "flash_attn/flash_api.cpp",
11
  "flash_attn/src/philox_unpack.cuh",
@@ -47,10 +51,10 @@ src = [
47
  # "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
48
  # "flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
49
  # "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
50
- # "flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
51
- # "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
52
- # "flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
53
- # "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
54
  # "flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
55
  # "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
56
  # "flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
@@ -59,9 +63,9 @@ src = [
59
  # "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
60
  # "flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
61
  # "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
62
- # "flash_attn/src/flash_bwd_kernel.h",
63
- # "flash_attn/src/flash_bwd_launch_template.h",
64
- # "flash_attn/src/flash_bwd_preprocess_kernel.h",
65
 
66
  ## TODO: include fwd kernels
67
 
@@ -81,10 +85,10 @@ src = [
81
  # "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
82
  # "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
83
  # "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
84
- # "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
85
- # "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
86
- # "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
87
- # "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
88
  # "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
89
  # "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
90
  # "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
@@ -92,7 +96,7 @@ src = [
92
  # "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
93
  # "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
94
  # "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
95
- "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
96
  "flash_attn/src/flash_fwd_kernel.h",
97
  "flash_attn/src/flash_fwd_launch_template.h",
98
  # "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
@@ -111,10 +115,10 @@ src = [
111
  # "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
112
  # "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
113
  # "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
114
- # "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
115
- # "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
116
- # "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
117
- # "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
118
  # "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
119
  # "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
120
  # "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
 
5
  src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
6
 
7
  [kernel.flash_attn]
8
+ cuda-capabilities = [
9
+ # "7.0", "7.2", "7.5", "8.0", "8.6", "8.7",
10
+ "8.9",
11
+ # "9.0",
12
+ ]
13
  src = [
14
  "flash_attn/flash_api.cpp",
15
  "flash_attn/src/philox_unpack.cuh",
 
51
  # "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
52
  # "flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
53
  # "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
54
+ "flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
55
+ "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
56
+ "flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
57
+ "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
58
  # "flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
59
  # "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
60
  # "flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
 
63
  # "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
64
  # "flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
65
  # "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
66
+ "flash_attn/src/flash_bwd_kernel.h",
67
+ "flash_attn/src/flash_bwd_launch_template.h",
68
+ "flash_attn/src/flash_bwd_preprocess_kernel.h",
69
 
70
  ## TODO: include fwd kernels
71
 
 
85
  # "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
86
  # "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
87
  # "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
88
+ "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
89
+ "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
90
+ "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
91
+ "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
92
  # "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
93
  # "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
94
  # "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
 
96
  # "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
97
  # "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
98
  # "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
99
+ # "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
100
  "flash_attn/src/flash_fwd_kernel.h",
101
  "flash_attn/src/flash_fwd_launch_template.h",
102
  # "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
 
115
  # "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
116
  # "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
117
  # "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
118
+ "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
119
+ "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
120
+ "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
121
+ "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
122
  # "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
123
  # "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
124
  # "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
flake.lock CHANGED
@@ -41,11 +41,11 @@
41
  "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
- "lastModified": 1742905006,
45
- "narHash": "sha256-SCi1f5Lti4AM0kNPlAidcgN/5YM4HgJP4KwCsMrB0IE=",
46
  "ref": "refs/heads/main",
47
- "rev": "517a2bf2d0a3f1faf058ab995b6ca280b0999e7c",
48
- "revCount": 105,
49
  "type": "git",
50
  "url": "ssh://[email protected]/huggingface/kernel-builder"
51
  },
 
41
  "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
+ "lastModified": 1742916494,
45
+ "narHash": "sha256-crH7vQjPJZ1yrS0GA/waYvsRLLUN4PIj85L+Rpy0Q+U=",
46
  "ref": "refs/heads/main",
47
+ "rev": "faf433757fefae660ada7d003e394b1939989a5c",
48
+ "revCount": 106,
49
  "type": "git",
50
  "url": "ssh://[email protected]/huggingface/kernel-builder"
51
  },
flake.nix CHANGED
@@ -2,7 +2,7 @@
2
  description = "Flake for ReLU kernel";
3
 
4
  inputs = {
5
- kernel-builder.url = "git+ssh://git@github.com/huggingface/kernel-builder";
6
  };
7
 
8
  outputs =
 
2
  description = "Flake for ReLU kernel";
3
 
4
  inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
  };
7
 
8
  outputs =
flash_attn/src/static_switch.h CHANGED
@@ -87,28 +87,33 @@
87
  } \
88
  }()
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  #define HEADDIM_SWITCH(HEADDIM, ...) \
91
- [&] { \
92
- if (HEADDIM <= 32) { \
93
- constexpr static int kHeadDim = 32; \
94
- return __VA_ARGS__(); \
95
- } else if (HEADDIM <= 64) { \
96
- constexpr static int kHeadDim = 64; \
97
- return __VA_ARGS__(); \
98
- } else if (HEADDIM <= 96) { \
99
- constexpr static int kHeadDim = 96; \
100
- return __VA_ARGS__(); \
101
- } else if (HEADDIM <= 128) { \
102
- constexpr static int kHeadDim = 128; \
103
- return __VA_ARGS__(); \
104
- } else if (HEADDIM <= 160) { \
105
- constexpr static int kHeadDim = 160; \
106
- return __VA_ARGS__(); \
107
- } else if (HEADDIM <= 192) { \
108
- constexpr static int kHeadDim = 192; \
109
- return __VA_ARGS__(); \
110
- } else if (HEADDIM <= 256) { \
111
- constexpr static int kHeadDim = 256; \
112
- return __VA_ARGS__(); \
113
- } \
114
  }()
 
87
  } \
88
  }()
89
 
90
+ // #define HEADDIM_SWITCH(HEADDIM, ...) \
91
+ // [&] { \
92
+ // if (HEADDIM <= 32) { \
93
+ // constexpr static int kHeadDim = 32; \
94
+ // return __VA_ARGS__(); \
95
+ // } else if (HEADDIM <= 64) { \
96
+ // constexpr static int kHeadDim = 64; \
97
+ // return __VA_ARGS__(); \
98
+ // } else if (HEADDIM <= 96) { \
99
+ // constexpr static int kHeadDim = 96; \
100
+ // return __VA_ARGS__(); \
101
+ // } else if (HEADDIM <= 128) { \
102
+ // constexpr static int kHeadDim = 128; \
103
+ // return __VA_ARGS__(); \
104
+ // } else if (HEADDIM <= 160) { \
105
+ // constexpr static int kHeadDim = 160; \
106
+ // return __VA_ARGS__(); \
107
+ // } else if (HEADDIM <= 192) { \
108
+ // constexpr static int kHeadDim = 192; \
109
+ // return __VA_ARGS__(); \
110
+ // } else if (HEADDIM <= 256) { \
111
+ // constexpr static int kHeadDim = 256; \
112
+ // return __VA_ARGS__(); \
113
+ // } \
114
+ // }()
115
  #define HEADDIM_SWITCH(HEADDIM, ...) \
116
+ [&] { \
117
+ constexpr static int kHeadDim = 32; \
118
+ return __VA_ARGS__(); \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  }()