Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support w8a8 fp8 kernel with CUTLASS #3047

Merged
merged 43 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
955a2fb
Add performance and accuracy test code for FP8 GEMM operations
yych0745 Jan 7, 2025
30bdf20
support w8a8 fp8
HandH1998 Jan 8, 2025
4cac9fb
support bias
HandH1998 Jan 9, 2025
ecc90a4
opitmize
yych0745 Jan 10, 2025
3497950
add config_profile for sm_89
yych0745 Jan 13, 2025
05eb204
fp8 sm90-H100 singleTest done
yych0745 Jan 14, 2025
8d95538
fp8 sm90-H100 singleTest done
yych0745 Jan 14, 2025
724cf62
clean code
yych0745 Jan 14, 2025
93e2d85
fix
yych0745 Jan 14, 2025
8c08dbb
clean code
yych0745 Jan 14, 2025
fb95b0e
clean code
yych0745 Jan 15, 2025
2bac342
fp8 dispatch change
yych0745 Jan 21, 2025
ba7ca85
clean code
yych0745 Jan 21, 2025
2727d7d
fix
yych0745 Jan 21, 2025
b11682e
clean code
yych0745 Jan 21, 2025
fe490cc
Add performance and accuracy test code for FP8 GEMM operations
yych0745 Jan 7, 2025
b2de73d
support w8a8 fp8
HandH1998 Jan 8, 2025
3691d68
support bias
HandH1998 Jan 9, 2025
38bcf52
fix compilation
HandH1998 Jan 21, 2025
d57f756
clean code
yych0745 Jan 22, 2025
e620244
clean code
yych0745 Jan 22, 2025
699fe9e
Merge pull request #6 from HandH1998/tmptmp
HandH1998 Jan 22, 2025
b6a88bb
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 Jan 22, 2025
604f4f5
format
HandH1998 Jan 22, 2025
98dc70d
format
HandH1998 Jan 22, 2025
a4025f6
Merge branch 'main' into main_w8a8_fp8
zhyncs Jan 22, 2025
8b87aad
upd
zhyncs Jan 22, 2025
b287319
Merge branch 'main' into main_w8a8_fp8
zhyncs Jan 22, 2025
6de3ad4
Merge branch 'main_w8a8_fp8' of https://github.com/HandH1998/sglang i…
yych0745 Jan 23, 2025
b4195b0
fix include
HandH1998 Jan 23, 2025
8290ba6
add more shapes for benchmark
HandH1998 Jan 23, 2025
a455233
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 Jan 23, 2025
42f408f
fix bug
HandH1998 Jan 23, 2025
1739631
Merge branch 'main_w8a8_fp8' of https://github.com/HandH1998/sglang i…
yych0745 Jan 24, 2025
0666d39
cutlass optimization
yych0745 Jan 24, 2025
b9980af
clean code
HandH1998 Jan 24, 2025
cd51083
fix reivew issues
HandH1998 Jan 24, 2025
4a98c75
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 Jan 24, 2025
8c3dc13
fix bug
HandH1998 Jan 24, 2025
a1b582e
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 Jan 26, 2025
248391e
Merge branch 'main' into main_w8a8_fp8
zhyncs Jan 26, 2025
62bf9a4
fix name conflict
HandH1998 Jan 26, 2025
0d7f5a0
Merge remote-tracking branch 'origin/main' into main_w8a8_fp8
HandH1998 Jan 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions sgl-kernel/benchmark/bench_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import argparse
import copy
import itertools

import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant

# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536

# TP1 shapes
WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}


@triton.testing.perf_report(
zhyncs marked this conversation as resolved.
Show resolved Hide resolved
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
# M, N, K = batch_size, 4096, 8192
M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]

dtype = torch.float16 if "fp16" in provider else torch.bfloat16

if "vllm-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
),
quantiles=quantiles,
)

gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms), gbps(max_ms), gbps(min_ms)


def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()

KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)

print("Benchmark finished!")
2 changes: 2 additions & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _get_version():
turbomind.resolve(),
turbomind.resolve() / "src",
]

nvcc_flags = [
"-DNDEBUG",
f"-DOPERATOR_NAMESPACE={operator_namespace}",
Expand All @@ -82,6 +83,7 @@ def _get_version():
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"3rdparty/flashinfer/csrc/activation.cu",
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
bmm_fp8,
custom_dispose,
custom_reduce,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
Expand All @@ -27,6 +28,7 @@
"bmm_fp8",
"custom_dispose",
"custom_reduce",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
"gelu_tanh_and_mul",
Expand Down
Loading
Loading