Mohamed Mekkouri
new builds
9ffd725
#pragma once
#include <torch/torch.h>
void f32_bf16w_matmul_torch(const at::Tensor& input,
const at::Tensor& weight_bf16,
const at::Tensor& bias_bf16,
at::Tensor& output,
int64_t num_tokens,
int64_t num_cols,
int64_t num_rows,
int64_t threadgroup_size);
void bf16_f32_embeddings_torch(const at::Tensor& token_ids,
const at::Tensor& weight_bf16,
at::Tensor& output,
int64_t threadgroup_size);
void f32_bf16w_rmsnorm_torch(const at::Tensor& input,
const at::Tensor& weight_bf16,
at::Tensor& output,
double epsilon);
void f32_bf16w_dense_matmul_qkv_torch(const at::Tensor& input,
const at::Tensor& weight_bf16,
const at::Tensor& bias_bf16,
at::Tensor& output);
void f32_bf16w_dense_matmul_attn_output_torch(const at::Tensor& input,
const at::Tensor& weight_bf16,
const at::Tensor& bias_bf16,
at::Tensor& output);
void f32_bf16w_dense_matmul_mlp_gate_torch(const at::Tensor& input,
const at::Tensor& weight_bf16,
const at::Tensor& bias_bf16,
at::Tensor& output);
void f32_rope_torch(at::Tensor& activations,
double rope_base,
double interpolation_scale,
double yarn_offset,
double yarn_scale,
double yarn_multiplier,
int64_t num_tokens,
int64_t num_q_heads,
int64_t num_kv_heads,
int64_t attn_head_dim,
int64_t token_offset,
int64_t threadgroup_size);
void f32_bf16w_matmul_qkv_torch(const at::Tensor& input,
const at::Tensor& weight_bf16,
const at::Tensor& bias_bf16,
at::Tensor& output,
at::Tensor& kv_cache,
int64_t kv_cache_offset_bytes,
int64_t num_tokens,
int64_t num_cols,
int64_t num_q_heads,
int64_t num_kv_heads,
int64_t attn_head_dim,
int64_t token_offset,
int64_t max_tokens,
double rope_base,
double interpolation_scale,
double yarn_offset,
double yarn_scale,
double yarn_multiplier,
int64_t threadgroup_size);
void f32_sdpa_torch(const at::Tensor& q,
int64_t q_offset_bytes,
const at::Tensor& kv,
int64_t kv_offset_bytes,
const at::Tensor& s_bf16,
int64_t s_offset_bytes,
at::Tensor& output,
int64_t output_offset_bytes,
int64_t window,
int64_t kv_stride,
int64_t num_q_tokens,
int64_t num_kv_tokens,
int64_t num_q_heads,
int64_t num_kv_heads,
int64_t head_dim);
void f32_topk_torch(const at::Tensor& scores,
at::Tensor& expert_ids,
at::Tensor& expert_scores,
int64_t num_tokens,
int64_t num_experts,
int64_t num_active_experts);
void expert_routing_metadata_torch(const at::Tensor& expert_ids,
const at::Tensor& expert_scores,
at::Tensor& expert_offsets,
at::Tensor& intra_expert_offsets,
int64_t num_tokens,
int64_t num_experts);
void f32_scatter_torch(const at::Tensor& input,
const at::Tensor& expert_ids,
const at::Tensor& expert_scores,
const at::Tensor& expert_offsets,
const at::Tensor& intra_expert_offsets,
at::Tensor& output,
int64_t num_channels,
int64_t num_tokens,
int64_t num_active_experts);
void f32_bf16w_matmul_add_torch(const at::Tensor& input,
const at::Tensor& weight_bf16,
const at::Tensor& bias_bf16,
at::Tensor& output,
int64_t num_tokens,
int64_t num_cols,
int64_t num_rows,
int64_t threadgroup_size);