#pragma once #include std::vector mha_fwd(const torch::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const c10::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const double p_dropout, const double softmax_scale, bool is_causal, const int64_t window_size_left, const int64_t window_size_right, const double softcap, const bool return_softmax, const c10::optional gen_); std::vector mha_varlen_fwd( const torch::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const torch::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_> const torch::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_> const c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const torch::Tensor &cu_seqlens_q, // b+1 const torch::Tensor &cu_seqlens_k, // b+1 const c10::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. const c10::optional &leftpad_k_, // batch_size const c10::optional &block_table_, // batch_size x max_num_blocks_per_seq const c10::optional &alibi_slopes_, // num_heads or b x num_heads const int64_t max_seqlen_q, const int64_t max_seqlen_k, const double p_dropout, const double softmax_scale, const bool zero_tensors, const bool is_causal, const int64_t window_size_left, const int64_t window_size_right, const double softcap, const bool return_softmax, const c10::optional gen_); std::vector mha_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const torch::Tensor &softmax_lse, // b x h x seqlen_q const c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size const c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size const c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size const c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const double p_dropout, // probability to drop const double softmax_scale, const bool is_causal, const int64_t window_size_left, const int64_t window_size_right, const double softcap, const bool deterministic, c10::optional gen_, const c10::optional &rng_state); std::vector mha_varlen_bwd( const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size const torch::Tensor &softmax_lse, // b x h x seqlen_q const c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size const c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size const c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size const torch::Tensor &cu_seqlens_q, // batch_size + 1 const torch::Tensor &cu_seqlens_k, // batch_size + 1 const c10::optional &alibi_slopes_, // num_heads or b x num_heads const int64_t max_seqlen_q, const int64_t max_seqlen_k, const double p_dropout, const double softmax_scale, const bool zero_tensors, const bool is_causal, const int64_t window_size_left, const int64_t window_size_right, const double softcap, const bool deterministic, c10::optional gen_, const c10::optional &rng_state); std::vector mha_fwd_kvcache( const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const torch::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const torch::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. const c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size const c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size const c10::optional &seqlens_k_, // batch_size const c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) const c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) const c10::optional &cache_batch_idx_, // indices to index into the KV cache const c10::optional &leftpad_k_, // batch_size const c10::optional &block_table_, // batch_size x max_num_blocks_per_seq const c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const double softmax_scale, bool is_causal, const int64_t window_size_left, const int64_t window_size_right, const double softcap, bool is_rotary_interleaved, const int64_t num_splits);