std::vector<at::Tensor> | |
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) | |
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) | |
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) | |
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) | |
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads | |
const double p_dropout, | |
const double softmax_scale, | |
bool is_causal, | |
const int64_t window_size_left, | |
const int64_t window_size_right, | |
const double softcap, | |
const bool return_softmax, | |
const c10::optional<at::Generator> gen_); |