Skip to content

Commit

Permalink
Triton MX4 Quantize Rounding Mode Support. (#2821)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2821

X-link: facebookresearch/FBGEMM#22

This diff adds the `rounding_mode` argument to triton quantize. We support all the rounding described in the [best practices doc](https://docs.google.com/document/d/156Du0hBRH6umG_i-OrYC574XhpQMUU5SJYG0RTS2tTg/edit#heading=h.akfcp7xpg8cr).

Differential Revision: D59562029
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jul 12, 2024
1 parent 973c0f1 commit 1a56f8a
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 20 deletions.
3 changes: 2 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def op_registeration(
int elem_ebits,
int elem_mbits,
float elem_max_norm,
int mx_group_size
int mx_group_size,
int rounding_mode
) -> Tensor
"""
)
Expand Down
11 changes: 8 additions & 3 deletions fbgemm_gpu/fbgemm_gpu/quantize/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# # pyre-unsafe
# pyre-unsafe
from typing import Union

import torch

from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32
from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode


def quantize_mx(
Expand All @@ -18,6 +19,7 @@ def quantize_mx(
elem_mbits: int = 3,
elem_max_norm: float = 6.0,
mx_group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.even,
) -> torch.Tensor:
"""
Registered quantize_mx ops for E2E comm.
Expand All @@ -31,14 +33,17 @@ def quantize_mx(
i.e., 3 for MX4 e2m1)
elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1)
mx_group_size: num elements that share the max shared_exponent
rounding_mode: Which type of rounding to use when calculating shared exponent.
Return:
output: MX4 tensor packed into int8 values with size
(total_elems / 2 + total_elems / groupsize)
the shared exponent of each group is stored at the last byte
of output of each group
"""
return fp32_to_mx4(input, mx_group_size, use_triton=True)
return fp32_to_mx4(
input, mx_group_size, rounding_mode=rounding_mode, use_triton=True
)


def dequantize_mx(
Expand Down
14 changes: 10 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

import logging
import math

Check failure on line 11 in fbgemm_gpu/fbgemm_gpu/quantize_utils.py

View workflow job for this annotation

GitHub Actions / run-lint (3.11)

F401 'math' imported but unused
from typing import Union

import torch

from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4
from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4

logger: logging.Logger = logging.getLogger()
Expand All @@ -32,13 +33,18 @@


def fp32_to_mx4(
tensor: torch.Tensor, group_size: int = 32, use_triton: bool = True
tensor: torch.Tensor,
group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.even,
use_triton: bool = True,
) -> torch.Tensor:
"""Quantize an FP32 tensor to MX4 with triton or native cuda impl.
Args:
tensor (torch.Tensor): FP32 tensor to quantize with M total elements.
group_size (int): Compute scale in chunks of group_size.
rounding_mode (RoundingMode or int): Which type of rounding to use when computing exponent.
Only supported with use_triton=True.
use_triton (bool): If set, use triton quantization, otherwise cuda.
Return:
Expand All @@ -49,10 +55,10 @@ def fp32_to_mx4(
input = tensor.flatten()

if not tensor.is_cuda:
return py_quantize_mx4(input, group_size)
return py_quantize_mx4(input, group_size, rounding_mode=rounding_mode)

if use_triton:
return quantize_mx4(input, group_size)
return quantize_mx4(input, group_size, rounding_mode=rounding_mode)
else:
out = torch.ops.fbgemm.quantize_mx_cuda(
input,
Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
# Attempt to import triton kernels, fallback to reference if we cannot.
try:
from .quantize import (
RoundingMode,
triton_dequantize_mx4 as dequantize_mx4,
triton_quantize_mx4 as quantize_mx4,
)
except ImportError:
from .quantize_ref import ( # noqa: F401, E402
py_dequantize_mx4 as dequantize_mx4,
py_quantize_mx4 as quantize_mx4,
RoundingMode,
)
120 changes: 116 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

# pyre-unsafe
import math
from enum import IntEnum
from typing import Union

import torch
import triton # @manual
Expand All @@ -15,6 +17,81 @@
from triton import Config # @manual


class RoundingMode(IntEnum):
nearest = 0
floor = 1
even = 2
stochastic = 3
ceil = 4


@triton.jit
def _floor_log2(x):
"""Helper function to efficiently compute floor(log2(x))
Args:
x (Tensor): FP32 Input tensor to operate on.
Returns:
Tensor: Floor of log2(x).
"""
# Helpful bit constants.
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]

# View x as an integer and extract its exponent.
x = x.to(tl.int32, bitcast=True) & FP32_EXP_MASK
# Shift exponent down to bottom bits.
x = x >> FP32_EXP_OFFSET
# Remove FP32 exponent bias and return.
return (x - FP32_EXP_BIAS).to(tl.float32)


@triton.jit
def _compute_exp(
group_max,
rounding_mode,
rand_bits,
):
"""Compute shared exponent of group using specified rounding mode.
Args:
group_max (Tensor): Group of values to compute exponent of.
rounding_mode (int or RoundingMode): Which rounding mode to use.
rand_bits (int): Random integer values used for stochastic rounding.
Returns:
Tensor: Shared exponent of group.
"""
# TODO(jwfromm) Enable hip libdevice support as well.
# Nearest rounding mode.
if rounding_mode == 0:
return tl.extra.cuda.libdevice.round(tl.log2(group_max))
# Floor rounding mode. This can be done with fast bit ops.
if rounding_mode == 1:
return _floor_log2(group_max)
# Even pre-rounding mode.
elif rounding_mode == 2:
# First round to nearest even integer.
group_max = tl.extra.cuda.libdevice.rint(group_max)
# Then perform floor rounding of log.
return _floor_log2(group_max)
# Stochastic rounding mode.
elif rounding_mode == 3:
# Define constants needed for stochastic rounding.
MBITS_FP32: tl.constexpr = 23 # type: ignore[Incompatible variable type]
MBITS_E2M1: tl.constexpr = 1 # type: ignore[Incompatible variable type]
RAND_MASK: tl.constexpr = 1 << (MBITS_FP32 - MBITS_E2M1) - 1 # type: ignore[Incompatible variable type]
# Use random bits to add noise to mantissa that would otherwise
# be rounded away.
group_max = group_max.to(tl.int32, bitcast=True) + (RAND_MASK & rand_bits)
# Now compute log and truncate.
return _floor_log2(group_max)
else:
return tl.ceil(tl.log2(group_max))


@triton.autotune(
configs=[
Config({"GROUP_LOAD": 1}),
Expand All @@ -31,6 +108,8 @@ def _kernel_quantize_mx4(
out,
M,
K,
rand_bits,
ROUNDING_MODE: tl.constexpr,
GROUP_SIZE: tl.constexpr,
GROUP_LOAD: tl.constexpr,
) -> None:
Expand All @@ -42,6 +121,8 @@ def _kernel_quantize_mx4(
out (Tensor): [M / 2 + M / GROUP_SIZE] output containing packed mx4 values.
M (int): Total number of elements.
K (int): Number of elements to process in each thread.
rand_bits (Optional Tensor): [M, K / 2] random integers used for stochastic rounding.
ROUNDING_MODE (int): Which rounding method to use when calculating shared exponent.
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
GROUP_LOAD (int): Number of groups to process simultaneously.
"""
Expand Down Expand Up @@ -73,6 +154,7 @@ def _kernel_quantize_mx4(
# Initiate offset ranges used in kernel.
input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + input_start
output_offset = tl.arange(0, GROUP_LOAD * (GROUP_SIZE // 2))
rand_bits_offset = tl.arange(0, GROUP_LOAD) + pid * K // GROUP_SIZE
# We need to shift output offsets to make space for shared exponent storage.
output_offset += output_offset // (GROUP_SIZE // 2) + output_start
# Now create offsets for writing the shared exponent.
Expand All @@ -97,9 +179,18 @@ def _kernel_quantize_mx4(
group_max = tl.max(tl.abs(a_groups), axis=1)
# Prevent infinite values in log.
group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max)
# Convert max to exponent via direct log computation and ceiling
# rounding to minimize errors.
group_exp = tl.ceil(tl.log2(group_max))
# Load relevant random values if doing stochastic rounding.
if ROUNDING_MODE == 3:
group_rand_bits = tl.load(
rand_bits + rand_bits_offset,
mask=rand_bits_offset < K // GROUP_SIZE,
other=0,
)
rand_bits_offset += GROUP_LOAD
else:
group_rand_bits = None
# Compute shared exponent using specified rounding mode.
group_exp = _compute_exp(group_max, ROUNDING_MODE, group_rand_bits)
# Subtract largest exponent in target datatype and remove bias.
group_exp = group_exp - 2
# Make sure exponent is in valid range.
Expand Down Expand Up @@ -204,13 +295,19 @@ def _kernel_quantize_mx4(
output_offset += GROUP_LOAD * PACKED_GROUP_SIZE


def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
def triton_quantize_mx4(
a: torch.Tensor,
group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.even,
) -> torch.Tensor:
"""
Quantize a tensor to mx4 format using efficient triton kernels.
Args:
a (Tensor): [M] higher precision input tensor.
group_size (int): Size of chunks that will use the same shared exponent.
rounding_mode (Union[RoundingMode, int]): Which type of rounding to use
when calculating shared exponent. Defaults to pre-rounding to nearest even int.
Returns:
torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
Expand Down Expand Up @@ -252,13 +349,28 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
[a.numel() // 2 + a.numel() // group_size], device=a.device, dtype=torch.uint8
)

# If using stochastic rounding, create random noise for each group.
if rounding_mode == RoundingMode.stochastic:
# Each group will need a seed.
rand_bits = torch.randint(
low=0,
high=2**31 - 1,
size=(a.numel() // group_size,),
dtype=torch.int32,
device=a.device,
)
else:
rand_bits = None

# Invoke triton quantization kernel over rows.
grid = (M,)
_kernel_quantize_mx4[grid](
a,
out,
a.numel(),
K,
rand_bits=rand_bits,
ROUNDING_MODE=rounding_mode,
GROUP_SIZE=group_size,
)
# Inputs are now fully quantized and ready to return.
Expand Down
61 changes: 56 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,69 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from enum import IntEnum
from typing import Union

import torch


def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
class RoundingMode(IntEnum):
nearest = 0
floor = 1
even = 2
stochastic = 3
ceil = 4


def _compute_exp(
group_max,
rounding_mode,
):
"""Compute shared exponent of group using specified rounding mode.
Args:
group_max (Tensor): Group of values to compute exponent of.
rounding_mode (int or RoundingMode): Which rounding mode to use.
Returns:
Tensor: Shared exponent of group.
"""
if rounding_mode == 0:
return torch.round(torch.log2(group_max))
# Floor rounding mode.
if rounding_mode == 1:
return torch.floor(torch.log2(group_max))
# Even pre-rounding mode.
elif rounding_mode == 2:
# First round to nearest even integer.
group_max = torch.round(group_max)
# Then perform floor rounding of log.
return torch.floor(torch.log2(group_max))
# Stochastic rounding mode.
elif rounding_mode == 3:
# Create random noise.
noise = torch.rand_like(group_max)
# Add noise to group max and round down.
group_max = group_max + noise
# Now compute log and truncate.
return torch.floor(torch.log2(group_max))
else:
return torch.ceil(torch.log2(group_max))


def py_quantize_mx4(
a: torch.Tensor,
group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.even,
) -> torch.Tensor:
"""
Quantize a tensor to mx4 format.
Args:
a (Tensor): [M] higher precision input tensor.
group_size (int): Size of chunks that will use the same shared exponent.
rounding_mode (int or RoundingMode): Which type of rounding to use when
calculating shared exponent.
Returns:
torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
Expand Down Expand Up @@ -43,10 +96,8 @@ def py_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
# Replace zero values with the minimum expressible normal value.
FP32_MIN_NORMAL = 2 ** (-126)
shared_exp = torch.where(shared_exp == 0, FP32_MIN_NORMAL, shared_exp)
# Convert max into an intger exponent.
# Note this can be more efficient by just shifting and masking exp bits.
# We can even use those directly.
shared_exp = torch.ceil(torch.log2(shared_exp))
# Convert max into an integer exponent.
shared_exp = _compute_exp(shared_exp, rounding_mode)
# Offset exponent by largest exponent in target datatype.
shared_exp = shared_exp - 2
# Restrict to range expressible as int8.
Expand Down
Loading

0 comments on commit 1a56f8a

Please sign in to comment.