GeminiFan207 commited on
Commit
9310065
·
verified ·
1 Parent(s): 6ae562b

Create optimized_matmul.cu

Browse files
Files changed (1) hide show
  1. core/kernels/optimized_matmul.cu +123 -0
core/kernels/optimized_matmul.cu ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda_fp16.h>
2
+ #include <cuda_runtime.h>
3
+ #include <torch/extension.h>
4
+ #include <mma.h> // Tensor Core WMMA API
5
+
6
+ using namespace nvcuda;
7
+
8
+ // Define FP16 type for Tensor Cores
9
+ using half_t = __half;
10
+
11
+ // WMMA tile sizes (fixed for Tensor Cores)
12
+ #define WMMA_M 16
13
+ #define WMMA_N 16
14
+ #define WMMA_K 16
15
+ #define BLOCK_SIZE 32
16
+
17
+ // Optimized GEMM kernel using Tensor Cores
18
+ __global__ void optimized_matmul_kernel(
19
+ const half_t* __restrict__ a, // Matrix A [m, k]
20
+ const half_t* __restrict__ b, // Matrix B [k, n]
21
+ half_t* __restrict__ c, // Matrix C [m, n]
22
+ int m, int n, int k) // Dimensions: A[m,k], B[k,n], C[m,n]
23
+ {
24
+ // Shared memory for WMMA tiles
25
+ __shared__ half_t shmem_a[WMMA_M * WMMA_K];
26
+ __shared__ half_t shmem_b[WMMA_K * WMMA_N];
27
+
28
+ // WMMA fragments
29
+ wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::row_major> a_frag;
30
+ wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::col_major> b_frag;
31
+ wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half_t> c_frag;
32
+
33
+ // Thread indices
34
+ int tx = threadIdx.x;
35
+ int ty = threadIdx.y;
36
+ int bx = blockIdx.x;
37
+ int by = blockIdx.y;
38
+
39
+ // Global tile offsets
40
+ int row = by * WMMA_M + ty; // Row in C
41
+ int col = bx * WMMA_N + tx; // Column in C
42
+
43
+ // Initialize accumulator
44
+ wmma::fill_fragment(c_frag, __float2half(0.0f));
45
+
46
+ // Loop over K dimension in WMMA tiles
47
+ for (int tile_k = 0; tile_k < k; tile_k += WMMA_K) {
48
+ // Load A tile into shared memory (row-major)
49
+ if (row < m && tile_k + tx < k) {
50
+ shmem_a[ty * WMMA_K + tx] = a[row * k + tile_k + tx];
51
+ } else {
52
+ shmem_a[ty * WMMA_K + tx] = __float2half(0.0f);
53
+ }
54
+
55
+ // Load B tile into shared memory (col-major)
56
+ if (tile_k + ty < k && col < n) {
57
+ shmem_b[ty * WMMA_N + tx] = b[(tile_k + ty) * n + col];
58
+ } else {
59
+ shmem_b[ty * WMMA_N + tx] = __float2half(0.0f);
60
+ }
61
+
62
+ __syncthreads();
63
+
64
+ // Load WMMA fragments from shared memory
65
+ wmma::load_matrix_sync(a_frag, shmem_a, WMMA_K);
66
+ wmma::load_matrix_sync(b_frag, shmem_b, WMMA_N);
67
+
68
+ // Perform Tensor Core matrix multiply-accumulate
69
+ wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
70
+
71
+ __syncthreads();
72
+ }
73
+
74
+ // Store result to global memory
75
+ if (row < m && col < n) {
76
+ wmma::store_matrix_sync(&c[row * n + col], c_frag, n, wmma::mem_row_major);
77
+ }
78
+ }
79
+
80
+ // PyTorch binding
81
+ torch::Tensor optimized_matmul(
82
+ torch::Tensor a, // [m, k]
83
+ torch::Tensor b) // [k, n]
84
+ {
85
+ // Ensure inputs are FP16 and on CUDA
86
+ TORCH_CHECK(a.dtype() == torch::kFloat16, "Matrix A must be FP16");
87
+ TORCH_CHECK(b.dtype() == torch::kFloat16, "Matrix B must be FP16");
88
+ TORCH_CHECK(a.is_cuda(), "Matrix A must be on CUDA");
89
+ TORCH_CHECK(b.is_cuda(), "Matrix B must be on CUDA");
90
+ TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors");
91
+ TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match");
92
+
93
+ // Dimensions
94
+ int m = a.size(0);
95
+ int k = a.size(1);
96
+ int n = b.size(1);
97
+
98
+ // Output tensor
99
+ auto c = torch::empty({m, n},
100
+ torch::TensorOptions().dtype(torch::kFloat16).device(a.device()));
101
+
102
+ // Grid and block dimensions
103
+ dim3 block(BLOCK_SIZE, WMMA_M / WARP_SIZE); // 32 threads per warp, WMMA_M/32 warps
104
+ dim3 grid((n + WMMA_N - 1) / WMMA_N, (m + WMMA_M - 1) / WMMA_M);
105
+
106
+ // Launch kernel
107
+ optimized_matmul_kernel<<<grid, block>>>(
108
+ (half_t*)a.data_ptr(),
109
+ (half_t*)b.data_ptr(),
110
+ (half_t*)c.data_ptr(),
111
+ m, n, k);
112
+
113
+ cudaError_t err = cudaGetLastError();
114
+ if (err != cudaSuccess) {
115
+ TORCH_CHECK(false, "CUDA error: ", cudaGetErrorString(err));
116
+ }
117
+
118
+ return c;
119
+ }
120
+
121
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
122
+ m.def("optimized_matmul", &optimized_matmul, "Tensor Core-optimized GEMM");
123
+ }