diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index e4d60f3c37..a650b1a534 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -23,14 +23,14 @@ fp32_to_bf16_with_clamp, fp32_to_fp16_with_clamp, fp32_to_hfp8_with_clamp, - fp32_to_mx4, hfp8_to_fp32, - mx4_to_fp32, ) + from fbgemm_gpu.split_embedding_configs import SparseType from torch.autograd.profiler import record_function # usort:skip from dataclasses import dataclass +import fbgemm_gpu.quantize_ops # noqa F401 logger: logging.Logger = logging.getLogger() @@ -100,7 +100,15 @@ def _quantize_tensor( return input_quant_all2all elif comm_precision == SparseType.MX4: mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT - return fp32_to_mx4(input_tensor, mx_group_size) + quantized_output = torch.ops.fbgemm.quantize_mx( + input=input_tensor, + scale_bits=8, + elem_ebits=2, + elem_mbits=3, + elem_max_norm=6.0, + mx_group_size=mx_group_size, + ) + return quantized_output else: raise ValueError(f"comm_precision={comm_precision} is not supported") @@ -141,7 +149,11 @@ def _dequantize_tensor( return dequant_tensor.view(-1) elif comm_precision == SparseType.MX4: mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT - return mx4_to_fp32(quantized_tensor, mx_group_size) + dequant_tensor = torch.ops.fbgemm.dequantize_mx( + input=quantized_tensor, + mx_group_size=mx_group_size, + ) + return dequant_tensor.view(-1) else: raise ValueError(f"comm_precision={comm_precision} is not supported") diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_ops.py b/fbgemm_gpu/fbgemm_gpu/quantize_ops.py new file mode 100644 index 0000000000..514f747e8a --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/quantize_ops.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch + +from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32 + +lib = torch.library.Library("fbgemm", "FRAGMENT") +lib.define( + """ + quantize_mx( + Tensor input, + int scale_bits, + int elem_ebits, + int elem_mbits, + float elem_max_norm, + int mx_group_size + ) -> Tensor + """ +) + +lib.define( + """ + dequantize_mx( + Tensor input, + int mx_group_size + ) -> Tensor + """ +) + + +@torch.library.impl(lib, "quantize_mx", "CPU") +@torch.library.impl(lib, "quantize_mx", "CUDA") +def quantize_mx( + input: torch.Tensor, + scale_bits: int = 8, + elem_ebits: int = 2, + elem_mbits: int = 3, + elem_max_norm: float = 6.0, + mx_group_size: int = 32, +) -> torch.Tensor: + """ + Registered quantize_mx ops for E2E comm + We use Triton implementation for quantization + Args: + input: FP32 tensor of size total_elems to be quantized + scale_bits: num bits of the shared exponent (i.e., 8 for MX4 e2m1) + elem_ebits: num bits of the exponent (i.e., 2 for MX4 e2m1) + elem_mbits: num bits of the mantissa incl. sign and implicit bits ( + i.e., 3 for MX4 e2m1) + elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1) + mx_group_size: num elements that share the max shared_exponent + + Return: + output: MX4 tensor packed into int8 values with size + (total_elems / 2 + total_elems / groupsize) + the shared exponent of each group is stored at the last byte + of output of each group + """ + return fp32_to_mx4(input, mx_group_size, use_triton=True) + + +@torch.library.impl(lib, "dequantize_mx", "CPU") +@torch.library.impl(lib, "dequantize_mx", "CUDA") +def dequantize_mx( + input: torch.Tensor, + mx_group_size: int = 32, +) -> torch.Tensor: + """ + Registered dequantize_mx ops for E2E comm + We use CUDA implementation for quantization + Args: + input: FP8 tensor (MX4 packed in FP8) + mx_group_size: number of elements that shares the same max shared_exponent + + Return: + output: FP32 tensor with total elements (total_elems) + """ + return mx4_to_fp32(input, mx_group_size, use_triton=False) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py index 212c8374e7..300ab8829d 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py @@ -8,6 +8,7 @@ # pyre-strict import logging +import math import torch @@ -44,13 +45,23 @@ def fp32_to_mx4( output: MX4 tensor packed into int8 values with total elements (M / 2 + M / groupsize) """ # Accelerated MX4 is only available on cuda, if input is on cpu, use python. + # For CPU and triton, set the second dim to 2048 or the nearest power of 2. + dim = ( + 2048 if tensor.numel() >= 2048 else 2 ** (math.floor(math.log2(tensor.numel()))) + ) + input = ( + tensor.view(-1) + if (tensor.is_cuda and not use_triton) or tensor.numel() % dim != 0 + else tensor.view(-1, dim) + ) if not tensor.is_cuda: - return py_quantize_mx4(tensor, group_size) + return py_quantize_mx4(input, group_size) + if use_triton: - return quantize_mx4(tensor, group_size) + return quantize_mx4(input, group_size) else: out = torch.ops.fbgemm.quantize_mx_cuda( - tensor.view(-1), + input, scale_bits=8, elem_ebits=2, elem_mbits=3, @@ -75,16 +86,14 @@ def mx4_to_fp32( Return: output: FP32 tensor with total elements (M). """ + flatten_tensor = tensor.view(-1) # Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python. if not tensor.is_cuda: - return py_dequantize_mx4(tensor, group_size) + return py_dequantize_mx4(flatten_tensor, group_size) if use_triton: - return dequantize_mx4(tensor, group_size) + return dequantize_mx4(flatten_tensor, group_size) else: - out = torch.ops.fbgemm.dequantize_mx_cuda(tensor.view(-1), group_size) - # Perserve input dimensions. - output_shape = list(tensor.shape[:-1]) + [-1] - return out.view(output_shape) + return torch.ops.fbgemm.dequantize_mx_cuda(flatten_tensor, group_size) def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: diff --git a/fbgemm_gpu/test/quantize/comm_codec_test.py b/fbgemm_gpu/test/quantize/comm_codec_test.py index 9e2737dfc5..0c8297c976 100644 --- a/fbgemm_gpu/test/quantize/comm_codec_test.py +++ b/fbgemm_gpu/test/quantize/comm_codec_test.py @@ -18,7 +18,7 @@ class QuantizedCommCodecTest(unittest.TestCase): - @settings(deadline=4000) + @settings(deadline=8000) # pyre-ignore @given( comm_precisions_loss_scale=st.sampled_from( diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index 7927434a62..5733a42064 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -9,9 +9,12 @@ import unittest from typing import List +import fbgemm_gpu.quantize_ops # noqa F401 + import hypothesis.strategies as st import torch + from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32 from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4 @@ -81,10 +84,6 @@ def fake_quantize_mx( ) -> torch.Tensor: """Function used for MX* fake quantization""" - #################### - # Python Quantize - #################### - # Make sure axes is a list of non-negative numbers axes = [x + A.ndim if x < 0 else x for x in axes] @@ -161,6 +160,7 @@ def test_mx4(self, power: int, sizes: int) -> None: ebits, mbits, emax, max_norm, _ = _get_format_params(element_format_str) scale_bits = 8 + # Reference from mx_github output_ref = fake_quantize_mx( input, scale_bits, @@ -172,7 +172,8 @@ def test_mx4(self, power: int, sizes: int) -> None: group_size=group_size, ) - output = fake_quantize_mx_cuda( + # Test CUDA implementation + output_cuda = fake_quantize_mx_cuda( input, scale_bits, ebits, @@ -183,16 +184,38 @@ def test_mx4(self, power: int, sizes: int) -> None: ) # Test intercompatibility between implementations. - py_mx_q_input = py_quantize_mx4(input, group_size) - py_mx_output = py_dequantize_mx4(py_mx_q_input, group_size) - triton_mx_q_input = fp32_to_mx4(input, group_size, use_triton=True) - cuda_mx_output = mx4_to_fp32(triton_mx_q_input, group_size, use_triton=False) - triton_mx_output = mx4_to_fp32(triton_mx_q_input, group_size, use_triton=True) - - check_diff_quantize(input, py_mx_output, output_ref) - check_diff_quantize(input, cuda_mx_output, output_ref) - check_diff_quantize(input, triton_mx_output, output_ref) - check_diff_quantize(input, output, output_ref) + # Test CPU implementation + quantized_cpu = py_quantize_mx4(input, group_size) + output_cpu = py_dequantize_mx4(quantized_cpu, group_size) + + # Test Triton implementation + quantized_triton = fp32_to_mx4(input, group_size, use_triton=True) + output_triton = mx4_to_fp32(quantized_triton, group_size, use_triton=True) + + # Test shim functions + output_cuda_from_quantized_triton = mx4_to_fp32( + quantized_triton, group_size, use_triton=False + ) + + # Test torch.ops + quantized_from_ops = torch.ops.fbgemm.quantize_mx( + input, + scale_bits, + ebits, + mbits, + max_norm, + mx_group_size=group_size, + ) + output_from_ops = torch.ops.fbgemm.dequantize_mx( + quantized_from_ops, + mx_group_size=group_size, + ) + + check_diff_quantize(input, output_ref, output_cuda) + check_diff_quantize(input, output_cuda, output_cuda_from_quantized_triton) + check_diff_quantize(input, output_cuda_from_quantized_triton, output_triton) + check_diff_quantize(input, output_triton, output_cpu) + check_diff_quantize(input, output_cuda, output_from_ops) if __name__ == "__main__":