From 46847b1845fe24ba83eb75faf9487aaf8183404a Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 12 Jul 2024 13:04:20 -0700 Subject: [PATCH] Refactor MX4 Kernel to operate on flat tensors (#2836) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2836 Rather than try to reshape inputs to 2D matrices with each thread operating on one row, this refactor uses 1D inputs and has each thread operate on an offset of the array. The main benefit of this is that it avoid ragged tensors where we cant divide an input into even sized rows. This should enable us to be compatible with more shapes. Differential Revision: D59653809 Reviewed By: sryap --- fbgemm_gpu/fbgemm_gpu/quantize_utils.py | 13 +- fbgemm_gpu/fbgemm_gpu/triton/quantize.py | 157 ++++++++++++----------- 2 files changed, 85 insertions(+), 85 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py index 300ab8829d..5e5c9505e4 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py @@ -8,7 +8,6 @@ # pyre-strict import logging -import math import torch @@ -45,15 +44,9 @@ def fp32_to_mx4( output: MX4 tensor packed into int8 values with total elements (M / 2 + M / groupsize) """ # Accelerated MX4 is only available on cuda, if input is on cpu, use python. - # For CPU and triton, set the second dim to 2048 or the nearest power of 2. - dim = ( - 2048 if tensor.numel() >= 2048 else 2 ** (math.floor(math.log2(tensor.numel()))) - ) - input = ( - tensor.view(-1) - if (tensor.is_cuda and not use_triton) or tensor.numel() % dim != 0 - else tensor.view(-1, dim) - ) + # Operate on flattened input. + input = tensor.flatten() + if not tensor.is_cuda: return py_quantize_mx4(input, group_size) diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py index d8adbb4a52..c03d34a4f6 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py @@ -15,8 +15,23 @@ from triton import Config # @manual +def prune_configs(configs, named_args, **kwargs): + """Helper function to remove invalid configurations.""" + group_size = kwargs["GROUP_SIZE"] + pruned_configs = [] + for config in configs: + block_size = config.kwargs["BLOCK_SIZE"] + # Dont use block sizes that are smaller than the group size. + if group_size <= block_size: + pruned_configs.append(config) + # Return only the valid configurations. + return pruned_configs + + @triton.autotune( configs=[ + Config({"BLOCK_SIZE": 32}), + Config({"BLOCK_SIZE": 64}), Config({"BLOCK_SIZE": 512}), Config({"BLOCK_SIZE": 1024}), Config({"BLOCK_SIZE": 2048}), @@ -24,6 +39,7 @@ Config({"BLOCK_SIZE": 8192}), ], key=["K"], + prune_configs_by={"early_config_prune": prune_configs}, ) @triton.jit def _kernel_quantize_mx4( @@ -32,37 +48,30 @@ def _kernel_quantize_mx4( out, M, K, - stride_am, - stride_ak, - stride_exp_m, - stride_exp_k, - stride_out_m, - stride_out_k, GROUP_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ) -> None: - """Quantize a float tensor into a packed MX4 tensor. + """Quantize a 1D float tensor into a packed MX4 tensor. Args: - A (Tensor): [M, K] float tensor to be quantized. + A (Tensor): [M] float tensor to be quantized. shared_exp (Tensor): [M / group_size] output containing shared exponent. - out (Tensor): [M, K / 2] output containing packed mx4 values. - M (int): Number of rows. - K (int): Number of columns. - stride_am (int): Stride of m dimension of A. - stride_ak (int): Stride of k dimension of A. - stride_exp_m (int): Stride of shared exponent in m dimension. - stride_exp_k (int): Stride of shared exponent in k dimension. - stride_out_m (int): Stride of output in m dimension. - stride_out_k (int): Stride of output in k dimension. + out (Tensor): [M / 2] output containing packed mx4 values. + M (int): Total number of elements. + K (int): Number of elements to process in each thread. GROUP_SIZE (int): Size of chunks that use the same shared exponent. BLOCK_SIZE (int): Size of each block. """ + # Get the current thread number. pid = tl.program_id(0) + # Find starting offsets for this thread. + input_start = pid * K + packed_start = pid * K // 2 + group_start = pid * K // GROUP_SIZE # Initiate offset ranges used in kernel. - k_offset = tl.arange(0, BLOCK_SIZE) - packed_offset = tl.arange(0, BLOCK_SIZE // 2) - group_offset = tl.arange(0, BLOCK_SIZE // GROUP_SIZE) + input_offset = tl.arange(0, BLOCK_SIZE) + input_start + packed_offset = tl.arange(0, BLOCK_SIZE // 2) + packed_start + group_offset = tl.arange(0, BLOCK_SIZE // GROUP_SIZE) + group_start # Define Constant Expressions. FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type] @@ -79,12 +88,13 @@ def _kernel_quantize_mx4( OVERFLOW_THRESHOLD: tl.constexpr = 4 # type: ignore[Incompatible variable type] FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type] - # First we need to compute shared exponent. + # Load and process blocks of values for this chunk. for _k in range(0, tl.cdiv(K, BLOCK_SIZE)): # Load a block of values. a = tl.load( - A + pid * stride_am + k_offset * stride_ak, - mask=k_offset < K, + A + input_offset, + # Mask values out of range for both the main array and this chunk. + mask=(input_offset < M) & (input_offset < (K * (pid + 1))), other=0, ) @@ -117,9 +127,11 @@ def _kernel_quantize_mx4( # We're done with group_exp now so we can write it out. # We readd fp32_exp_bias for compatibility with cuda dequant. tl.store( - shared_exp + pid * stride_exp_m + stride_exp_k * group_offset, + shared_exp + group_offset, (group_exp + FP32_EXP_BIAS).to(tl.int8), - mask=group_offset < K // GROUP_SIZE, + # Prevent writing outside this chunk or the main array. + mask=(group_offset < M // GROUP_SIZE) + & (group_offset < ((K // GROUP_SIZE) * (pid + 1))), ) # Quantization step @@ -188,13 +200,14 @@ def _kernel_quantize_mx4( # Next step is packing, lets write this out to check how it looks. tl.store( - out + pid * stride_out_m + packed_offset * stride_out_k, + out + packed_offset, packed_mx4, - mask=packed_offset < K // 2, + # Prevent writing outside this chunk or the main array. + mask=(packed_offset < M // 2) & (packed_offset < ((K // 2) * (pid + 1))), ) # Update offsets so we work on the next block. - k_offset += BLOCK_SIZE + input_offset += BLOCK_SIZE group_offset += BLOCK_SIZE // GROUP_SIZE packed_offset += BLOCK_SIZE // 2 @@ -236,12 +249,17 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: # If K is less than group_size, we compute a single group per row. if K == 0: K = group_size - a = a.view(-1, K) - M, K = a.shape + # We want to divide the input into chunks of size K. If that cant be done + # evenly, its ok for one chunk to be smaller. + M = int(math.ceil(a.numel() / K)) + # Flatten input. + a = a.flatten() # Create output tensors. - shared_exp = torch.empty([M, K // group_size], device=a.device, dtype=torch.uint8) - out = torch.empty([M, K // 2], device=a.device, dtype=torch.uint8) + shared_exp = torch.empty( + [a.numel() // group_size], device=a.device, dtype=torch.uint8 + ) + out = torch.empty([a.numel() // 2], device=a.device, dtype=torch.uint8) # Invoke triton quantization kernel over rows. grid = (M,) @@ -249,14 +267,8 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: a, shared_exp, out, - M, + a.numel(), K, - a.stride(0), - a.stride(1), - shared_exp.stride(0), - shared_exp.stride(1), - out.stride(0), - out.stride(1), GROUP_SIZE=group_size, ) # Ravel together output and shared exponent. @@ -275,6 +287,8 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: @triton.autotune( configs=[ + Config({"BLOCK_SIZE": 32}), + Config({"BLOCK_SIZE": 64}), Config({"BLOCK_SIZE": 512}), Config({"BLOCK_SIZE": 1024}), Config({"BLOCK_SIZE": 2048}), @@ -282,6 +296,7 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor: Config({"BLOCK_SIZE": 8192}), ], key=["K"], + prune_configs_by={"early_config_prune": prune_configs}, ) @triton.jit def _kernel_dequantize_mx4( @@ -291,35 +306,28 @@ def _kernel_dequantize_mx4( out, M, K, - stride_am, - stride_ak, - stride_exp_m, - stride_exp_k, - stride_out_m, - stride_out_k, GROUP_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, ) -> None: """Dequantize a packed MX4 tensor and apply scaling. Args: - A (Tensor): [M, K] MX4 tensor packed into int8. + A (Tensor): [M] MX4 tensor packed into int8. shared_exp (Tensor): Int8 tensor representing group exponent. mx4_lookup_table (Tensor): Map from mx4 integer value to floating point. - M (int): Number of rows. - K (int): Number of columns. - stride_am (int): Stride of m dimension of A. - stride_ak (int): Stride of k dimension of A. - stride_exp_m (int): Stride of m dimension of shared exponent tensor. - stride_exp_k (int): Stride of k dimension of shared exponent tensor. - stride_out_m (int): Stride of m dimension of output. - stride_out_k (int): Stride of k dimension of output. + M (int): Total number of elements in input. + K (int): Number of elements each thread should operate on. GROUP_SIZE (int): Size of chunks that use the same shared exponent. BLOCK_SIZE (int): Size of each block. """ + # Get the current thread number. pid = tl.program_id(0) - k_offset = tl.arange(0, BLOCK_SIZE) - output_offset = tl.arange(0, 2 * BLOCK_SIZE) + # Find the starting offsets for this thread. + input_start = pid * K + output_start = pid * K * 2 + # Initiate offset ranges used in this thread. + input_offset = tl.arange(0, BLOCK_SIZE) + input_start + output_offset = tl.arange(0, 2 * BLOCK_SIZE) + output_start # Define constants. MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type] @@ -328,7 +336,10 @@ def _kernel_dequantize_mx4( # Iterate over input tensor and unpack mx4 values. for _k in range(0, tl.cdiv(K, BLOCK_SIZE)): a = tl.load( - A + pid * stride_am + k_offset * stride_ak, mask=k_offset < K, other=0.0 + A + input_offset, + # Mask values that are out of this chunk or the main array. + mask=(input_offset < M) & (input_offset < (K * (pid + 1))), + other=0.0, ) # Extract high and low values from loaded mx4 tile. low_mx4 = a & MX4_BIT_MASK @@ -339,8 +350,8 @@ def _kernel_dequantize_mx4( high_fp32 = tl.load(mx4_lookup_table + high_mx4) # Get proper shared exponent and convert it to a float scale. - group_offset = (2 * k_offset) // GROUP_SIZE - exp = tl.load(shared_exp + pid * stride_exp_m + group_offset * stride_exp_k) + group_offset = (2 * input_offset) // GROUP_SIZE + exp = tl.load(shared_exp + group_offset) # Remove fp32 exponent bias. exp = exp.to(tl.uint8, bitcast=True) - FP32_EXP_BIAS @@ -356,12 +367,13 @@ def _kernel_dequantize_mx4( # Write final outputs. tl.store( - out + pid * stride_out_m + output_offset * stride_out_k, + out + output_offset, scaled_fp32, - mask=output_offset < 2 * K, + # Mask values that are out of this chunk or the main array. + mask=(output_offset < 2 * M) & (output_offset < ((2 * K) * (pid + 1))), ) - k_offset += BLOCK_SIZE + input_offset += BLOCK_SIZE output_offset += 2 * BLOCK_SIZE @@ -396,9 +408,11 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor K = (K // packed_group_size) * packed_group_size if K == 0: K = packed_group_size - packed_input = packed_input.reshape(-1, K) - shared_exp = shared_exp.reshape(-1, K // packed_group_size) - M, K_2 = packed_input.shape + # Try to evenly divide input into chunks of size K, allow last chunk to be smaller. + M = int(math.ceil(packed_input.numel() / K)) + # Flatten inputs. + packed_input = packed_input.flatten().contiguous() + shared_exp = shared_exp.flatten().contiguous() # Use a lookup table to convert mx4_to_fp_values = torch.tensor( @@ -408,8 +422,7 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor ) # Create output tensor. - out = torch.empty([M, 2 * K_2], device=a.device, dtype=torch.float) - + out = torch.empty([2 * packed_input.numel()], device=a.device, dtype=torch.float) # Invoke triton dequantization kernel over rows. grid = (M,) _kernel_dequantize_mx4[grid]( @@ -417,14 +430,8 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor shared_exp, mx4_to_fp_values, out, - M, - K_2, - packed_input.stride(0), - packed_input.stride(1), - shared_exp.stride(0), - shared_exp.stride(1), - out.stride(0), - out.stride(1), + packed_input.numel(), + K, GROUP_SIZE=group_size, )