Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	bf16 fa3
Browse files
    	
        fa3.py
    CHANGED
    
    | 
         @@ -10,8 +10,7 @@ _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_f 
     | 
|
| 10 | 
         | 
| 11 | 
         
             
            @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
         
     | 
| 12 | 
         
             
            def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
         
     | 
| 13 | 
         
            -
                 
     | 
| 14 | 
         
            -
                outputs, lse = _flash_attn_func(q.to(dtype), k.to(dtype), v.to(dtype))
         
     | 
| 15 | 
         
             
                return outputs
         
     | 
| 16 | 
         | 
| 17 | 
         
             
            @flash_attn_func.register_fake
         
     | 
| 
         | 
|
| 10 | 
         | 
| 11 | 
         
             
            @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
         
     | 
| 12 | 
         
             
            def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
         
     | 
| 13 | 
         
            +
                outputs, lse = _flash_attn_func(q, k, v)
         
     | 
| 
         | 
|
| 14 | 
         
             
                return outputs
         
     | 
| 15 | 
         | 
| 16 | 
         
             
            @flash_attn_func.register_fake
         
     |