#pragma once #include "scaled_mm_c3x.cuh" /** * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm * shape. */ namespace vllm { template typename Epilogue> struct sm90_fp8_config_default { // M in (128, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_fp8_config_M128 { // M in (64, 128] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue> struct sm90_fp8_config_M64 { // M in [1, 64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _64, _128>; using ClusterShape = Shape<_1, _8, _1>; using Cutlass3xGemm = cutlass_3x_gemm; }; template typename Epilogue, typename... EpilogueArgs> inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); using Cutlass3xGemmDefault = typename sm90_fp8_config_default::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = typename sm90_fp8_config_M128::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 if (mp2 <= 64) { // m in [1, 64] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // m in (64, 128] return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { // m in (128, inf) return cutlass_gemm_caller( out, a, b, std::forward(args)...); } } } // namespace vllm