Skip to content

Commit

Permalink
Triton MX4 Quantize Rounding Mode Support.
Browse files Browse the repository at this point in the history
Summary: This diff adds the `rounding_mode` argument to triton quantize. We support (almost) all the rounding described in the [best practices doc](https://docs.google.com/document/d/156Du0hBRH6umG_i-OrYC574XhpQMUU5SJYG0RTS2tTg/edit#heading=h.akfcp7xpg8cr). Stochastic coming soon.

Differential Revision: D59562029
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jul 10, 2024
1 parent f228611 commit a3684eb
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 9 deletions.
116 changes: 112 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

# pyre-unsafe
import math
from enum import IntEnum
from typing import Union

import torch
import triton # @manual
Expand All @@ -15,6 +17,82 @@
from triton import Config # @manual


class RoundingMode(IntEnum):
nearest = 0
floor = 1
even = 2
stochastic = 3
ceil = 4


@triton.jit
def _floor_log2(x):
"""Helper function to efficiently compute floor(log2(x))
Args:
x (Tensor): FP32 Input tensor to operate on.
Returns:
Tensor: Floor of log2(x).
"""
# Helpful bit constants.
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]

# View x as an integer and extract its exponent.
x = x.to(tl.int32, bitcast=True) & FP32_EXP_MASK
# Shift exponent down to bottom bits.
x = x >> FP32_EXP_OFFSET
# Remove FP32 exponent bias and return.
return (x - FP32_EXP_BIAS).to(tl.float32)


@triton.jit
def _compute_exp(
group_max,
rounding_mode,
noise,
):
"""Compute shared exponent of group using specified rounding mode.
Args:
group_max (Tensor): Group of values to compute exponent of.
rounding_mode (int or RoundingMode): Which rounding mode to use.
noise (float): Random noise between 0 and 1 for stochastic rounding.
Returns:
Tensor: Shared exponent of group.
"""
# TODO(jwfromm) Enable hip libdevice support as well.
# Nearest rounding mode.
if rounding_mode == 0:
return tl.extra.cuda.libdevice.round(tl.log2(group_max))
# Floor rounding mode. This can be done with fast bit ops.
if rounding_mode == 1:
return _floor_log2(group_max)
# Even pre-rounding mode.
elif rounding_mode == 2:
# First round to nearest even integer.
group_max = tl.extra.cuda.libdevice.rint(group_max)
# Then perform floor rounding of log.
return _floor_log2(group_max)
# Stochastic rounding mode.
elif rounding_mode == 3:
# Compute probability of rounding up or down.
group_floor = tl.floor(group_max)
# Probability of rounding is based on distance from floor or ceil.
prob = group_max - group_floor
# Compare probability to random noise for this group.
# If prob is larger than the random value, round up,
# otherwise round down.
group_max = tl.where(prob > noise, tl.ceil(group_max), group_floor)
# Now compute log and truncate.
return _floor_log2(group_max)
else:
return tl.ceil(tl.log2(group_max))


