Spaces:
Running
on
Zero
Running
on
Zero
bf16 fa3
Browse files
fa3.py
CHANGED
@@ -10,8 +10,7 @@ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_f
|
|
10 |
|
11 |
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
12 |
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
13 |
-
|
14 |
-
outputs, lse = _flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype))
|
15 |
return outputs
|
16 |
|
17 |
@flash_attn_func.register_fake
|
|
|
10 |
|
11 |
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
12 |
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
13 |
+
outputs, lse = _flash_attn_func(q, k, v)
|
|
|
14 |
return outputs
|
15 |
|
16 |
@flash_attn_func.register_fake
|