#pragma once // clang-format will break include orders // clang-format off #include #include #include "cutlass/cutlass.h" #include "cute/tensor.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/numeric_types.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "core/math.hpp" #include "cutlass_extensions/common.hpp" // clang-format on /* Epilogues defined in, csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp, must contain a public type named EVTCompute of type Sm90EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ using namespace cute; namespace vllm { // A wrapper for the GEMM kernel that is used to guard against compilation on // architectures that will never use the kernel. The purpose of this is to // reduce the size of the compiled binary. // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef // into code that will be executed on the device where it is defined. template struct enable_sm90_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 Kernel::operator()(std::forward(args)...); #endif } }; template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, typename EpilogueSchedule> struct cutlass_3x_gemm { using ElementAB = ElementAB_; using ElementD = ElementD_; using ElementAcc = typename std::conditional, int32_t, float>::type; using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, ElementD, EpilogueSchedule>; using Epilogue = Epilogue_; using StrideD = Stride, Int<0>>; using ElementC = void; using StrideC = StrideD; using EVTCompute = typename Epilogue::EVTCompute; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< static_cast(CEStorageSize)>; // clang-format off using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, ElementAB, cutlass::layout::RowMajor, 16, ElementAB, cutlass::layout::ColumnMajor, 16, ElementAcc, TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp; // clang-format on using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>>; struct GemmKernel : public KernelType {}; }; template void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; int32_t m = a.size(0); int32_t n = b.size(1); int32_t k = a.size(1); int64_t lda = a.stride(0); int64_t ldb = b.stride(1); int64_t ldc = out.stride(0); using StrideA = Stride, int64_t>; using StrideB = Stride, int64_t>; using StrideC = typename Gemm::StrideC; StrideA a_stride{lda, Int<1>{}, 0}; StrideB b_stride{ldb, Int<1>{}, 0}; StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; using GemmKernel = typename Gemm::GemmKernel; typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, b_stride}; auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), c_ptr, c_stride, c_ptr, c_stride}; typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, prob_shape, mainloop_args, epilogue_args}; // Launch the CUTLASS GEMM kernel. using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } } // namespace vllm