@triton.autotune(
configs=[
Config({"BLOCK_SIZE": 512}),
Expand All @@ -38,6 +116,8 @@ def _kernel_quantize_mx4(
stride_exp_k,
stride_out_m,
stride_out_k,
noise,
ROUNDING_MODE: tl.constexpr,
GROUP_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
) -> None:
Expand All @@ -55,6 +135,7 @@ def _kernel_quantize_mx4(
stride_exp_k (int): Stride of shared exponent in k dimension.
stride_out_m (int): Stride of output in m dimension.
stride_out_k (int): Stride of output in k dimension.
ROUNDING_MODE (int): Which rounding method to use when calculating shared exponent.
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
BLOCK_SIZE (int): Size of each block.
"""
Expand Down Expand Up @@ -97,9 +178,17 @@ def _kernel_quantize_mx4(
group_max = tl.max(tl.abs(a_groups), axis=1)
# Prevent infinite values in log.
group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max)
# Convert max to exponent via direct log computation and ceiling
# rounding to minimize errors.
group_exp = tl.ceil(tl.log2(group_max))
# Load relevant noise values if doing stochastic rounding.
if ROUNDING_MODE == 3:
group_noise = tl.load(
noise + pid * stride_exp_m + stride_exp_k * group_offset,
mask=group_offset < K // GROUP_SIZE,
other=0,
)
else:
group_noise = None
# Compute shared exponent using specified rounding mode.
group_exp = _compute_exp(group_max, ROUNDING_MODE, group_noise)
# Subtract largest exponent in target datatype and remove bias.
group_exp = group_exp - 2
# Make sure exponent is in valid range.
Expand Down Expand Up @@ -199,13 +288,19 @@ def _kernel_quantize_mx4(
packed_offset += BLOCK_SIZE // 2


def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
def triton_quantize_mx4(
a: torch.Tensor,
group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.floor,
) -> torch.Tensor:
"""
Quantize a tensor to mx4 format using efficient triton kernels.
Args:
a (Tensor): [M] higher precision input tensor.
group_size (int): Size of chunks that will use the same shared exponent.
rounding_mode (Union[RoundingMode, int]): Which type of rounding to use
when calculating shared exponent. Defaults to pre-rounding to nearest even int.
Returns:
torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
Expand Down Expand Up @@ -245,6 +340,17 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
shared_exp = torch.empty([M, K // group_size], device=a.device, dtype=torch.uint8)
out = torch.empty([M, K // 2], device=a.device, dtype=torch.uint8)

# If using stochastic rounding, create random noise for each group.
if rounding_mode == RoundingMode.stochastic:
# Each group will need a seed.
noise = torch.rand(
size=(M, K // group_size),
dtype=torch.float32,
device=a.device,
)
else:
noise = None

# Invoke triton quantization kernel over rows.
grid = (M,)
_kernel_quantize_mx4[grid](
Expand All @@ -259,6 +365,8 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
shared_exp.stride(1),
out.stride(0),
out.stride(1),
noise=noise,
ROUNDING_MODE=rounding_mode,
GROUP_SIZE=group_size,
)
# Ravel together output and shared exponent.
Expand Down
67 changes: 62 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,75 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from enum import IntEnum
from typing import Union

import torch


def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
class RoundingMode(IntEnum):
nearest = 0
floor = 1
even = 2
stochastic = 3
ceil = 4


def _compute_exp(
group_max,
rounding_mode,
):
"""Compute shared exponent of group using specified rounding mode.
Args:
group_max (Tensor): Group of values to compute exponent of.
rounding_mode (int or RoundingMode): Which rounding mode to use.
Returns:
Tensor: Shared exponent of group.
"""
if rounding_mode == 0:
return torch.round(torch.log2(group_max))
# Floor rounding mode.
if rounding_mode == 1:
return torch.floor(torch.log2(group_max))
# Even pre-rounding mode.
elif rounding_mode == 2:
# First round to nearest even integer.
group_max = torch.round(group_max)
# Then perform floor rounding of log.
return torch.floor(torch.log2(group_max))
# Stochastic rounding mode.
elif rounding_mode == 3:
# Create random noise.
noise = torch.rand_like(group_max)
# Compute probability of rounding up or down.
group_floor = torch.floor(group_max)
# Probability of rounding is based on distance from floor or ceil.
prob = group_max - group_floor
# Compare probability to random noise for this group.
# If prob is larger than the random value, round up,
# otherwise round down.
group_max = torch.where(prob > noise, torch.ceil(group_max), group_floor)
# Now compute log and truncate.
return torch.floor(torch.log2(group_max))
else:
return torch.ceil(torch.log2(group_max))


def py_quantize_mx4(
a: torch.Tensor,
group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.floor,
) -> torch.Tensor:
"""
Quantize a tensor to mx4 format.
Args:
a (Tensor): [M] higher precision input tensor.
group_size (int): Size of chunks that will use the same shared exponent.
rounding_mode (int or RoundingMode): Which type of rounding to use when
calculating shared exponent.
Returns:
torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
Expand Down Expand Up @@ -43,10 +102,8 @@ def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
# Replace zero values with the minimum expressible normal value.
FP32_MIN_NORMAL = 2 ** (-126)
shared_exp = torch.where(shared_exp == 0, FP32_MIN_NORMAL, shared_exp)
# Convert max into an intger exponent.
# Note this can be more efficient by just shifting and masking exp bits.
# We can even use those directly.
shared_exp = torch.ceil(torch.log2(shared_exp))
# Convert max into an integer exponent.
shared_exp = _compute_exp(shared_exp, rounding_mode)
# Offset exponent by largest exponent in target datatype.
shared_exp = shared_exp - 2
# Restrict to range expressible as int8.
Expand Down

0 comments on commit a3684eb

Please sign in to comment.