cbensimon HF Staff commited on
Commit
c60d44a
·
1 Parent(s): 7301ed0
Files changed (1) hide show
  1. fa3.py +1 -2
fa3.py CHANGED
@@ -10,8 +10,7 @@ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_f
10
 
11
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
12
  def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
13
- dtype = torch.float8_e4m3fn
14
- outputs, lse = _flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype))
15
  return outputs
16
 
17
  @flash_attn_func.register_fake
 
10
 
11
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
12
  def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
13
+ outputs, lse = _flash_attn_func(q, k, v)
 
14
  return outputs
15
 
16
  @flash_attn_func.register_fake