diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py index 77fa04b3a2..9eb1a6c7d8 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py @@ -70,38 +70,38 @@ def _run_benchmark( ) shapes = [ - (8192, 8192, 512), + # (8192, 8192, 512), (8192, 8192, 8192), - (65536, 8192, 7168), - (65536, 3584, 8192), - (8192, 14336, 4096), + # (65536, 8192, 7168), + # (65536, 3584, 8192), + # (8192, 14336, 4096), ] for shape in shapes: - _run_benchmark(bf16_bench, shape=shape, tag="bf16") - _run_benchmark(scale_row_bench, shape=shape, tag="fp8 scale + row gemm") - _run_benchmark(scale_block_bench, shape=shape, tag="fp8 scale + block gemm") - _run_benchmark( - row_gemm_bench, - shape=shape, - tag="fp8 row gemm only | fp8_fast_accum=True", - ) + # _run_benchmark(bf16_bench, shape=shape, tag="bf16") + # _run_benchmark(scale_row_bench, shape=shape, tag="fp8 scale + row gemm") + # _run_benchmark(scale_block_bench, shape=shape, tag="fp8 scale + block gemm") + # _run_benchmark( + # row_gemm_bench, + # shape=shape, + # tag="fp8 row gemm only | fp8_fast_accum=True", + # ) _run_benchmark( row_gemm_bench_no_fast_acc, shape=shape, tag="fp8 row gemm only | fp8_fast_accum=False", ) - _run_benchmark( - row_gemm_bench_imprecise_acc, - shape=shape, - tag="fp8 row gemm only | max_num_imprecise_acc=32", - ) - _run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only") - if rowwise_tma: - _run_benchmark( - row_gemm_bench_tma, - shape=shape, - tag="fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True", - ) + # _run_benchmark( + # row_gemm_bench_imprecise_acc, + # shape=shape, + # tag="fp8 row gemm only | max_num_imprecise_acc=32", + # ) + # _run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only") + # if rowwise_tma: + # _run_benchmark( + # row_gemm_bench_tma, + # shape=shape, + # tag="fp8 row gemm only | fp8_fast_accum=True | tma_persistent=True", + # ) def bf16_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index ce2cd12a38..8d31eacafe 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -362,8 +362,7 @@ def _kernel_matmul_fp8_row( @triton.autotune( - configs=MATMUL_CONFIGS - + [ + configs= [ Config( {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 65b34a956d..079799085f 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -144,7 +144,7 @@ def get_quantize_ops() -> List[QuantizeOpBase]: return quantize_op_registry -@register_quantize_op +# @register_quantize_op class BF16Baseline(QuantizeOpBase): """ Baseline BF16 matmul. @@ -172,7 +172,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class ScaledMMBaseline(QuantizeOpBase): """ Reference FP8 matmul implemented in native torch with cublas or hipblas. @@ -253,7 +253,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class ScaledMMRowwise(QuantizeOpBase): def quantize(self, x, w): xq, x_scale = quantize_fp8_row(x) @@ -292,7 +292,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class FP8TensorwiseGemm(QuantizeOpBase): """ FP8 matmul with tensorwise scaling. @@ -361,7 +361,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class FP8RowwiseGemm(QuantizeOpBase): """ FP8 matmul with rowwise scaling. @@ -430,7 +430,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class FP8TritonBlockwiseGemm(QuantizeOpBase): """ FP8 matmul with block scaling. @@ -463,7 +463,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class FP8CutlassBlockwiseGemm(QuantizeOpBase): """ FP8 matmul with block scaling. @@ -501,7 +501,7 @@ def cuda(self) -> bool: # CUTLASS kernel v2 -@register_quantize_op +# @register_quantize_op class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase): """ FP8 matmul with tensorwise scaling. @@ -535,7 +535,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class F8I4RowwiseGemm(QuantizeOpBase): """ Mixed Precision FP8 Activations with Int4 Weights. @@ -611,7 +611,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class BF16I4RowwiseGemm(F8I4RowwiseGemm): """ Mixed Precision BF16 Activations with Int4 Weights. @@ -645,7 +645,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class TinyGemmBF16I4(QuantizeOpBase): """ Mixed Precision BF16 Activations with Int4 Weights using tinygemm. @@ -682,7 +682,7 @@ def cuda(self) -> bool: return True -@register_quantize_op +# @register_quantize_op class MarlinBF16I4(QuantizeOpBase): """ Mixed Precision BF16 Activations with Int4 Weights using Marlin.