kernel
flash-attn / torch-ext /torch_binding.h
drbh
feat: include source and enable build
a7165c8
raw
history blame
898 Bytes
#pragma once
#include <torch/torch.h>
std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const double p_dropout,
const double softmax_scale,
bool is_causal,
const int64_t window_size_left,
const int64_t window_size_right,
const double softcap,
const bool return_softmax,
const c10::optional<at::Generator> gen_);