Skip to content

Commit

Permalink
Use better exponent rounding in Triton MX4 quantize kernel (pytorch#2816
Browse files Browse the repository at this point in the history
)

Summary:
X-link: facebookresearch/FBGEMM#20

Pull Request resolved: pytorch#2816

As noted in [this doc](https://docs.google.com/document/d/156Du0hBRH6umG_i-OrYC574XhpQMUU5SJYG0RTS2tTg/edit#heading=h.akfcp7xpg8cr), using a ceiling round for scale calculation does a better job of not truncating some mantissa bits. This diff switches triton's floor rounding to ceil rounding.

Note that currently mx4_test doesnt pass as the cuda kernel now has different behavior than triton. Once we rebase this diff onto a similar change to the cuda kernel, we should see exact matching outputs again.

Differential Revision: D59527463

Reviewed By: jianyuh
  • Loading branch information
Josh Fromm authored and facebook-github-bot committed Jul 12, 2024
1 parent 88349ff commit ceaa089
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import torch

from fbgemm_gpu.quantize.quantize_ops import dequantize_mx, quantize_mx # noqa F401


Expand Down
53 changes: 28 additions & 25 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ def _kernel_quantize_mx4(
MAX_FP32_MANTISSA_BITS: tl.constexpr = 24 # type: ignore[Incompatible variable type]
IMPLIED_1_BIT: tl.constexpr = 1 << 23 # type: ignore[Incompatible variable type]
OVERFLOW_THRESHOLD: tl.constexpr = 4 # type: ignore[Incompatible variable type]
FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]

# First we need to compute shared exponent.
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
# Load a block of values.
a = tl.load(
A + pid * stride_am + k_offset * stride_ak,
mask=k_offset < K,
other=-float("inf"),
other=0,
)

# Scaling step
Expand All @@ -94,13 +95,15 @@ def _kernel_quantize_mx4(
a_groups = tl.reshape(a, [BLOCK_SIZE // GROUP_SIZE, GROUP_SIZE])
# Compute the shared exponent of each group.
group_max = tl.max(tl.abs(a_groups), axis=1)
# Convert max to exponent via bit operations.
group_exp = group_max.to(tl.int32, bitcast=True) & FP32_EXP_MASK
group_exp = group_exp >> FP32_EXP_OFFSET
# Prevent infinite values in log.
group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max)
# Convert max to exponent via direct log computation and ceiling
# rounding to minimize errors.
group_exp = tl.ceil(tl.log2(group_max))
# Subtract largest exponent in target datatype and remove bias.
group_exp = group_exp - 2 - FP32_EXP_BIAS
# Clamp to valid int8 range.
group_exp = tl.maximum(group_exp, -127)
group_exp = group_exp - 2
# Make sure exponent is in valid range.
group_exp = tl.clamp(group_exp, -127, 125)

# Next we scale A in preparation for quantization.
scale = tl.exp2(group_exp.to(tl.float64)).to(tl.float32)
Expand All @@ -113,10 +116,9 @@ def _kernel_quantize_mx4(

# We're done with group_exp now so we can write it out.
# We readd fp32_exp_bias for compatibility with cuda dequant.
group_exp = group_exp.to(tl.int8)
tl.store(
shared_exp + pid * stride_exp_m + stride_exp_k * group_offset,
group_exp + FP32_EXP_BIAS,
(group_exp + FP32_EXP_BIAS).to(tl.int8),
mask=group_offset < K // GROUP_SIZE,
)

Expand Down Expand Up @@ -228,16 +230,15 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
# We do this by finding the power of two that is closest to
# the sqrt of the number of elements.
num_threads = int(2 ** round(math.log2(math.sqrt(a.numel()))))
# Make sure that num_threads is a multiple of group_size.
num_threads = (num_threads // group_size) * group_size
if num_threads == 0:
num_threads = a.numel() // group_size
a = a.view(num_threads, -1)
M, K = a.shape
# Make sure that the number of elements per row is a multiple of group_size.
K = a.numel() // num_threads
K = (K // group_size) * group_size
# If K is less than group_size, we compute a single group per row.
if K < group_size:
a = a.view(-1, group_size)
M, K = a.shape
if K == 0:
K = group_size
a = a.view(-1, K)
M, K = a.shape

# Create output tensors.
shared_exp = torch.empty([M, K // group_size], device=a.device, dtype=torch.uint8)
out = torch.empty([M, K // 2], device=a.device, dtype=torch.uint8)
Expand Down Expand Up @@ -382,19 +383,21 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor
# View a as 2D for simplicity.
orig_shape = a.shape
# Unravel packed inputs from shared exponents.
a = a.view(-1, (group_size // 2) + 1)
packed_group_size = group_size // 2
a = a.view(-1, packed_group_size + 1)
packed_input = a[:, :-1]
shared_exp = a[:, -1:]
# Find a shape that distributes work evenly over threads.
# We do this by finding the power of two that is closest to
# the sqrt of the number of elements.
num_threads = int(2 ** round(math.log2(math.sqrt(packed_input.numel()))))
# Make sure that num_threads is a multiple of group_size.
num_threads = (num_threads // group_size) * group_size
if num_threads == 0:
num_threads = packed_input.numel() // group_size
packed_input = packed_input.reshape(num_threads, -1)
shared_exp = shared_exp.reshape(num_threads, -1)
# Make sure that the number of elements per row is a multiple of packed group_size.
K = packed_input.numel() // num_threads
K = (K // packed_group_size) * packed_group_size
if K == 0:
K = packed_group_size
packed_input = packed_input.reshape(-1, K)
shared_exp = shared_exp.reshape(-1, K // packed_group_size)
M, K_2 = packed_input.shape

# Use a lookup table to convert
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
# Convert max into an intger exponent.
# Note this can be more efficient by just shifting and masking exp bits.
# We can even use those directly.
shared_exp = torch.floor(torch.log2(shared_exp))
shared_exp = torch.ceil(torch.log2(shared_exp))
# Offset exponent by largest exponent in target datatype.
shared_exp = shared_exp - 2
# Restrict to range expressible as int8.
shared_exp = torch.clamp(shared_exp, min=-127, max=127)
shared_exp = torch.clamp(shared_exp, min=-127, max=125)
# Convert exponent to scale and apply to input.
# Need to do this calculation on cpu for accuracy.
_shared_exp = shared_exp.cpu()
Expand Down

0 comments on commit ceaa089

Please sign in to comment.