import torch import flash_attn # TODO: improve and add more tests def test_flash_attn(): q = torch.randn(2, 5, 4, 8) k = torch.randn(2, 5, 4, 8) v = torch.randn(2, 5, 4, 8) out = torch.empty(2, 5, 4, 8) alibi_slopes = torch.empty(4) p_dropout = 0.1 softmax_scale = 1.0 is_causal = False window_size_left = 0 window_size_right = 0 softcap = 0.0 return_softmax = False gen = None out = flash_attn.mha_fwd( q, k, v, out, alibi_slopes, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, softcap, return_softmax, gen, ) assert out.shape == (2, 5, 4, 8)