Skip to content

Commit

Permalink
Triton MX4 Quantize Rounding Mode Support. (#2821)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2821

X-link: facebookresearch/FBGEMM#22

This diff adds the `rounding_mode` argument to triton quantize. We support all the rounding described in the [best practices doc](https://docs.google.com/document/d/156Du0hBRH6umG_i-OrYC574XhpQMUU5SJYG0RTS2tTg/edit#heading=h.akfcp7xpg8cr).

Reviewed By: summerdengfb

Differential Revision: D59562029
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jul 12, 2024
1 parent a14b630 commit 573b3b0
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 23 deletions.
3 changes: 2 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/quantize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def op_registeration(
int elem_ebits,
int elem_mbits,
float elem_max_norm,
int mx_group_size
int mx_group_size,
int? rounding_mode = None
) -> Tensor
"""
)
Expand Down
11 changes: 8 additions & 3 deletions fbgemm_gpu/fbgemm_gpu/quantize/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# # pyre-unsafe
# pyre-unsafe
from typing import Union

import torch

from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32
from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode


def quantize_mx(
Expand All @@ -18,6 +19,7 @@ def quantize_mx(
elem_mbits: int = 3,
elem_max_norm: float = 6.0,
mx_group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil,
) -> torch.Tensor:
"""
Registered quantize_mx ops for E2E comm.
Expand All @@ -31,14 +33,17 @@ def quantize_mx(
i.e., 3 for MX4 e2m1)
elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1)
mx_group_size: num elements that share the max shared_exponent
rounding_mode: Which type of rounding to use when calculating shared exponent.
Return:
output: MX4 tensor packed into int8 values with size
(total_elems / 2 + total_elems / groupsize)
the shared exponent of each group is stored at the last byte
of output of each group
"""
return fp32_to_mx4(input, mx_group_size, use_triton=True)
return fp32_to_mx4(
input, mx_group_size, rounding_mode=rounding_mode, use_triton=True
)


def dequantize_mx(
Expand Down
19 changes: 14 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
# pyre-strict

import logging
from typing import Optional, Union

import torch

from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4
from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4

logger: logging.Logger = logging.getLogger()
Expand All @@ -31,13 +32,18 @@


def fp32_to_mx4(
tensor: torch.Tensor, group_size: int = 32, use_triton: bool = True
tensor: torch.Tensor,
group_size: int = 32,
rounding_mode: Optional[Union[RoundingMode, int]] = RoundingMode.ceil,
use_triton: bool = True,
) -> torch.Tensor:
"""Quantize an FP32 tensor to MX4 with triton or native cuda impl.
Args:
tensor (torch.Tensor): FP32 tensor to quantize with M total elements.
group_size (int): Compute scale in chunks of group_size.
rounding_mode (RoundingMode or int): Which type of rounding to use when computing exponent.
Only supported with use_triton=True.
use_triton (bool): If set, use triton quantization, otherwise cuda.
Return:
Expand All @@ -47,11 +53,14 @@ def fp32_to_mx4(
# Operate on flattened input.
input = tensor.flatten()

if rounding_mode is None:
rounding_mode = RoundingMode.ceil

if not tensor.is_cuda:
return py_quantize_mx4(input, group_size)
return py_quantize_mx4(input, group_size, rounding_mode=rounding_mode)

if use_triton:
return quantize_mx4(input, group_size)
return quantize_mx4(input, group_size, rounding_mode=rounding_mode)
else:
out = torch.ops.fbgemm.quantize_mx_cuda(
input,
Expand All @@ -67,7 +76,7 @@ def fp32_to_mx4(


def mx4_to_fp32(
tensor: torch.Tensor, group_size: int = 32, use_triton: bool = False
tensor: torch.Tensor, group_size: int = 32, use_triton: bool = True
) -> torch.Tensor:
"""Dequantize an MX4 tensor to FP32 with triton or native cuda impl.
Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# pyre-unsafe

# Attempt to import triton kernels, fallback to reference if we cannot.
from .common import RoundingMode # noqa

try:
from .quantize import (
triton_dequantize_mx4 as dequantize_mx4,
Expand Down
19 changes: 19 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/triton/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from enum import IntEnum


class RoundingMode(IntEnum):
"""Rounding options for quantization."""

nearest = 0
floor = 1
even = 2
stochastic = 3
ceil = 4
117 changes: 113 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,87 @@

# pyre-unsafe
import math
from typing import Union

import torch
import triton # @manual

import triton.language as tl # @manual
from triton import Config # @manual

from .common import RoundingMode


@triton.jit
def _floor_log2(x):
"""Helper function to efficiently compute floor(log2(x))
Args:
x (Tensor): FP32 Input tensor to operate on.
Returns:
Tensor: Floor of log2(x).
"""
# Helpful bit constants.
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]

# View x as an integer and extract its exponent.
x = x.to(tl.int32, bitcast=True) & FP32_EXP_MASK
# Shift exponent down to bottom bits.
x = x >> FP32_EXP_OFFSET
# Remove FP32 exponent bias and return.
return (x - FP32_EXP_BIAS).to(tl.float32)


@triton.jit
def _compute_exp(
group_max,
rounding_mode,
rand_bits,
):
"""Compute shared exponent of group using specified rounding mode.
Args:
group_max (Tensor): Group of values to compute exponent of.
rounding_mode (int or RoundingMode): Which rounding mode to use.
rand_bits (int): Random integer values used for stochastic rounding.
Returns:
Tensor: Shared exponent of group.
"""
# TODO(jwfromm) Enable hip libdevice support as well.
# Define some helpful constants.
MBITS_FP32: tl.constexpr = 23 # type: ignore[Incompatible variable type]
MBITS_E2M1: tl.constexpr = 1 # type: ignore[Incompatible variable type]
# Nearest rounding mode.
if rounding_mode == 0:
return tl.extra.cuda.libdevice.round(tl.log2(group_max))
# Floor rounding mode. This can be done with fast bit ops.
if rounding_mode == 1:
return _floor_log2(group_max)
# Even pre-rounding mode.
elif rounding_mode == 2:
# Add fixed amount of rounding to mantissa so that they are clipped
# to the closest integer.
M_ROUND: tl.constexpr = (1 << (MBITS_FP32 - MBITS_E2M1 - 1)) - 1
# Add them to the mantissa bits of the input to round during truncation.
group_max = group_max.to(tl.int32, bitcast=True) + M_ROUND
# Then perform floor rounding of log.
return _floor_log2(group_max)
# Stochastic rounding mode.
elif rounding_mode == 3:
# Define constants needed for stochastic rounding.
RAND_MASK: tl.constexpr = 1 << (MBITS_FP32 - MBITS_E2M1) - 1 # type: ignore[Incompatible variable type]
# Use random bits to add noise to mantissa that would otherwise
# be rounded away.
group_max = group_max.to(tl.int32, bitcast=True) + (RAND_MASK & rand_bits)
# Now compute log and truncate.
return _floor_log2(group_max)
else:
return tl.ceil(tl.log2(group_max))


@triton.autotune(
configs=[
Expand All @@ -31,6 +105,8 @@ def _kernel_quantize_mx4(
out,
M,
K,
rand_bits,
ROUNDING_MODE: tl.constexpr,
GROUP_SIZE: tl.constexpr,
GROUP_LOAD: tl.constexpr,
) -> None:
Expand All @@ -42,6 +118,8 @@ def _kernel_quantize_mx4(
out (Tensor): [M / 2 + M / GROUP_SIZE] output containing packed mx4 values.
M (int): Total number of elements.
K (int): Number of elements to process in each thread.
rand_bits (Optional Tensor): [M, K / 2] random integers used for stochastic rounding.
ROUNDING_MODE (int): Which rounding method to use when calculating shared exponent.
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
GROUP_LOAD (int): Number of groups to process simultaneously.
"""
Expand Down Expand Up @@ -73,6 +151,7 @@ def _kernel_quantize_mx4(
# Initiate offset ranges used in kernel.
input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + input_start
output_offset = tl.arange(0, GROUP_LOAD * (GROUP_SIZE // 2))
rand_bits_offset = tl.arange(0, GROUP_LOAD) + pid * K // GROUP_SIZE
# We need to shift output offsets to make space for shared exponent storage.
output_offset += output_offset // (GROUP_SIZE // 2) + output_start
# Now create offsets for writing the shared exponent.
Expand All @@ -97,9 +176,18 @@ def _kernel_quantize_mx4(
group_max = tl.max(tl.abs(a_groups), axis=1)
# Prevent infinite values in log.
group_max = tl.where(group_max == 0, FP32_MIN_NORMAL, group_max)
# Convert max to exponent via direct log computation and ceiling
# rounding to minimize errors.
group_exp = tl.ceil(tl.log2(group_max))
# Load relevant random values if doing stochastic rounding.
if ROUNDING_MODE == 3:
group_rand_bits = tl.load(
rand_bits + rand_bits_offset,
mask=rand_bits_offset < K // GROUP_SIZE,
other=0,
)
rand_bits_offset += GROUP_LOAD
else:
group_rand_bits = None
# Compute shared exponent using specified rounding mode.
group_exp = _compute_exp(group_max, ROUNDING_MODE, group_rand_bits)
# Subtract largest exponent in target datatype and remove bias.
group_exp = group_exp - 2
# Make sure exponent is in valid range.
Expand Down Expand Up @@ -204,13 +292,19 @@ def _kernel_quantize_mx4(
output_offset += GROUP_LOAD * PACKED_GROUP_SIZE


def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
def triton_quantize_mx4(
a: torch.Tensor,
group_size: int = 32,
rounding_mode: Union[RoundingMode, int] = RoundingMode.ceil,
) -> torch.Tensor:
"""
Quantize a tensor to mx4 format using efficient triton kernels.
Args:
a (Tensor): [M] higher precision input tensor.
group_size (int): Size of chunks that will use the same shared exponent.
rounding_mode (Union[RoundingMode, int]): Which type of rounding to use
when calculating shared exponent. Defaults to pre-rounding to nearest even int.
Returns:
torch.Tensor: [M / 2 + M / group_size] mx4 scaled tensor packed into in8
Expand Down Expand Up @@ -252,13 +346,28 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
[a.numel() // 2 + a.numel() // group_size], device=a.device, dtype=torch.uint8
)

# If using stochastic rounding, create random noise for each group.
if rounding_mode == RoundingMode.stochastic:
# Each group will need a seed.
rand_bits = torch.randint(
low=0,
high=2**31 - 1,
size=(a.numel() // group_size,),
dtype=torch.int32,
device=a.device,
)
else:
rand_bits = None

# Invoke triton quantization kernel over rows.
grid = (M,)
_kernel_quantize_mx4[grid](
a,
out,
a.numel(),
K,
rand_bits=rand_bits,
ROUNDING_MODE=rounding_mode,
GROUP_SIZE=group_size,
)
# Inputs are now fully quantized and ready to return.
Expand Down
Loading

0 comments on commit 573b3b0

Please sign in to comment.