|
#pragma once |
|
|
|
#include <torch/torch.h> |
|
|
|
std::vector<torch::Tensor> |
|
mha_fwd(const torch::Tensor &q, |
|
const torch::Tensor &k, |
|
const torch::Tensor &v, |
|
const c10::optional<torch::Tensor> &out_, |
|
const c10::optional<torch::Tensor> &alibi_slopes_, |
|
const double p_dropout, |
|
const double softmax_scale, |
|
bool is_causal, |
|
const int64_t window_size_left, |
|
const int64_t window_size_right, |
|
const double softcap, |
|
const bool return_softmax, |
|
const c10::optional<at::Generator> gen_); |
|
|
|
std::vector<torch::Tensor> |
|
mha_varlen_fwd( |
|
const torch::Tensor &q, |
|
const torch::Tensor &k, |
|
const torch::Tensor &v, |
|
const c10::optional<torch::Tensor> &out_, |
|
const torch::Tensor &cu_seqlens_q, |
|
const torch::Tensor &cu_seqlens_k, |
|
const c10::optional<torch::Tensor> &seqused_k_, |
|
const c10::optional<torch::Tensor> &leftpad_k_, |
|
const c10::optional<torch::Tensor> &block_table_, |
|
const c10::optional<torch::Tensor> &alibi_slopes_, |
|
const int64_t max_seqlen_q, |
|
const int64_t max_seqlen_k, |
|
const double p_dropout, |
|
const double softmax_scale, |
|
const bool zero_tensors, |
|
const bool is_causal, |
|
const int64_t window_size_left, |
|
const int64_t window_size_right, |
|
const double softcap, |
|
const bool return_softmax, |
|
const c10::optional<at::Generator> gen_); |
|
|
|
|
|
std::vector<torch::Tensor> |
|
mha_bwd(const torch::Tensor &dout, |
|
const torch::Tensor &q, |
|
const torch::Tensor &k, |
|
const torch::Tensor &v, |
|
const torch::Tensor &out, |
|
const torch::Tensor &softmax_lse, |
|
const c10::optional<torch::Tensor> &dq_, |
|
const c10::optional<torch::Tensor> &dk_, |
|
const c10::optional<torch::Tensor> &dv_, |
|
const c10::optional<torch::Tensor> &alibi_slopes_, |
|
const double p_dropout, |
|
const double softmax_scale, |
|
const bool is_causal, |
|
const int64_t window_size_left, |
|
const int64_t window_size_right, |
|
const double softcap, |
|
const bool deterministic, |
|
c10::optional<at::Generator> gen_, |
|
const c10::optional<torch::Tensor> &rng_state); |
|
|
|
|
|
std::vector<torch::Tensor> |
|
mha_varlen_bwd( |
|
const torch::Tensor &dout, |
|
const torch::Tensor &q, |
|
const torch::Tensor &k, |
|
const torch::Tensor &v, |
|
const torch::Tensor &out, |
|
const torch::Tensor &softmax_lse, |
|
const c10::optional<torch::Tensor> &dq_, |
|
const c10::optional<torch::Tensor> &dk_, |
|
const c10::optional<torch::Tensor> &dv_, |
|
const torch::Tensor &cu_seqlens_q, |
|
const torch::Tensor &cu_seqlens_k, |
|
const c10::optional<torch::Tensor> &alibi_slopes_, |
|
const int64_t max_seqlen_q, |
|
const int64_t max_seqlen_k, |
|
const double p_dropout, |
|
const double softmax_scale, |
|
const bool zero_tensors, |
|
const bool is_causal, |
|
const int64_t window_size_left, |
|
const int64_t window_size_right, |
|
const double softcap, |
|
const bool deterministic, |
|
c10::optional<at::Generator> gen_, |
|
const c10::optional<torch::Tensor> &rng_state); |
|
|
|
std::vector<torch::Tensor> |
|
mha_fwd_kvcache( |
|
const torch::Tensor &q, |
|
const torch::Tensor &kcache, |
|
const torch::Tensor &vcache, |
|
const c10::optional<torch::Tensor> &k_, |
|
const c10::optional<torch::Tensor> &v_, |
|
const c10::optional<torch::Tensor> &seqlens_k_, |
|
const c10::optional<torch::Tensor> &rotary_cos_, |
|
const c10::optional<torch::Tensor> &rotary_sin_, |
|
const c10::optional<torch::Tensor> &cache_batch_idx_, |
|
const c10::optional<torch::Tensor> &leftpad_k_, |
|
const c10::optional<torch::Tensor> &block_table_, |
|
const c10::optional<torch::Tensor> &alibi_slopes_, |
|
const c10::optional<torch::Tensor> &out_, |
|
const double softmax_scale, |
|
bool is_causal, |
|
const int64_t window_size_left, |
|
const int64_t window_size_right, |
|
const double softcap, |
|
bool is_rotary_interleaved, |
|
const int64_t num_splits); |
|
|