Skip to content

Commit

Permalink
Refactor MX4 Kernel to operate on flat tensors (pytorch#2836)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2836

Rather than try to reshape inputs to 2D matrices with each thread operating on one row, this refactor uses 1D inputs and has each thread operate on an offset of the array.

The main benefit of this is that it avoid ragged tensors where we cant divide an input into even sized rows. This should enable us to be compatible with more shapes.

Differential Revision: D59653809

Reviewed By: sryap
  • Loading branch information
Josh Fromm authored and facebook-github-bot committed Jul 12, 2024
1 parent ceaa089 commit 46847b1
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 85 deletions.
13 changes: 3 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# pyre-strict

import logging
import math

import torch

Expand Down Expand Up @@ -45,15 +44,9 @@ def fp32_to_mx4(
output: MX4 tensor packed into int8 values with total elements (M / 2 + M / groupsize)
"""
# Accelerated MX4 is only available on cuda, if input is on cpu, use python.
# For CPU and triton, set the second dim to 2048 or the nearest power of 2.
dim = (
2048 if tensor.numel() >= 2048 else 2 ** (math.floor(math.log2(tensor.numel())))
)
input = (
tensor.view(-1)
if (tensor.is_cuda and not use_triton) or tensor.numel() % dim != 0
else tensor.view(-1, dim)
)
# Operate on flattened input.
input = tensor.flatten()

if not tensor.is_cuda:
return py_quantize_mx4(input, group_size)

Expand Down
157 changes: 82 additions & 75 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,31 @@
from triton import Config # @manual


def prune_configs(configs, named_args, **kwargs):
"""Helper function to remove invalid configurations."""
group_size = kwargs["GROUP_SIZE"]
pruned_configs = []
for config in configs:
block_size = config.kwargs["BLOCK_SIZE"]
# Dont use block sizes that are smaller than the group size.
if group_size <= block_size:
pruned_configs.append(config)
# Return only the valid configurations.
return pruned_configs


@triton.autotune(
configs=[
Config({"BLOCK_SIZE": 32}),
Config({"BLOCK_SIZE": 64}),
Config({"BLOCK_SIZE": 512}),
Config({"BLOCK_SIZE": 1024}),
Config({"BLOCK_SIZE": 2048}),
Config({"BLOCK_SIZE": 4096}),
Config({"BLOCK_SIZE": 8192}),
],
key=["K"],
prune_configs_by={"early_config_prune": prune_configs},
)
@triton.jit
def _kernel_quantize_mx4(
Expand All @@ -32,37 +48,30 @@ def _kernel_quantize_mx4(
out,
M,
K,
stride_am,
stride_ak,
stride_exp_m,
stride_exp_k,
stride_out_m,
stride_out_k,
GROUP_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
) -> None:
"""Quantize a float tensor into a packed MX4 tensor.
"""Quantize a 1D float tensor into a packed MX4 tensor.
Args:
A (Tensor): [M, K] float tensor to be quantized.
A (Tensor): [M] float tensor to be quantized.
shared_exp (Tensor): [M / group_size] output containing shared exponent.
out (Tensor): [M, K / 2] output containing packed mx4 values.
M (int): Number of rows.
K (int): Number of columns.
stride_am (int): Stride of m dimension of A.
stride_ak (int): Stride of k dimension of A.
stride_exp_m (int): Stride of shared exponent in m dimension.
stride_exp_k (int): Stride of shared exponent in k dimension.
stride_out_m (int): Stride of output in m dimension.
stride_out_k (int): Stride of output in k dimension.
out (Tensor): [M / 2] output containing packed mx4 values.
M (int): Total number of elements.
K (int): Number of elements to process in each thread.
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
BLOCK_SIZE (int): Size of each block.
"""
# Get the current thread number.
pid = tl.program_id(0)
# Find starting offsets for this thread.
input_start = pid * K
packed_start = pid * K // 2
group_start = pid * K // GROUP_SIZE
# Initiate offset ranges used in kernel.
k_offset = tl.arange(0, BLOCK_SIZE)
packed_offset = tl.arange(0, BLOCK_SIZE // 2)
group_offset = tl.arange(0, BLOCK_SIZE // GROUP_SIZE)
input_offset = tl.arange(0, BLOCK_SIZE) + input_start
packed_offset = tl.arange(0, BLOCK_SIZE // 2) + packed_start
group_offset = tl.arange(0, BLOCK_SIZE // GROUP_SIZE) + group_start

# Define Constant Expressions.
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
Expand All @@ -79,12 +88,13 @@ def _kernel_quantize_mx4(
OVERFLOW_THRESHOLD: tl.constexpr = 4 # type: ignore[Incompatible variable type]
FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]

# First we need to compute shared exponent.
# Load and process blocks of values for this chunk.
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
# Load a block of values.
a = tl.load(
A + pid * stride_am + k_offset * stride_ak,
mask=k_offset < K,
A + input_offset,
# Mask values out of range for both the main array and this chunk.
mask=(input_offset < M) & (input_offset < (K * (pid + 1))),
other=0,
)

Expand Down Expand Up @@ -117,9 +127,11 @@ def _kernel_quantize_mx4(
# We're done with group_exp now so we can write it out.
# We readd fp32_exp_bias for compatibility with cuda dequant.
tl.store(
shared_exp + pid * stride_exp_m + stride_exp_k * group_offset,
shared_exp + group_offset,
(group_exp + FP32_EXP_BIAS).to(tl.int8),
mask=group_offset < K // GROUP_SIZE,
# Prevent writing outside this chunk or the main array.
mask=(group_offset < M // GROUP_SIZE)
& (group_offset < ((K // GROUP_SIZE) * (pid + 1))),
)

# Quantization step
Expand Down Expand Up @@ -188,13 +200,14 @@ def _kernel_quantize_mx4(

# Next step is packing, lets write this out to check how it looks.
tl.store(
out + pid * stride_out_m + packed_offset * stride_out_k,
out + packed_offset,
packed_mx4,
mask=packed_offset < K // 2,
# Prevent writing outside this chunk or the main array.
mask=(packed_offset < M // 2) & (packed_offset < ((K // 2) * (pid + 1))),
)

# Update offsets so we work on the next block.
k_offset += BLOCK_SIZE
input_offset += BLOCK_SIZE
group_offset += BLOCK_SIZE // GROUP_SIZE
packed_offset += BLOCK_SIZE // 2

Expand Down Expand Up @@ -236,27 +249,26 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
# If K is less than group_size, we compute a single group per row.
if K == 0:
K = group_size
a = a.view(-1, K)
M, K = a.shape
# We want to divide the input into chunks of size K. If that cant be done
# evenly, its ok for one chunk to be smaller.
M = int(math.ceil(a.numel() / K))
# Flatten input.
a = a.flatten()

# Create output tensors.
shared_exp = torch.empty([M, K // group_size], device=a.device, dtype=torch.uint8)
out = torch.empty([M, K // 2], device=a.device, dtype=torch.uint8)
shared_exp = torch.empty(
[a.numel() // group_size], device=a.device, dtype=torch.uint8
)
out = torch.empty([a.numel() // 2], device=a.device, dtype=torch.uint8)

# Invoke triton quantization kernel over rows.
grid = (M,)
_kernel_quantize_mx4[grid](
a,
shared_exp,
out,
M,
a.numel(),
K,
a.stride(0),
a.stride(1),
shared_exp.stride(0),
shared_exp.stride(1),
out.stride(0),
out.stride(1),
GROUP_SIZE=group_size,
)
# Ravel together output and shared exponent.
Expand All @@ -275,13 +287,16 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:

@triton.autotune(
configs=[
Config({"BLOCK_SIZE": 32}),
Config({"BLOCK_SIZE": 64}),
Config({"BLOCK_SIZE": 512}),
Config({"BLOCK_SIZE": 1024}),
Config({"BLOCK_SIZE": 2048}),
Config({"BLOCK_SIZE": 4096}),
Config({"BLOCK_SIZE": 8192}),
],
key=["K"],
prune_configs_by={"early_config_prune": prune_configs},
)
@triton.jit
def _kernel_dequantize_mx4(
Expand All @@ -291,35 +306,28 @@ def _kernel_dequantize_mx4(
out,
M,
K,
stride_am,
stride_ak,
stride_exp_m,
stride_exp_k,
stride_out_m,
stride_out_k,
GROUP_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
) -> None:
"""Dequantize a packed MX4 tensor and apply scaling.
Args:
A (Tensor): [M, K] MX4 tensor packed into int8.
A (Tensor): [M] MX4 tensor packed into int8.
shared_exp (Tensor): Int8 tensor representing group exponent.
mx4_lookup_table (Tensor): Map from mx4 integer value to floating point.
M (int): Number of rows.
K (int): Number of columns.
stride_am (int): Stride of m dimension of A.
stride_ak (int): Stride of k dimension of A.
stride_exp_m (int): Stride of m dimension of shared exponent tensor.
stride_exp_k (int): Stride of k dimension of shared exponent tensor.
stride_out_m (int): Stride of m dimension of output.
stride_out_k (int): Stride of k dimension of output.
M (int): Total number of elements in input.
K (int): Number of elements each thread should operate on.
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
BLOCK_SIZE (int): Size of each block.
"""
# Get the current thread number.
pid = tl.program_id(0)
k_offset = tl.arange(0, BLOCK_SIZE)
output_offset = tl.arange(0, 2 * BLOCK_SIZE)
# Find the starting offsets for this thread.
input_start = pid * K
output_start = pid * K * 2
# Initiate offset ranges used in this thread.
input_offset = tl.arange(0, BLOCK_SIZE) + input_start
output_offset = tl.arange(0, 2 * BLOCK_SIZE) + output_start

# Define constants.
MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type]
Expand All @@ -328,7 +336,10 @@ def _kernel_dequantize_mx4(
# Iterate over input tensor and unpack mx4 values.
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
a = tl.load(
A + pid * stride_am + k_offset * stride_ak, mask=k_offset < K, other=0.0
A + input_offset,
# Mask values that are out of this chunk or the main array.
mask=(input_offset < M) & (input_offset < (K * (pid + 1))),
other=0.0,
)
# Extract high and low values from loaded mx4 tile.
low_mx4 = a & MX4_BIT_MASK
Expand All @@ -339,8 +350,8 @@ def _kernel_dequantize_mx4(
high_fp32 = tl.load(mx4_lookup_table + high_mx4)

# Get proper shared exponent and convert it to a float scale.
group_offset = (2 * k_offset) // GROUP_SIZE
exp = tl.load(shared_exp + pid * stride_exp_m + group_offset * stride_exp_k)
group_offset = (2 * input_offset) // GROUP_SIZE
exp = tl.load(shared_exp + group_offset)
# Remove fp32 exponent bias.
exp = exp.to(tl.uint8, bitcast=True) - FP32_EXP_BIAS

Expand All @@ -356,12 +367,13 @@ def _kernel_dequantize_mx4(

# Write final outputs.
tl.store(
out + pid * stride_out_m + output_offset * stride_out_k,
out + output_offset,
scaled_fp32,
mask=output_offset < 2 * K,
# Mask values that are out of this chunk or the main array.
mask=(output_offset < 2 * M) & (output_offset < ((2 * K) * (pid + 1))),
)

k_offset += BLOCK_SIZE
input_offset += BLOCK_SIZE
output_offset += 2 * BLOCK_SIZE


Expand Down Expand Up @@ -396,9 +408,11 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor
K = (K // packed_group_size) * packed_group_size
if K == 0:
K = packed_group_size
packed_input = packed_input.reshape(-1, K)
shared_exp = shared_exp.reshape(-1, K // packed_group_size)
M, K_2 = packed_input.shape
# Try to evenly divide input into chunks of size K, allow last chunk to be smaller.
M = int(math.ceil(packed_input.numel() / K))
# Flatten inputs.
packed_input = packed_input.flatten().contiguous()
shared_exp = shared_exp.flatten().contiguous()

# Use a lookup table to convert
mx4_to_fp_values = torch.tensor(
Expand All @@ -408,23 +422,16 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor
)

# Create output tensor.
out = torch.empty([M, 2 * K_2], device=a.device, dtype=torch.float)

out = torch.empty([2 * packed_input.numel()], device=a.device, dtype=torch.float)
# Invoke triton dequantization kernel over rows.
grid = (M,)
_kernel_dequantize_mx4[grid](
packed_input,
shared_exp,
mx4_to_fp_values,
out,
M,
K_2,
packed_input.stride(0),
packed_input.stride(1),
shared_exp.stride(0),
shared_exp.stride(1),
out.stride(0),
out.stride(1),
packed_input.numel(),
K,
GROUP_SIZE=group_size,
)

Expand Down

0 comments on commit 46847b1

Please sign in to comment.