diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py new file mode 100644 index 00000000000..c3f80475356 --- /dev/null +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -0,0 +1,164 @@ +import argparse +import copy +import itertools + +import torch +import triton +from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=[ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ], + line_names=[ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ], + styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], + ylabel="GB/s", + plot_name="fp8 scaled matmul", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + # M, N, K = batch_size, 4096, 8192 + M = batch_size + a = torch.ones((M, K), device="cuda") * 5.0 + b = torch.ones((N, K), device="cuda") * 5.0 + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + b_fp8 = b_fp8.t() + quantiles = [0.5, 0.2, 0.8] + + dtype = torch.float16 if "fp16" in provider else torch.bfloat16 + + if "vllm-fp8" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), + quantiles=quantiles, + ) + elif "sglang-fp8" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sgl_scaled_mm( + a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None + ), + quantiles=quantiles, + ) + + gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 56a42ae4759..c8469dc1c0e 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -56,6 +56,7 @@ def _get_version(): turbomind.resolve(), turbomind.resolve() / "src", ] + nvcc_flags = [ "-DNDEBUG", f"-DOPERATOR_NAMESPACE={operator_namespace}", @@ -82,6 +83,7 @@ def _get_version(): "src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu", + "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", "3rdparty/flashinfer/csrc/activation.cu", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index c7fcd274259..df141dee1d0 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -2,6 +2,7 @@ bmm_fp8, custom_dispose, custom_reduce, + fp8_scaled_mm, fused_add_rmsnorm, gelu_and_mul, gelu_tanh_and_mul, @@ -27,6 +28,7 @@ "bmm_fp8", "custom_dispose", "custom_reduce", + "fp8_scaled_mm", "fused_add_rmsnorm", "gelu_and_mul", "gelu_tanh_and_mul", diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu new file mode 100644 index 00000000000..3e33e143c0c --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -0,0 +1,624 @@ +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +using namespace cute; + +#if defined CUDA_VERSION && CUDA_VERSION >= 12040 +template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, + typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> +struct DeviceGemmFp8RowwiseSm89 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + using ElementA = ElementType; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementType; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = OutElementType; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementOutput = OutElementType; + using LayoutOutput = cutlass::layout::RowMajor; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = AccumElementType; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm89; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + // Number of epilogue stages in EVT + static constexpr int EVTEpilogueStages = 1; + + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; + + // Definition of EVT + using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>; + using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeAScale = + cutlass::epilogue::threadblock::VisitorCompute; + using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; + using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; + + // With bias + using biasSrc = + cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using ComputeAScaleWithBias = + cutlass::epilogue::threadblock::VisitorCompute; + using EpilogueAScaleWithBias = + cutlass::epilogue::threadblock::Sm80EVT; + + using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride>; + using EpilogueStore = + typename cutlass::platform::conditional, + cutlass::epilogue::threadblock::Sm80EVT>::type; + + using EpilogueOp = EpilogueStore; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, + cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator, + ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp, + ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + ptr_a, // a pointer + ptr_b, // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) + if constexpr (WithBias) { + args.epilogue = {{ + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } else { + args.epilogue = {{ + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } + + return args; +} + +template +void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + if (bias) { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + uint32_t const n = out.size(1); + + if (m == 1) { + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 16) { + // M in (1, 16] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + // M in (16, 64] + if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + // M in (64, 128] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 256) { + // M in (128, 256] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 512) { + // M in (256, 512) + if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else { + // M in (512, inf) + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } + } +} +#endif + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +template +struct DeviceGemmFp8RowwiseSm90 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + // A matrix configuration + using ElementA = ElementType; // Element type for A matrix operand + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = ElementType; // Element type for B matrix operand + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = void; // Element type for C matrix operands + using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in + // units of elements (up to 16 bytes) + + // Output matrix configuration + using ElementOutput = OutElementType; // Element type for output matrix operands + using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // // Auxiliary matrix configuration and other fusion types + // using ElementBias = float; + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = AccumElementType; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + + static constexpr bool PONG = false; + static constexpr bool FAST_ACCUM = true; + static constexpr bool USE_BIAS = false; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default + // setting in the Collective Builder + // Implement rowwise scaling epilogue. + using XScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = + cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC, + AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + + using SlowAccum = DefaultSchedule; + using FastAccum = FastPongSchedule; // Default apply Pingpong + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {ptr_a, stride_a, ptr_b, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + ptr_d, + stride_d}}; + if constexpr (WithBias) { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, + {}, // Accumulator + {} // Multiplies + }, + {ptr_bias}, + {}, // Multiplies + }; + } else { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + return args; +} + +template +void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op.run(args, workspace.data_ptr(), stream); + + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias, bool fast_accum = true, + bool use_persistent = false) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + + if (bias) { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; + using BasicTileScheduler = void; + if (m <= 1) { + return sm90_fp8_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler, + BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); + } + if (m <= 64) { + // m in [1, 64] + return sm90_fp8_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 256) { + // m in (64, 256] + return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 1024) { + // m in (256, 1024] + return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else { + // m in (1024, inf) + return sm90_fp8_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } +} +#endif + +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias) { + TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); + TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); + + TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, + "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, + "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); + TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); + TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); + + TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched"); + TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched"); + TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous"); + TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); + + if (bias) { + TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched"); + TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous"); + TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype"); + } + + torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); + TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment"); + + auto sm_version = getSMVersion(); + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + if (sm_version >= 90) { + if (out_dtype == torch::kBFloat16) { + sm90_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + +#if defined CUDA_VERSION && CUDA_VERSION >= 12040 + if (sm_version == 89) { + if (out_dtype == torch::kBFloat16) { + sm89_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm89_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version); +} diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index b29d30ac557..93c53c1e9e4 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias); +// fp8_scaled_mm +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); + // lightning_attention_decode void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 3a21ced875a..ced0dafa9d7 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ) +def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): + return torch.ops.sgl_kernels.fp8_scaled_mm( + mat_a, + mat_b, + scales_a, + scales_b, + out_dtype, + bias, + ) + + def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): torch.ops.sgl_kernels.lightning_attention_decode( q, k, v, past_kv, slope, output, new_kv diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index 099a03a5601..caf4f1269b6 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "bias) -> Tensor"); m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); + // fp8_scaled_mm + m.def( + "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm); + // lightning_attention_decode m.def( "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py new file mode 100644 index 00000000000..1a731865944 --- /dev/null +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -0,0 +1,67 @@ +import unittest + +import torch +from sgl_kernel import fp8_scaled_mm + + +def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): + o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + + o = o.to(torch.float32) + temp1 = o * scale_a.view(-1, 1) + temp2 = temp1 * scale_b.view(1, -1) + final = temp2.to(out_dtype) + if bias is not None: + final = final + bias.view(1, -1) + + return final + + +class TestFp8Gemm(unittest.TestCase): + def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + b_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 + if with_bias: + bias = torch.randn((N,), device=device, dtype=out_dtype) + else: + bias = None + o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) + b_fp8 = b_fp8.t() + o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + rtol = 0.02 + atol = 1 + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") + + def test_accuracy(self): + Ms = [1, 128, 512, 1024, 4096] + Ns = [16, 128, 512, 1024, 4096] + Ks = [512, 1024, 4096, 8192, 16384] + bias_opts = [True, False] + out_dtypes = [torch.bfloat16, torch.float16] + for M in Ms: + for N in Ns: + for K in Ks: + for with_bias in bias_opts: + for out_dtype in out_dtypes: + self._test_accuracy_once( + M, N, K, with_bias, out_dtype, "cuda" + ) + + +if __name__ == "__main__": + unittest.main()