CharmAGX_G1 / core /kernels /optimized_matmul.cu
GeminiFan207's picture
Create optimized_matmul.cu
9310065 verified
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <mma.h> // Tensor Core WMMA API
using namespace nvcuda;
// Define FP16 type for Tensor Cores
using half_t = __half;
// WMMA tile sizes (fixed for Tensor Cores)
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
#define BLOCK_SIZE 32
// Optimized GEMM kernel using Tensor Cores
__global__ void optimized_matmul_kernel(
const half_t* __restrict__ a, // Matrix A [m, k]
const half_t* __restrict__ b, // Matrix B [k, n]
half_t* __restrict__ c, // Matrix C [m, n]
int m, int n, int k) // Dimensions: A[m,k], B[k,n], C[m,n]
{
// Shared memory for WMMA tiles
__shared__ half_t shmem_a[WMMA_M * WMMA_K];
__shared__ half_t shmem_b[WMMA_K * WMMA_N];
// WMMA fragments
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half_t> c_frag;
// Thread indices
int tx = threadIdx.x;
int ty = threadIdx.y;
int bx = blockIdx.x;
int by = blockIdx.y;
// Global tile offsets
int row = by * WMMA_M + ty; // Row in C
int col = bx * WMMA_N + tx; // Column in C
// Initialize accumulator
wmma::fill_fragment(c_frag, __float2half(0.0f));
// Loop over K dimension in WMMA tiles
for (int tile_k = 0; tile_k < k; tile_k += WMMA_K) {
// Load A tile into shared memory (row-major)
if (row < m && tile_k + tx < k) {
shmem_a[ty * WMMA_K + tx] = a[row * k + tile_k + tx];
} else {
shmem_a[ty * WMMA_K + tx] = __float2half(0.0f);
}
// Load B tile into shared memory (col-major)
if (tile_k + ty < k && col < n) {
shmem_b[ty * WMMA_N + tx] = b[(tile_k + ty) * n + col];
} else {
shmem_b[ty * WMMA_N + tx] = __float2half(0.0f);
}
__syncthreads();
// Load WMMA fragments from shared memory
wmma::load_matrix_sync(a_frag, shmem_a, WMMA_K);
wmma::load_matrix_sync(b_frag, shmem_b, WMMA_N);
// Perform Tensor Core matrix multiply-accumulate
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();
}
// Store result to global memory
if (row < m && col < n) {
wmma::store_matrix_sync(&c[row * n + col], c_frag, n, wmma::mem_row_major);
}
}
// PyTorch binding
torch::Tensor optimized_matmul(
torch::Tensor a, // [m, k]
torch::Tensor b) // [k, n]
{
// Ensure inputs are FP16 and on CUDA
TORCH_CHECK(a.dtype() == torch::kFloat16, "Matrix A must be FP16");
TORCH_CHECK(b.dtype() == torch::kFloat16, "Matrix B must be FP16");
TORCH_CHECK(a.is_cuda(), "Matrix A must be on CUDA");
TORCH_CHECK(b.is_cuda(), "Matrix B must be on CUDA");
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors");
TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match");
// Dimensions
int m = a.size(0);
int k = a.size(1);
int n = b.size(1);
// Output tensor
auto c = torch::empty({m, n},
torch::TensorOptions().dtype(torch::kFloat16).device(a.device()));
// Grid and block dimensions
dim3 block(BLOCK_SIZE, WMMA_M / WARP_SIZE); // 32 threads per warp, WMMA_M/32 warps
dim3 grid((n + WMMA_N - 1) / WMMA_N, (m + WMMA_M - 1) / WMMA_M);
// Launch kernel
optimized_matmul_kernel<<<grid, block>>>(
(half_t*)a.data_ptr(),
(half_t*)b.data_ptr(),
(half_t*)c.data_ptr(),
m, n, k);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
TORCH_CHECK(false, "CUDA error: ", cudaGetErrorString(err));
}
return c;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("optimized_matmul", &optimized_matmul, "Tensor Core-optimized GEMM");
}