Skip to content

Commit

Permalink
rowwise only test with fast_acc=true
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu authored and facebook-github-bot committed Sep 3, 2024
1 parent 225ac16 commit 8ceeb8c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 38 deletions.
48 changes: 24 additions & 24 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,38 +70,38 @@ def _run_benchmark(
)

shapes = [
(8192, 8192, 512),
# (8192, 8192, 512),
(8192, 8192, 8192),
(65536, 8192, 7168),
(65536, 3584, 8192),
(8192, 14336, 4096),
# (65536, 8192, 7168),
# (65536, 3584, 8192),
# (8192, 14336, 4096),
]
for shape in shapes:
_run_benchmark(bf16_bench, shape=shape, tag="bf16")
_run_benchmark(scale_row_bench, shape=shape, tag="fp8 scale + row gemm")
_run_benchmark(scale_block_bench, shape=shape, tag="fp8 scale + block gemm")
_run_benchmark(
row_gemm_bench,
shape=shape,
tag="fp8 row gemm only | fp8_fast_accum=True",
)
# _run_benchmark(bf16_bench, shape=shape, tag="bf16")
# _run_benchmark(scale_row_bench, shape=shape, tag="fp8 scale + row gemm")
# _run_benchmark(scale_block_bench, shape=shape, tag="fp8 scale + block gemm")
# _run_benchmark(
# row_gemm_bench,
# shape=shape,
# tag="fp8 row gemm only | fp8_fast_accum=True",
# )
_run_benchmark(
row_gemm_bench_no_fast_acc,
shape=shape,
tag="fp8 row gemm only | fp8_fast_accum=False",
)
_run_benchmark(
row_gemm_bench_imprecise_acc,
shape=shape,
tag="fp8 row gemm only | max_num_imprecise_acc=32",
)
_run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only")
if rowwise_tma:
_run_benchmark(
row_gemm_bench_tma,
shape=shape,
tag="fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True",
)
# _run_benchmark(
# row_gemm_bench_imprecise_acc,
# shape=shape,
# tag="fp8 row gemm only | max_num_imprecise_acc=32",
# )
# _run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only")
# if rowwise_tma:
# _run_benchmark(
# row_gemm_bench_tma,
# shape=shape,
# tag="fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True",
# )


def bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]:
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,7 @@ def _kernel_matmul_fp8_row(


@triton.autotune(
configs=MATMUL_CONFIGS
+ [
configs= [
Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
Expand Down
24 changes: 12 additions & 12 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_quantize_ops() -> List[QuantizeOpBase]:
return quantize_op_registry


@register_quantize_op
# @register_quantize_op
class BF16Baseline(QuantizeOpBase):
"""
Baseline BF16 matmul.
Expand Down Expand Up @@ -172,7 +172,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class ScaledMMBaseline(QuantizeOpBase):
"""
Reference FP8 matmul implemented in native torch with cublas or hipblas.
Expand Down Expand Up @@ -253,7 +253,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class ScaledMMRowwise(QuantizeOpBase):
def quantize(self, x, w):
xq, x_scale = quantize_fp8_row(x)
Expand Down Expand Up @@ -292,7 +292,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class FP8TensorwiseGemm(QuantizeOpBase):
"""
FP8 matmul with tensorwise scaling.
Expand Down Expand Up @@ -361,7 +361,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class FP8RowwiseGemm(QuantizeOpBase):
"""
FP8 matmul with rowwise scaling.
Expand Down Expand Up @@ -430,7 +430,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class FP8TritonBlockwiseGemm(QuantizeOpBase):
"""
FP8 matmul with block scaling.
Expand Down Expand Up @@ -463,7 +463,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class FP8CutlassBlockwiseGemm(QuantizeOpBase):
"""
FP8 matmul with block scaling.
Expand Down Expand Up @@ -501,7 +501,7 @@ def cuda(self) -> bool:


# CUTLASS kernel v2
@register_quantize_op
# @register_quantize_op
class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase):
"""
FP8 matmul with tensorwise scaling.
Expand Down Expand Up @@ -535,7 +535,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class F8I4RowwiseGemm(QuantizeOpBase):
"""
Mixed Precision FP8 Activations with Int4 Weights.
Expand Down Expand Up @@ -611,7 +611,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
"""
Mixed Precision BF16 Activations with Int4 Weights.
Expand Down Expand Up @@ -645,7 +645,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class TinyGemmBF16I4(QuantizeOpBase):
"""
Mixed Precision BF16 Activations with Int4 Weights using tinygemm.
Expand Down Expand Up @@ -682,7 +682,7 @@ def cuda(self) -> bool:
return True


@register_quantize_op
# @register_quantize_op
class MarlinBF16I4(QuantizeOpBase):
"""
Mixed Precision BF16 Activations with Int4 Weights using Marlin.
Expand Down

0 comments on commit 8ceeb8c

Please sign in to comment.