#include #include #include #include // Tensor Core WMMA API using namespace nvcuda; // Define FP16 type for Tensor Cores using half_t = __half; // WMMA tile sizes (fixed for Tensor Cores) #define WMMA_M 16 #define WMMA_N 16 #define WMMA_K 16 #define BLOCK_SIZE 32 // Optimized GEMM kernel using Tensor Cores __global__ void optimized_matmul_kernel( const half_t* __restrict__ a, // Matrix A [m, k] const half_t* __restrict__ b, // Matrix B [k, n] half_t* __restrict__ c, // Matrix C [m, n] int m, int n, int k) // Dimensions: A[m,k], B[k,n], C[m,n] { // Shared memory for WMMA tiles __shared__ half_t shmem_a[WMMA_M * WMMA_K]; __shared__ half_t shmem_b[WMMA_K * WMMA_N]; // WMMA fragments wmma::fragment a_frag; wmma::fragment b_frag; wmma::fragment c_frag; // Thread indices int tx = threadIdx.x; int ty = threadIdx.y; int bx = blockIdx.x; int by = blockIdx.y; // Global tile offsets int row = by * WMMA_M + ty; // Row in C int col = bx * WMMA_N + tx; // Column in C // Initialize accumulator wmma::fill_fragment(c_frag, __float2half(0.0f)); // Loop over K dimension in WMMA tiles for (int tile_k = 0; tile_k < k; tile_k += WMMA_K) { // Load A tile into shared memory (row-major) if (row < m && tile_k + tx < k) { shmem_a[ty * WMMA_K + tx] = a[row * k + tile_k + tx]; } else { shmem_a[ty * WMMA_K + tx] = __float2half(0.0f); } // Load B tile into shared memory (col-major) if (tile_k + ty < k && col < n) { shmem_b[ty * WMMA_N + tx] = b[(tile_k + ty) * n + col]; } else { shmem_b[ty * WMMA_N + tx] = __float2half(0.0f); } __syncthreads(); // Load WMMA fragments from shared memory wmma::load_matrix_sync(a_frag, shmem_a, WMMA_K); wmma::load_matrix_sync(b_frag, shmem_b, WMMA_N); // Perform Tensor Core matrix multiply-accumulate wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); __syncthreads(); } // Store result to global memory if (row < m && col < n) { wmma::store_matrix_sync(&c[row * n + col], c_frag, n, wmma::mem_row_major); } } // PyTorch binding torch::Tensor optimized_matmul( torch::Tensor a, // [m, k] torch::Tensor b) // [k, n] { // Ensure inputs are FP16 and on CUDA TORCH_CHECK(a.dtype() == torch::kFloat16, "Matrix A must be FP16"); TORCH_CHECK(b.dtype() == torch::kFloat16, "Matrix B must be FP16"); TORCH_CHECK(a.is_cuda(), "Matrix A must be on CUDA"); TORCH_CHECK(b.is_cuda(), "Matrix B must be on CUDA"); TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors"); TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match"); // Dimensions int m = a.size(0); int k = a.size(1); int n = b.size(1); // Output tensor auto c = torch::empty({m, n}, torch::TensorOptions().dtype(torch::kFloat16).device(a.device())); // Grid and block dimensions dim3 block(BLOCK_SIZE, WMMA_M / WARP_SIZE); // 32 threads per warp, WMMA_M/32 warps dim3 grid((n + WMMA_N - 1) / WMMA_N, (m + WMMA_M - 1) / WMMA_M); // Launch kernel optimized_matmul_kernel<<>>( (half_t*)a.data_ptr(), (half_t*)b.data_ptr(), (half_t*)c.data_ptr(), m, n, k); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { TORCH_CHECK(false, "CUDA error: ", cudaGetErrorString(err)); } return c; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("optimized_matmul", &optimized_matmul, "Tensor Core-optimized GEMM"); }