#pragma once #include void f32_bf16w_matmul_torch(const at::Tensor& input, const at::Tensor& weight_bf16, const at::Tensor& bias_bf16, at::Tensor& output, int64_t num_tokens, int64_t num_cols, int64_t num_rows, int64_t threadgroup_size); void bf16_f32_embeddings_torch(const at::Tensor& token_ids, const at::Tensor& weight_bf16, at::Tensor& output, int64_t threadgroup_size); void f32_bf16w_rmsnorm_torch(const at::Tensor& input, const at::Tensor& weight_bf16, at::Tensor& output, double epsilon); void f32_bf16w_dense_matmul_qkv_torch(const at::Tensor& input, const at::Tensor& weight_bf16, const at::Tensor& bias_bf16, at::Tensor& output); void f32_bf16w_dense_matmul_attn_output_torch(const at::Tensor& input, const at::Tensor& weight_bf16, const at::Tensor& bias_bf16, at::Tensor& output); void f32_bf16w_dense_matmul_mlp_gate_torch(const at::Tensor& input, const at::Tensor& weight_bf16, const at::Tensor& bias_bf16, at::Tensor& output); void f32_rope_torch(at::Tensor& activations, double rope_base, double interpolation_scale, double yarn_offset, double yarn_scale, double yarn_multiplier, int64_t num_tokens, int64_t num_q_heads, int64_t num_kv_heads, int64_t attn_head_dim, int64_t token_offset, int64_t threadgroup_size); void f32_bf16w_matmul_qkv_torch(const at::Tensor& input, const at::Tensor& weight_bf16, const at::Tensor& bias_bf16, at::Tensor& output, at::Tensor& kv_cache, int64_t kv_cache_offset_bytes, int64_t num_tokens, int64_t num_cols, int64_t num_q_heads, int64_t num_kv_heads, int64_t attn_head_dim, int64_t token_offset, int64_t max_tokens, double rope_base, double interpolation_scale, double yarn_offset, double yarn_scale, double yarn_multiplier, int64_t threadgroup_size); void f32_sdpa_torch(const at::Tensor& q, int64_t q_offset_bytes, const at::Tensor& kv, int64_t kv_offset_bytes, const at::Tensor& s_bf16, int64_t s_offset_bytes, at::Tensor& output, int64_t output_offset_bytes, int64_t window, int64_t kv_stride, int64_t num_q_tokens, int64_t num_kv_tokens, int64_t num_q_heads, int64_t num_kv_heads, int64_t head_dim); void f32_topk_torch(const at::Tensor& scores, at::Tensor& expert_ids, at::Tensor& expert_scores, int64_t num_tokens, int64_t num_experts, int64_t num_active_experts); void expert_routing_metadata_torch(const at::Tensor& expert_ids, const at::Tensor& expert_scores, at::Tensor& expert_offsets, at::Tensor& intra_expert_offsets, int64_t num_tokens, int64_t num_experts); void f32_scatter_torch(const at::Tensor& input, const at::Tensor& expert_ids, const at::Tensor& expert_scores, const at::Tensor& expert_offsets, const at::Tensor& intra_expert_offsets, at::Tensor& output, int64_t num_channels, int64_t num_tokens, int64_t num_active_experts); void f32_bf16w_matmul_add_torch(const at::Tensor& input, const at::Tensor& weight_bf16, const at::Tensor& bias_bf16, at::Tensor& output, int64_t num_tokens, int64_t num_cols, int64_t num_rows, int64_t threadgroup_size);