diff --git a/fbgemm_gpu/fbgemm_gpu/quantize/__init__.py b/fbgemm_gpu/fbgemm_gpu/quantize/__init__.py index 0580cf863b..df64c82ca7 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize/__init__.py @@ -49,7 +49,8 @@ def op_registeration( int elem_ebits, int elem_mbits, float elem_max_norm, - int mx_group_size + int mx_group_size, + int rounding_mode ) -> Tensor """ ) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize/quantize_ops.py b/fbgemm_gpu/fbgemm_gpu/quantize/quantize_ops.py index c788341435..3da107b615 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize/quantize_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize/quantize_ops.py @@ -4,11 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# # pyre-unsafe +# pyre-unsafe +from typing import Union import torch -from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32 +from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode def quantize_mx( @@ -18,6 +19,7 @@ def quantize_mx( elem_mbits: int = 3, elem_max_norm: float = 6.0, mx_group_size: int = 32, + rounding_mode: Union[RoundingMode, int] = RoundingMode.even, ) -> torch.Tensor: """ Registered quantize_mx ops for E2E comm. @@ -31,6 +33,7 @@ def quantize_mx( i.e., 3 for MX4 e2m1) elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1) mx_group_size: num elements that share the max shared_exponent + rounding_mode: Which type of rounding to use when calculating shared exponent. Return: output: MX4 tensor packed into int8 values with size @@ -38,7 +41,9 @@ def quantize_mx( the shared exponent of each group is stored at the last byte of output of each group """ - return fp32_to_mx4(input, mx_group_size, use_triton=True) + return fp32_to_mx4( + input, mx_group_size, rounding_mode=rounding_mode, use_triton=True + ) def dequantize_mx( diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py index 5e5c9505e4..65539a7a73 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py @@ -11,7 +11,7 @@ import torch -from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4 +from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4 logger: logging.Logger = logging.getLogger() @@ -31,13 +31,18 @@ def fp32_to_mx4( - tensor: torch.Tensor, group_size: int = 32, use_triton: bool = True + tensor: torch.Tensor, + group_size: int = 32, + rounding_mode: Union[RoundingMode, int] = RoundingMode.even, + use_triton: bool = True, ) -> torch.Tensor: """Quantize an FP32 tensor to MX4 with triton or native cuda impl. Args: tensor (torch.Tensor): FP32 tensor to quantize with M total elements. group_size (int): Compute scale in chunks of group_size. + rounding_mode (RoundingMode or int): Which type of rounding to use when computing exponent. + Only supported with use_triton=True. use_triton (bool): If set, use triton quantization, otherwise cuda. Return: @@ -48,10 +53,10 @@ def fp32_to_mx4( input = tensor.flatten() if not tensor.is_cuda: - return py_quantize_mx4(input, group_size) + return py_quantize_mx4(input, group_size, rounding_mode=rounding_mode) if use_triton: - return quantize_mx4(input, group_size) + return quantize_mx4(input, group_size, rounding_mode=rounding_mode) else: out = torch.ops.fbgemm.quantize_mx_cuda( input, diff --git a/fbgemm_gpu/fbgemm_gpu/triton/__init__.py b/fbgemm_gpu/fbgemm_gpu/triton/__init__.py index 66dcd46b4c..4e483d2670 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/__init__.py @@ -10,6 +10,7 @@ # Attempt to import triton kernels, fallback to reference if we cannot. try: from .quantize import ( + RoundingMode, triton_dequantize_mx4 as dequantize_mx4, triton_quantize_mx4 as quantize_mx4, ) @@ -17,4 +18,5 @@ from .quantize_ref import ( # noqa: F401, E402 py_dequantize_mx4 as dequantize_mx4, py_quantize_mx4 as quantize_mx4, + RoundingMode, ) diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py index 64dbc85587..9a40d211ac 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py @@ -7,6 +7,8 @@ # pyre-unsafe import math +from enum import IntEnum +from typing import Union import torch import triton # @manual @@ -15,6 +17,81 @@ from triton import Config # @manual +class RoundingMode(IntEnum): + nearest = 0 + floor = 1 + even = 2 + stochastic = 3 + ceil = 4 + + +@triton.jit +def _floor_log2(x): + """Helper function to efficiently compute floor(log2(x)) + + Args: + x (Tensor): FP32 Input tensor to operate on. + + Returns: + Tensor: Floor of log2(x). + """ + # Helpful bit constants. + FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type] + FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type] + FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type] + + # View x as an integer and extract its exponent. + x = x.to(tl.int32, bitcast=True) & FP32_EXP_MASK + # Shift exponent down to bottom bits. + x = x >> FP32_EXP_OFFSET + # Remove FP32 exponent bias and return. + return (x - FP32_EXP_BIAS).to(tl.float32) + + +@triton.jit +def _compute_exp( + group_max, + rounding_mode, + rand_bits, +): + """Compute shared exponent of group using specified rounding mode. + + Args: + group_max (Tensor): Group of values to compute exponent of. + rounding_mode (int or RoundingMode): Which rounding mode to use. + rand_bits (int): Random integer values used for stochastic rounding. + + Returns: + Tensor: Shared exponent of group. + """ + # TODO(jwfromm) Enable hip libdevice support as well. + # Nearest rounding mode. + if rounding_mode == 0: + return tl.extra.cuda.libdevice.round(tl.log2(group_max)) + # Floor rounding mode. This can be done with fast bit ops. + if rounding_mode == 1: + return _floor_log2(group_max) + # Even pre-rounding mode. + elif rounding_mode == 2: + # First round to nearest even integer. + group_max = tl.extra.cuda.libdevice.rint(group_max) + # Then perform floor rounding of log. + return _floor_log2(group_max) + # Stochastic rounding mode. + elif rounding_mode == 3: + # Define constants needed for stochastic rounding. + MBITS_FP32: tl.constexpr = 23 # type: ignore[Incompatible variable type] + MBITS_E2M1: tl.constexpr = 1 # type: ignore[Incompatible variable type] + RAND_MASK: tl.constexpr = 1 << (MBITS_FP32 - MBITS_E2M1) - 1 # type: ignore[Incompatible variable type] + # Use random bits to add noise to mantissa that would otherwise + # be rounded away. + group_max = group_max.to(tl.int32, bitcast=True) + (RAND_MASK & rand_bits) + # Now compute log and truncate. + return _floor_log2(group_max) + else: + return tl.ceil(tl.log2(group_max)) + + @triton.autotune( configs=[ Config({"GROUP_LOAD": 1}), @@ -31,6 +108,8 @@ def _kernel_quantize_mx4( out, M, K, + rand_bits, + ROUNDING_MODE: tl.constexpr, GROUP_SIZE: tl.constexpr, GROUP_LOAD: tl.constexpr, ) -> None: @@ -42,6 +121,8 @@ def _kernel_quantize_mx4( out (Tensor): [M / 2 + M / GROUP_SIZE] output containing packed mx4 values. M (int): Total number of elements. K (int): Number of elements to process in each thread. + rand_bits (Optional Tensor): [M, K / 2] random integers used for stochastic rounding. + ROUNDING_MODE (int): Which rounding method to use when calculating shared exponent. GROUP_SIZE (int): Size of chunks that use the same shared exponent. GROUP_LOAD (int): Number of groups to process simultaneously. """ @@ -73,6 +154,7 @@ def _kernel_quantize_mx4( # Initiate offset ranges used in kernel. input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + input_start output_offset = tl.arange(0, GROUP_LOAD * (GROUP_SIZE // 2)) + rand_bits_offset = tl.arange(0, GROUP_LOAD) + pid * K // GROUP_SIZE # We need to shift output offsets to make space for shared exponent storage. output_offset += output_offset // (GROUP_SIZE // 2) + output_start # Now create offsets for writing the shared exponent. @@ -97,9 +179,18 @@ def _kernel_quantize_mx4( group_max = tl.max(tl.abs(a_groups), axis=1) # Prevent infinite values in log. group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max) - # Convert max to exponent via direct log computation and ceiling - # rounding to minimize errors. - group_exp = tl.ceil(tl.log2(group_max)) + # Load relevant random values if doing stochastic rounding. + if ROUNDING_MODE == 3: + group_rand_bits = tl.load( + rand_bits + rand_bits_offset, + mask=rand_bits_offset < K // GROUP_SIZE, + other=0, + ) + rand_bits_offset += GROUP_LOAD + else: + group_rand_bits = None + # Compute shared exponent using specified rounding mode. + group_exp = _compute_exp(group_max, ROUNDING_MODE, group_rand_bits) # Subtract largest exponent in target datatype and remove bias. group_exp = group_exp - 2 # Make sure exponent is in valid range. @@ -204,13 +295,19 @@ def _kernel_quantize_mx4( output_offset += GROUP_LOAD * PACKED_GROUP_SIZE -def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: +def triton_quantize_mx4( + a: torch.Tensor, + group_size: int = 32, + rounding_mode: Union[RoundingMode, int] = RoundingMode.even, +) -> torch.Tensor: """ Quantize a tensor to mx4 format using efficient triton kernels. Args: a (Tensor): [M] higher precision input tensor. group_size (int): Size of chunks that will use the same shared exponent. + rounding_mode (Union[RoundingMode, int]): Which type of rounding to use + when calculating shared exponent. Defaults to pre-rounding to nearest even int. Returns: torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8 @@ -252,6 +349,19 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: [a.numel() // 2 + a.numel() // group_size], device=a.device, dtype=torch.uint8 ) + # If using stochastic rounding, create random noise for each group. + if rounding_mode == RoundingMode.stochastic: + # Each group will need a seed. + rand_bits = torch.randint( + low=0, + high=2**31 - 1, + size=(a.numel() // group_size,), + dtype=torch.int32, + device=a.device, + ) + else: + rand_bits = None + # Invoke triton quantization kernel over rows. grid = (M,) _kernel_quantize_mx4[grid]( @@ -259,6 +369,8 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: out, a.numel(), K, + rand_bits=rand_bits, + ROUNDING_MODE=rounding_mode, GROUP_SIZE=group_size, ) # Inputs are now fully quantized and ready to return. diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py index 09e2ea1ead..bd076e2ed3 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py @@ -6,16 +6,69 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe +from enum import IntEnum +from typing import Union + import torch -def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: +class RoundingMode(IntEnum): + nearest = 0 + floor = 1 + even = 2 + stochastic = 3 + ceil = 4 + + +def _compute_exp( + group_max, + rounding_mode, +): + """Compute shared exponent of group using specified rounding mode. + + Args: + group_max (Tensor): Group of values to compute exponent of. + rounding_mode (int or RoundingMode): Which rounding mode to use. + + Returns: + Tensor: Shared exponent of group. + """ + if rounding_mode == 0: + return torch.round(torch.log2(group_max)) + # Floor rounding mode. + if rounding_mode == 1: + return torch.floor(torch.log2(group_max)) + # Even pre-rounding mode. + elif rounding_mode == 2: + # First round to nearest even integer. + group_max = torch.round(group_max) + # Then perform floor rounding of log. + return torch.floor(torch.log2(group_max)) + # Stochastic rounding mode. + elif rounding_mode == 3: + # Create random noise. + noise = torch.rand_like(group_max) + # Add noise to group max and round down. + group_max = group_max + noise + # Now compute log and truncate. + return torch.floor(torch.log2(group_max)) + else: + return torch.ceil(torch.log2(group_max)) + + +def py_quantize_mx4( + a: torch.Tensor, + group_size: int = 32, + rounding_mode: Union[RoundingMode, int] = RoundingMode.even, +) -> torch.Tensor: """ Quantize a tensor to mx4 format. Args: a (Tensor): [M] higher precision input tensor. group_size (int): Size of chunks that will use the same shared exponent. + rounding_mode (int or RoundingMode): Which type of rounding to use when + calculating shared exponent. Returns: torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8 @@ -43,10 +96,8 @@ def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: # Replace zero values with the minimum expressible normal value. FP32_MIN_NORMAL = 2 ** (-126) shared_exp = torch.where(shared_exp == 0, FP32_MIN_NORMAL, shared_exp) - # Convert max into an intger exponent. - # Note this can be more efficient by just shifting and masking exp bits. - # We can even use those directly. - shared_exp = torch.ceil(torch.log2(shared_exp)) + # Convert max into an integer exponent. + shared_exp = _compute_exp(shared_exp, rounding_mode) # Offset exponent by largest exponent in target datatype. shared_exp = shared_exp - 2 # Restrict to range expressible as int8. diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index d88be72c12..5477b6ebba 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -15,7 +15,7 @@ import torch -from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32 +from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4 from hypothesis import given, settings, Verbosity @@ -185,11 +185,15 @@ def test_mx4(self, power: int, sizes: int) -> None: # Test intercompatibility between implementations. # Test CPU implementation - quantized_cpu = py_quantize_mx4(input, group_size) + quantized_cpu = py_quantize_mx4( + input, group_size, rounding_mode=RoundingMode.floor + ) output_cpu = py_dequantize_mx4(quantized_cpu, group_size) # Test Triton implementation - quantized_triton = fp32_to_mx4(input, group_size, use_triton=True) + quantized_triton = fp32_to_mx4( + input, group_size, rounding_mode=RoundingMode.floor, use_triton=True + ) output_triton = mx4_to_fp32(quantized_triton, group_size, use_triton=True) # Test shim functions @@ -205,6 +209,7 @@ def test_mx4(self, power: int, sizes: int) -> None: mbits, max_norm, mx_group_size=group_size, + rounding_mode=RoundingMode.floor, ) output_from_ops = torch.ops.fbgemm.dequantize_mx( quantized_from_ops,