Skip to content

Commit

Permalink
row-wise quant fp8 gemm perf bench | fp8_fast_accum test (pytorch#2666)
Browse files Browse the repository at this point in the history
Summary:

Modify the FP8 GEMM performance benchmark to add the cases with fast_accum=False and max_num_imprecise_acc=32.

Reviewed By: sryap, htyu, jianyuh

Differential Revision: D57613375
  • Loading branch information
Jiecao Yu authored and facebook-github-bot committed Jun 4, 2024
1 parent 7d4b51e commit 1dbe3ac
Show file tree
Hide file tree
Showing 2 changed files with 434 additions and 32 deletions.
72 changes: 69 additions & 3 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,34 @@ def _run_benchmark(
sec = ms / 1e3
perf_str = f"{tflops / sec:.2f}"
print(
f"{(tag + ':').ljust(20)}\tshape {str(shape):<25} tflops {perf_str:<8} ms {ms:.3f}"
f"{(tag + ':').ljust(40)}\tshape {str(shape):<25} tflops {perf_str:<8} ms {ms:.3f}"
)

shapes = [(8192, 8192, 8192), (65536, 8192, 7168), (65536, 3584, 8192)]
shapes = [
(8192, 8192, 8192),
(65536, 8192, 7168),
(65536, 3584, 8192),
(8192, 14336, 4096),
]
for shape in shapes:
_run_benchmark(bf16_bench, shape=shape, tag="bf16")
_run_benchmark(scale_row_bench, shape=shape, tag="fp8 scale + row gemm")
_run_benchmark(scale_block_bench, shape=shape, tag="fp8 scale + block gemm")
_run_benchmark(row_gemm_bench, shape=shape, tag="fp8 row gemm only")
_run_benchmark(
row_gemm_bench,
shape=shape,
tag="fp8 row gemm only | fp8_fast_accum=True",
)
_run_benchmark(
row_gemm_bench_no_fast_acc,
shape=shape,
tag="fp8 row gemm only | fp8_fast_accum=False",
)
_run_benchmark(
row_gemm_bench_imprecise_acc,
shape=shape,
tag="fp8 row gemm only | max_num_imprecise_acc=32",
)
_run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only")


Expand Down Expand Up @@ -118,6 +137,53 @@ def run_gemm() -> Tensor:
return run_gemm


def row_gemm_bench_no_fast_acc(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
# Benchmark only row-wise gemm, caching scaling.
x_fp8: TensorWrapper
w_fp8: TensorWrapper
x_scale: Tensor
w_scale: Tensor
x_fp8, x_scale = quantize_fp8_row(x)
w_fp8, w_scale = quantize_fp8_row(w)

def run_gemm() -> Tensor:
return matmul_fp8_row(
x_fp8,
w_fp8,
x_scale,
w_scale,
dot_out_dtype=torch.float32,
allow_tf32=True,
fp8_fast_accum=False,
)

return run_gemm


def row_gemm_bench_imprecise_acc(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
# Benchmark only row-wise gemm, caching scaling.
x_fp8: TensorWrapper
w_fp8: TensorWrapper
x_scale: Tensor
w_scale: Tensor
x_fp8, x_scale = quantize_fp8_row(x)
w_fp8, w_scale = quantize_fp8_row(w)

def run_gemm() -> Tensor:
return matmul_fp8_row(
x_fp8,
w_fp8,
x_scale,
w_scale,
dot_out_dtype=torch.float32,
allow_tf32=True,
fp8_fast_accum=True,
imprecise_acc=True,
)

return run_gemm


def scale_block_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
def run_gemm() -> Tensor:
x_fp8: TensorWrapper
Expand Down
Loading

0 comments on commit 1dbe3ac

Please sign in to comment.