From ceaa0896ba5425d55b4a653988e9df63a14ccbf9 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 12 Jul 2024 13:04:00 -0700 Subject: [PATCH] Use better exponent rounding in Triton MX4 quantize kernel (#2816) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/20 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2816 As noted in [this doc](https://docs.google.com/document/d/156Du0hBRH6umG_i-OrYC574XhpQMUU5SJYG0RTS2tTg/edit#heading=h.akfcp7xpg8cr), using a ceiling round for scale calculation does a better job of not truncating some mantissa bits. This diff switches triton's floor rounding to ceil rounding. Note that currently mx4_test doesnt pass as the cuda kernel now has different behavior than triton. Once we rebase this diff onto a similar change to the cuda kernel, we should see exact matching outputs again. Differential Revision: D59527463 Reviewed By: jianyuh --- fbgemm_gpu/fbgemm_gpu/quantize/__init__.py | 1 + fbgemm_gpu/fbgemm_gpu/triton/quantize.py | 53 +++++++++++--------- fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py | 4 +- 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize/__init__.py b/fbgemm_gpu/fbgemm_gpu/quantize/__init__.py index 2ab4ea8fb2..0580cf863b 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize/__init__.py @@ -7,6 +7,7 @@ # pyre-strict import torch + from fbgemm_gpu.quantize.quantize_ops import dequantize_mx, quantize_mx # noqa F401 diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py index 910d7077df..d8adbb4a52 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py @@ -77,6 +77,7 @@ def _kernel_quantize_mx4( MAX_FP32_MANTISSA_BITS: tl.constexpr = 24 # type: ignore[Incompatible variable type] IMPLIED_1_BIT: tl.constexpr = 1 << 23 # type: ignore[Incompatible variable type] OVERFLOW_THRESHOLD: tl.constexpr = 4 # type: ignore[Incompatible variable type] + FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type] # First we need to compute shared exponent. for _k in range(0, tl.cdiv(K, BLOCK_SIZE)): @@ -84,7 +85,7 @@ def _kernel_quantize_mx4( a = tl.load( A + pid * stride_am + k_offset * stride_ak, mask=k_offset < K, - other=-float("inf"), + other=0, ) # Scaling step @@ -94,13 +95,15 @@ def _kernel_quantize_mx4( a_groups = tl.reshape(a, [BLOCK_SIZE // GROUP_SIZE, GROUP_SIZE]) # Compute the shared exponent of each group. group_max = tl.max(tl.abs(a_groups), axis=1) - # Convert max to exponent via bit operations. - group_exp = group_max.to(tl.int32, bitcast=True) & FP32_EXP_MASK - group_exp = group_exp >> FP32_EXP_OFFSET + # Prevent infinite values in log. + group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max) + # Convert max to exponent via direct log computation and ceiling + # rounding to minimize errors. + group_exp = tl.ceil(tl.log2(group_max)) # Subtract largest exponent in target datatype and remove bias. - group_exp = group_exp - 2 - FP32_EXP_BIAS - # Clamp to valid int8 range. - group_exp = tl.maximum(group_exp, -127) + group_exp = group_exp - 2 + # Make sure exponent is in valid range. + group_exp = tl.clamp(group_exp, -127, 125) # Next we scale A in preparation for quantization. scale = tl.exp2(group_exp.to(tl.float64)).to(tl.float32) @@ -113,10 +116,9 @@ def _kernel_quantize_mx4( # We're done with group_exp now so we can write it out. # We readd fp32_exp_bias for compatibility with cuda dequant. - group_exp = group_exp.to(tl.int8) tl.store( shared_exp + pid * stride_exp_m + stride_exp_k * group_offset, - group_exp + FP32_EXP_BIAS, + (group_exp + FP32_EXP_BIAS).to(tl.int8), mask=group_offset < K // GROUP_SIZE, ) @@ -228,16 +230,15 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: # We do this by finding the power of two that is closest to # the sqrt of the number of elements. num_threads = int(2 ** round(math.log2(math.sqrt(a.numel())))) - # Make sure that num_threads is a multiple of group_size. - num_threads = (num_threads // group_size) * group_size - if num_threads == 0: - num_threads = a.numel() // group_size - a = a.view(num_threads, -1) - M, K = a.shape + # Make sure that the number of elements per row is a multiple of group_size. + K = a.numel() // num_threads + K = (K // group_size) * group_size # If K is less than group_size, we compute a single group per row. - if K < group_size: - a = a.view(-1, group_size) - M, K = a.shape + if K == 0: + K = group_size + a = a.view(-1, K) + M, K = a.shape + # Create output tensors. shared_exp = torch.empty([M, K // group_size], device=a.device, dtype=torch.uint8) out = torch.empty([M, K // 2], device=a.device, dtype=torch.uint8) @@ -382,19 +383,21 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor # View a as 2D for simplicity. orig_shape = a.shape # Unravel packed inputs from shared exponents. - a = a.view(-1, (group_size // 2) + 1) + packed_group_size = group_size // 2 + a = a.view(-1, packed_group_size + 1) packed_input = a[:, :-1] shared_exp = a[:, -1:] # Find a shape that distributes work evenly over threads. # We do this by finding the power of two that is closest to # the sqrt of the number of elements. num_threads = int(2 ** round(math.log2(math.sqrt(packed_input.numel())))) - # Make sure that num_threads is a multiple of group_size. - num_threads = (num_threads // group_size) * group_size - if num_threads == 0: - num_threads = packed_input.numel() // group_size - packed_input = packed_input.reshape(num_threads, -1) - shared_exp = shared_exp.reshape(num_threads, -1) + # Make sure that the number of elements per row is a multiple of packed group_size. + K = packed_input.numel() // num_threads + K = (K // packed_group_size) * packed_group_size + if K == 0: + K = packed_group_size + packed_input = packed_input.reshape(-1, K) + shared_exp = shared_exp.reshape(-1, K // packed_group_size) M, K_2 = packed_input.shape # Use a lookup table to convert diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py index a4999440f2..09e2ea1ead 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py @@ -46,11 +46,11 @@ def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: # Convert max into an intger exponent. # Note this can be more efficient by just shifting and masking exp bits. # We can even use those directly. - shared_exp = torch.floor(torch.log2(shared_exp)) + shared_exp = torch.ceil(torch.log2(shared_exp)) # Offset exponent by largest exponent in target datatype. shared_exp = shared_exp - 2 # Restrict to range expressible as int8. - shared_exp = torch.clamp(shared_exp, min=-127, max=127) + shared_exp = torch.clamp(shared_exp, min=-127, max=125) # Convert exponent to scale and apply to input. # Need to do this calculation on cpu for accuracy. _shared_exp = shared_exp.cpu()