GeminiFan207 commited on
Commit
0f701ea
·
verified ·
1 Parent(s): 1469b78

Create fused_ops.cu

Browse files
Files changed (1) hide show
  1. core/kernels/fused_ops.cu +137 -0
core/kernels/fused_ops.cu ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda_fp16.h>
2
+ #include <cuda_runtime.h>
3
+ #include <torch/extension.h>
4
+ #include <mma.h> // Tensor Core WMMA API
5
+
6
+ using namespace nvcuda;
7
+
8
+ // Define FP16 type for Tensor Cores
9
+ using half_t = __half;
10
+
11
+ // Thread block and warp sizes
12
+ #define BLOCK_SIZE 32
13
+ #define WARP_SIZE 32
14
+ #define WMMA_M 16
15
+ #define WMMA_N 16
16
+ #define WMMA_K 16
17
+
18
+ // Fused sparse GEMM + ReLU kernel
19
+ __global__ void fused_sparse_gemm_relu_kernel(
20
+ const half_t* __restrict__ input, // Input tensor [batch_size, in_features]
21
+ const half_t* __restrict__ weight, // Weight tensor [out_features, in_features]
22
+ const half_t* __restrict__ mask, // Sparsity mask [out_features, in_features]
23
+ half_t* __restrict__ output, // Output tensor [batch_size, out_features]
24
+ const half_t* __restrict__ bias, // Bias tensor [out_features]
25
+ int batch_size, int in_features, int out_features)
26
+ {
27
+ // Shared memory for WMMA fragments
28
+ __shared__ half_t shmem_input[BLOCK_SIZE * WMMA_K];
29
+ __shared__ half_t shmem_weight[WMMA_M * WMMA_K];
30
+
31
+ // WMMA fragments
32
+ wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::row_major> a_frag;
33
+ wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half_t, wmma::col_major> b_frag;
34
+ wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half_t> c_frag;
35
+
36
+ // Thread indices
37
+ int bx = blockIdx.x;
38
+ int by = blockIdx.y;
39
+ int tx = threadIdx.x;
40
+ int ty = threadIdx.y;
41
+
42
+ // Global indices
43
+ int row = by * WMMA_M + ty; // Output row
44
+ int col = bx * WMMA_N + tx; // Output col
45
+
46
+ // Compute tile offsets
47
+ int batch_offset = blockIdx.z * in_features; // Batch dimension
48
+
49
+ // Initialize accumulator
50
+ wmma::fill_fragment(c_frag, __float2half(0.0f));
51
+
52
+ // Loop over K dimension (in_features) in WMMA tiles
53
+ for (int k = 0; k < in_features; k += WMMA_K) {
54
+ // Load input into shared memory
55
+ if (ty < WMMA_K && row < batch_size) {
56
+ shmem_input[ty * BLOCK_SIZE + tx] = input[batch_offset + row * in_features + k + tx];
57
+ }
58
+
59
+ // Load sparse weight into shared memory (apply mask)
60
+ if (ty < WMMA_M && k + tx < in_features && row < out_features) {
61
+ half_t w = weight[row * in_features + k + tx];
62
+ half_t m = mask[row * in_features + k + tx];
63
+ shmem_weight[ty * WMMA_K + tx] = __hmul(w, m); // Apply sparsity mask
64
+ }
65
+
66
+ __syncthreads();
67
+
68
+ // Load WMMA fragments from shared memory
69
+ wmma::load_matrix_sync(a_frag, shmem_input, BLOCK_SIZE);
70
+ wmma::load_matrix_sync(b_frag, shmem_weight, WMMA_K);
71
+
72
+ // Perform Tensor Core matrix multiply-accumulate
73
+ wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
74
+
75
+ __syncthreads();
76
+ }
77
+
78
+ // Store result with ReLU and bias
79
+ if (row < batch_size && col < out_features) {
80
+ half_t result = c_frag.x[ty * WMMA_N + tx];
81
+ result = __hadd(result, bias[col]); // Add bias
82
+ output[row * out_features + col] = __hgt(result, __float2half(0.0f)) ? result : __float2half(0.0f); // ReLU
83
+ }
84
+ }
85
+
86
+ // PyTorch binding
87
+ torch::Tensor fused_sparse_gemm_relu(
88
+ torch::Tensor input, // [batch_size, in_features]
89
+ torch::Tensor weight, // [out_features, in_features]
90
+ torch::Tensor mask, // [out_features, in_features]
91
+ torch::Tensor bias) // [out_features]
92
+ {
93
+ // Ensure inputs are FP16 and on CUDA
94
+ TORCH_CHECK(input.dtype() == torch::kFloat16, "Input must be FP16");
95
+ TORCH_CHECK(weight.dtype() == torch::kFloat16, "Weight must be FP16");
96
+ TORCH_CHECK(mask.dtype() == torch::kFloat16, "Mask must be FP16");
97
+ TORCH_CHECK(bias.dtype() == torch::kFloat16, "Bias must be FP16");
98
+ TORCH_CHECK(input.is_cuda(), "Input must be on CUDA");
99
+ TORCH_CHECK(weight.is_cuda(), "Weight must be on CUDA");
100
+ TORCH_CHECK(mask.is_cuda(), "Mask must be on CUDA");
101
+ TORCH_CHECK(bias.is_cuda(), "Bias must be on CUDA");
102
+
103
+ // Dimensions
104
+ int batch_size = input.size(0);
105
+ int in_features = input.size(1);
106
+ int out_features = weight.size(0);
107
+
108
+ // Output tensor
109
+ auto output = torch::empty({batch_size, out_features},
110
+ torch::TensorOptions().dtype(torch::kFloat16).device(input.device()));
111
+
112
+ // Grid and block dimensions
113
+ dim3 block(BLOCK_SIZE, WMMA_M / WARP_SIZE);
114
+ dim3 grid((out_features + WMMA_N - 1) / WMMA_N,
115
+ (batch_size + WMMA_M - 1) / WMMA_M,
116
+ batch_size);
117
+
118
+ // Launch kernel
119
+ fused_sparse_gemm_relu_kernel<<<grid, block>>>(
120
+ (half_t*)input.data_ptr(),
121
+ (half_t*)weight.data_ptr(),
122
+ (half_t*)mask.data_ptr(),
123
+ (half_t*)output.data_ptr(),
124
+ (half_t*)bias.data_ptr(),
125
+ batch_size, in_features, out_features);
126
+
127
+ cudaError_t err = cudaGetLastError();
128
+ if (err != cudaSuccess) {
129
+ TORCH_CHECK(false, "CUDA error: ", cudaGetErrorString(err));
130
+ }
131
+
132
+ return output;
133
+ }
134
+
135
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
136
+ m.def("fused_sparse_gemm_relu", &fused_sparse_gemm_relu, "Fused sparse GEMM + ReLU with Tensor Cores");
137
+ }