import torch | |
import flash_attn | |
# TODO: improve and add more tests | |
def test_flash_attn(): | |
q = torch.randn(2, 5, 4, 8) | |
k = torch.randn(2, 5, 4, 8) | |
v = torch.randn(2, 5, 4, 8) | |
out = torch.empty(2, 5, 4, 8) | |
alibi_slopes = torch.empty(4) | |
p_dropout = 0.1 | |
softmax_scale = 1.0 | |
is_causal = False | |
window_size_left = 0 | |
window_size_right = 0 | |
softcap = 0.0 | |
return_softmax = False | |
gen = None | |
out = flash_attn.mha_fwd( | |
q, | |
k, | |
v, | |
out, | |
alibi_slopes, | |
p_dropout, | |
softmax_scale, | |
is_causal, | |
window_size_left, | |
window_size_right, | |
softcap, | |
return_softmax, | |
gen, | |
) | |
assert out.shape == (2, 5, 4, 8) | |