Spaces:
Running
on
T4
Running
on
T4
| /* | |
| NOTE: blas gemm is column-major by default, but we need row-major output. | |
| The data of row-major, transposed matrix is exactly the same as the | |
| column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T | |
| */ | |
| void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { | |
| const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); | |
| const auto cuda_data_type = CUDA_R_16F; | |
| const auto cuda_c_data_type = | |
| c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; | |
| const auto compute_type = CUDA_R_32F; | |
| const float sp_alpha = 1.f; | |
| // swap a and b, and use CUBLAS_OP_N. see the notes above | |
| std::swap(a, b); | |
| const cublasOperation_t cublas_trans_a = CUBLAS_OP_N; | |
| const cublasOperation_t cublas_trans_b = CUBLAS_OP_N; | |
| // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap, | |
| // negative axis is used because of the existence of batch matmul. | |
| const int m = a.size(-1); | |
| const int k = a.size(-2); | |
| const int n = b.size(-2); | |
| const int cublas_lda = m; | |
| const int cublas_ldb = k; | |
| const int cublas_ldc = m; | |
| cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); | |
| cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; | |
| cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; | |
| const float sp_beta = 0.f; | |
| if (a.sizes().size() == 2 && b.sizes().size() == 2) { | |
| CUBLAS_CHECK(cublasGemmEx( | |
| cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha, | |
| a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type, | |
| cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, | |
| compute_type, algo)); | |
| } else { | |
| // batch matmul | |
| assert(a.sizes().size() == 3 && b.sizes().size() == 3); | |
| const long long int cublas_stride_a = m * k; | |
| const long long int cublas_stride_b = k * n; | |
| const long long int cublas_stride_c = m * n; | |
| CUBLAS_CHECK(cublasGemmStridedBatchedEx( | |
| cublas_handle, cublas_trans_a, cublas_trans_b, m, | |
| n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda, | |
| cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b, | |
| &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c, | |
| a.size(0), compute_type, algo)); | |
| } | |
| } | |