|
#include <cuda_fp16.h> |
|
#include <cuda_runtime.h> |
|
#include <torch/extension.h> |
|
#include <mma.h> |
|
|
|
using namespace nvcuda; |
|
|
|
|
|
using half_t = __half; |
|
|
|
|
|
#define WMMA_M 16 |
|
#define WMMA_N 16 |
|
#define WMMA_K 16 |
|
#define BLOCK_SIZE 32 |
|
|
|
|
|
__global__ void optimized_matmul_kernel( |
|
const half_t* __restrict__ a, |
|
const half_t* __restrict__ b, |
|
half_t* __restrict__ c, |
|
int m, int n, int k) |
|
{ |
|
|
|
__shared__ half_t shmem_a[WMMA_M * WMMA_K]; |
|
__shared__ half_t shmem_b[WMMA_K * WMMA_N]; |
|
|
|
|
|
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::row_major> a_frag; |
|
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::col_major> b_frag; |
|
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half_t> c_frag; |
|
|
|
|
|
int tx = threadIdx.x; |
|
int ty = threadIdx.y; |
|
int bx = blockIdx.x; |
|
int by = blockIdx.y; |
|
|
|
|
|
int row = by * WMMA_M + ty; |
|
int col = bx * WMMA_N + tx; |
|
|
|
|
|
wmma::fill_fragment(c_frag, __float2half(0.0f)); |
|
|
|
|
|
for (int tile_k = 0; tile_k < k; tile_k += WMMA_K) { |
|
|
|
if (row < m && tile_k + tx < k) { |
|
shmem_a[ty * WMMA_K + tx] = a[row * k + tile_k + tx]; |
|
} else { |
|
shmem_a[ty * WMMA_K + tx] = __float2half(0.0f); |
|
} |
|
|
|
|
|
if (tile_k + ty < k && col < n) { |
|
shmem_b[ty * WMMA_N + tx] = b[(tile_k + ty) * n + col]; |
|
} else { |
|
shmem_b[ty * WMMA_N + tx] = __float2half(0.0f); |
|
} |
|
|
|
__syncthreads(); |
|
|
|
|
|
wmma::load_matrix_sync(a_frag, shmem_a, WMMA_K); |
|
wmma::load_matrix_sync(b_frag, shmem_b, WMMA_N); |
|
|
|
|
|
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); |
|
|
|
__syncthreads(); |
|
} |
|
|
|
|
|
if (row < m && col < n) { |
|
wmma::store_matrix_sync(&c[row * n + col], c_frag, n, wmma::mem_row_major); |
|
} |
|
} |
|
|
|
|
|
torch::Tensor optimized_matmul( |
|
torch::Tensor a, |
|
torch::Tensor b) |
|
{ |
|
|
|
TORCH_CHECK(a.dtype() == torch::kFloat16, "Matrix A must be FP16"); |
|
TORCH_CHECK(b.dtype() == torch::kFloat16, "Matrix B must be FP16"); |
|
TORCH_CHECK(a.is_cuda(), "Matrix A must be on CUDA"); |
|
TORCH_CHECK(b.is_cuda(), "Matrix B must be on CUDA"); |
|
TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors"); |
|
TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match"); |
|
|
|
|
|
int m = a.size(0); |
|
int k = a.size(1); |
|
int n = b.size(1); |
|
|
|
|
|
auto c = torch::empty({m, n}, |
|
torch::TensorOptions().dtype(torch::kFloat16).device(a.device())); |
|
|
|
|
|
dim3 block(BLOCK_SIZE, WMMA_M / WARP_SIZE); |
|
dim3 grid((n + WMMA_N - 1) / WMMA_N, (m + WMMA_M - 1) / WMMA_M); |
|
|
|
|
|
optimized_matmul_kernel<<<grid, block>>>( |
|
(half_t*)a.data_ptr(), |
|
(half_t*)b.data_ptr(), |
|
(half_t*)c.data_ptr(), |
|
m, n, k); |
|
|
|
cudaError_t err = cudaGetLastError(); |
|
if (err != cudaSuccess) { |
|
TORCH_CHECK(false, "CUDA error: ", cudaGetErrorString(err)); |
|
} |
|
|
|
return c; |
|
} |
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def("optimized_matmul", &optimized_matmul, "Tensor Core-optimized GEMM"); |
|
} |