| std::vector<at::Tensor> | |
| mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) | |
| const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) | |
| const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) | |
| const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) | |
| const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads | |
| const double p_dropout, | |
| const double softmax_scale, | |
| bool is_causal, | |
| const int64_t window_size_left, | |
| const int64_t window_size_right, | |
| const double softcap, | |
| const bool return_softmax, | |
| const c10::optional<at::Generator> gen_); |