flash-attn / tests /test_flash_attn.py
drbh
feat: include source and enable build
a7165c8
import torch
import flash_attn
# TODO: improve and add more tests
def test_flash_attn():
q = torch.randn(2, 5, 4, 8)
k = torch.randn(2, 5, 4, 8)
v = torch.randn(2, 5, 4, 8)
out = torch.empty(2, 5, 4, 8)
alibi_slopes = torch.empty(4)
p_dropout = 0.1
softmax_scale = 1.0
is_causal = False
window_size_left = 0
window_size_right = 0
softcap = 0.0
return_softmax = False
gen = None
out = flash_attn.mha_fwd(
q,
k,
v,
out,
alibi_slopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
gen,
)
assert out.shape == (2, 5, 4, 8)