Skip to content

Commit

Permalink
support w8a8 fp8 kernel with CUTLASS (#3047)
Browse files Browse the repository at this point in the history
Co-authored-by: yych0745 <[email protected]>
  • Loading branch information
HandH1998 and yych0745 authored Jan 26, 2025
1 parent 95f789a commit 82392da
Show file tree
Hide file tree
Showing 8 changed files with 881 additions and 0 deletions.
164 changes: 164 additions & 0 deletions sgl-kernel/benchmark/bench_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import argparse
import copy
import itertools

import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant

# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536

# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
# M, N, K = batch_size, 4096, 8192
M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]

dtype = torch.float16 if "fp16" in provider else torch.bfloat16

if "vllm-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
),
quantiles=quantiles,
)

gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)


def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()

KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)

print("Benchmark finished!")
2 changes: 2 additions & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _get_version():
turbomind.resolve(),
turbomind.resolve() / "src",
]

nvcc_flags = [
"-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
Expand All @@ -82,6 +83,7 @@ def _get_version():
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"3rdparty/flashinfer/csrc/activation.cu",
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
bmm_fp8,
custom_dispose,
custom_reduce,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
Expand All @@ -27,6 +28,7 @@
"bmm_fp8",
"custom_dispose",
"custom_reduce",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
Expand Down
Loading

0 comments on commit 82392da

Please sign in to comment.