Skip to content

Commit

Permalink
Optimize FBGEMM Triton MX4 Quantize
Browse files Browse the repository at this point in the history
Summary:
We apply a similar technique as we did to dequantize in D59661776 to MX4 quantization. Specifically we do fancy indexing to be able to write both exponents and values to the same output tensor within the triton kernel. This allows us to only allocate a single output and do no extra copies, giving a sizeable 40% performance boost.

Before this change:
```
INFO:root:input size: 1073741824 group size: 32
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 MX4 quantized time per iter: 7563us
input_size=1073741824 MX4 dequantized time per iter: 2756us
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 MX4 triton quantized time per iter: 5110us
input_size=1073741824 MX4 triton dequantized time per iter: 2417us
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 FP8 quantized time per iter: 6274us
input_size=1073741824 FP8 dequantized time per iter: 4223us
```

After this change:
```
INFO:root:input size: 1073741824 group size: 32
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 MX4 quantized time per iter: 7560us
input_size=1073741824 MX4 dequantized time per iter: 2758us
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 MX4 triton quantized time per iter: 3138us
input_size=1073741824 MX4 triton dequantized time per iter: 2418us
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 FP8 quantized time per iter: 6274us
input_size=1073741824 FP8 dequantized time per iter: 4226us
```

Differential Revision: D59688150
  • Loading branch information
Josh Fromm authored and facebook-github-bot committed Jul 12, 2024
1 parent 4b0534e commit 973c0f1
Showing 1 changed file with 49 additions and 63 deletions.
112 changes: 49 additions & 63 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,64 +15,36 @@
from triton import Config # @manual


def prune_configs(configs, named_args, **kwargs):
"""Helper function to remove invalid configurations."""
group_size = kwargs["GROUP_SIZE"]
pruned_configs = []
for config in configs:
block_size = config.kwargs["BLOCK_SIZE"]
# Dont use block sizes that are smaller than the group size.
if group_size <= block_size:
pruned_configs.append(config)
# Return only the valid configurations.
return pruned_configs


@triton.autotune(
configs=[
Config({"BLOCK_SIZE": 32}),
Config({"BLOCK_SIZE": 64}),
Config({"BLOCK_SIZE": 512}),
Config({"BLOCK_SIZE": 1024}),
Config({"BLOCK_SIZE": 2048}),
Config({"BLOCK_SIZE": 4096}),
Config({"BLOCK_SIZE": 8192}),
Config({"GROUP_LOAD": 1}),
Config({"GROUP_LOAD": 4}),
Config({"GROUP_LOAD": 8}),
Config({"GROUP_LOAD": 16}),
Config({"GROUP_LOAD": 32}),
],
key=["K"],
prune_configs_by={"early_config_prune": prune_configs},
)
@triton.jit
def _kernel_quantize_mx4(
A,
shared_exp,
out,
M,
K,
GROUP_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
GROUP_LOAD: tl.constexpr,
) -> None:
"""Quantize a 1D float tensor into a packed MX4 tensor.
Args:
A (Tensor): [M] float tensor to be quantized.
shared_exp (Tensor): [M / group_size] output containing shared exponent.
out (Tensor): [M / 2] output containing packed mx4 values.
out (Tensor): [M / 2 + M / GROUP_SIZE] output containing packed mx4 values.
M (int): Total number of elements.
K (int): Number of elements to process in each thread.
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
BLOCK_SIZE (int): Size of each block.
GROUP_LOAD (int): Number of groups to process simultaneously.
"""
# Get the current thread number.
pid = tl.program_id(0)
# Find starting offsets for this thread.
input_start = pid * K
packed_start = pid * K // 2
group_start = pid * K // GROUP_SIZE
# Initiate offset ranges used in kernel.
input_offset = tl.arange(0, BLOCK_SIZE) + input_start
packed_offset = tl.arange(0, BLOCK_SIZE // 2) + packed_start
group_offset = tl.arange(0, BLOCK_SIZE // GROUP_SIZE) + group_start

# Define Constant Expressions.
FP32_EXP_MASK: tl.constexpr = 0x7F800000 # type: ignore[Incompatible variable type]
FP32_EXP_OFFSET: tl.constexpr = 23 # type: ignore[Incompatible variable type]
Expand All @@ -87,9 +59,27 @@ def _kernel_quantize_mx4(
IMPLIED_1_BIT: tl.constexpr = 1 << 23 # type: ignore[Incompatible variable type]
OVERFLOW_THRESHOLD: tl.constexpr = 4 # type: ignore[Incompatible variable type]
FP32_MIN_NORMAL: tl.constexpr = 2 ** (-126) # type: ignore[Incompatible variable type]
# Boundaries for writing to output tensor.
OUTPUT_LIMIT: tl.constexpr = K // 2 + K // GROUP_SIZE # type: ignore[Incompatible variable type]
OUTPUT_SIZE: tl.constexpr = M // 2 + M // GROUP_SIZE # type: ignore[Incompatible variable type]
PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type]

# Get the current thread number.
pid = tl.program_id(0)
# Find starting offsets for this thread.
input_start = pid * K
output_start = pid * (K // 2 + K // GROUP_SIZE)
exp_start = output_start + GROUP_SIZE // 2
# Initiate offset ranges used in kernel.
input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + input_start
output_offset = tl.arange(0, GROUP_LOAD * (GROUP_SIZE // 2))
# We need to shift output offsets to make space for shared exponent storage.
output_offset += output_offset // (GROUP_SIZE // 2) + output_start
# Now create offsets for writing the shared exponent.
exp_offset = tl.arange(0, GROUP_LOAD) * PACKED_GROUP_SIZE + exp_start

# Load and process blocks of values for this chunk.
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
for _k in range(0, tl.cdiv(K, GROUP_LOAD * GROUP_SIZE)):
# Load a block of values.
a = tl.load(
A + input_offset,
Expand All @@ -102,7 +92,7 @@ def _kernel_quantize_mx4(
##############

# View the block in terms of groups.
a_groups = tl.reshape(a, [BLOCK_SIZE // GROUP_SIZE, GROUP_SIZE])
a_groups = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
# Compute the shared exponent of each group.
group_max = tl.max(tl.abs(a_groups), axis=1)
# Prevent infinite values in log.
Expand All @@ -118,20 +108,19 @@ def _kernel_quantize_mx4(
# Next we scale A in preparation for quantization.
scale = tl.exp2(group_exp.to(tl.float64)).to(tl.float32)
# Apply scale to input. We do this by broadcasting scale.
scaled_a = tl.reshape(a, [BLOCK_SIZE // GROUP_SIZE, GROUP_SIZE]) / tl.reshape(
scale, [BLOCK_SIZE // GROUP_SIZE, 1]
scaled_a = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE]) / tl.reshape(
scale, [GROUP_LOAD, 1]
)
# Reshape back to a flat array.
scaled_a = tl.reshape(scaled_a, [BLOCK_SIZE])
scaled_a = tl.reshape(scaled_a, [GROUP_LOAD * GROUP_SIZE])

# We're done with group_exp now so we can write it out.
# We readd fp32_exp_bias for compatibility with cuda dequant.
tl.store(
shared_exp + group_offset,
out + exp_offset,
(group_exp + FP32_EXP_BIAS).to(tl.int8),
# Prevent writing outside this chunk or the main array.
mask=(group_offset < M // GROUP_SIZE)
& (group_offset < ((K // GROUP_SIZE) * (pid + 1))),
mask=(exp_offset < OUTPUT_SIZE) & (exp_offset < (OUTPUT_LIMIT * (pid + 1))),
)

# Quantization step
Expand Down Expand Up @@ -194,22 +183,25 @@ def _kernel_quantize_mx4(
mx4_value = (sign_bit << 3) | mx4_value

# Extract low and high bits from values.
low_mx4, high_mx4 = tl.split(tl.reshape(mx4_value, [BLOCK_SIZE // 2, 2]))
low_mx4, high_mx4 = tl.split(
tl.reshape(mx4_value, [(GROUP_LOAD * GROUP_SIZE) // 2, 2])
)
# Shift mx4 values together so they are packed into int8.
packed_mx4 = ((high_mx4 << 4) | (low_mx4)).to(tl.int8)

# Next step is packing, lets write this out to check how it looks.
# Write out packed values to output tensor.
tl.store(
out + packed_offset,
out + output_offset,
packed_mx4,
# Prevent writing outside this chunk or the main array.
mask=(packed_offset < M // 2) & (packed_offset < ((K // 2) * (pid + 1))),
mask=(output_offset < OUTPUT_SIZE)
& (output_offset < (OUTPUT_LIMIT * (pid + 1))),
)

# Update offsets so we work on the next block.
input_offset += BLOCK_SIZE
group_offset += BLOCK_SIZE // GROUP_SIZE
packed_offset += BLOCK_SIZE // 2
input_offset += GROUP_LOAD * GROUP_SIZE
exp_offset += GROUP_LOAD * PACKED_GROUP_SIZE
output_offset += GROUP_LOAD * PACKED_GROUP_SIZE


def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
Expand Down Expand Up @@ -255,34 +247,28 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
# Flatten input.
a = a.flatten()

# Create output tensors.
shared_exp = torch.empty(
[a.numel() // group_size], device=a.device, dtype=torch.uint8
# Create output tensor.
out = torch.empty(
[a.numel() // 2 + a.numel() // group_size], device=a.device, dtype=torch.uint8
)
out = torch.empty([a.numel() // 2], device=a.device, dtype=torch.uint8)

# Invoke triton quantization kernel over rows.
grid = (M,)
_kernel_quantize_mx4[grid](
a,
shared_exp,
out,
a.numel(),
K,
GROUP_SIZE=group_size,
)
# Ravel together output and shared exponent.
packed_mx4 = torch.concat(
[out.view(-1, group_size // 2), shared_exp.view(-1, 1)], dim=1
)
# Inputs are now fully quantized and ready to return.
# Try to return in the original shape if possible.
if orig_shape[-1] % group_size == 0:
output_shape = list(orig_shape[:-1]) + [-1]
return packed_mx4.view(output_shape)
return out.view(output_shape)
# If we cant, return as a flat array.
else:
return packed_mx4.view(-1)
return out.view(-1)


@triton.autotune(
Expand Down Expand Up @@ -320,6 +306,7 @@ def _kernel_dequantize_mx4(
MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type]
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type]
# Boundaries for writing to output tensor.
OUTPUT_LIMIT: tl.constexpr = (K // PACKED_GROUP_SIZE) * GROUP_SIZE # type: ignore[Incompatible variable type]
OUTPUT_SIZE: tl.constexpr = (M // PACKED_GROUP_SIZE) * GROUP_SIZE # type: ignore[Incompatible variable type]

Expand Down Expand Up @@ -412,7 +399,6 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor
a = a.flatten()
# Find number of groups.
packed_group_size = group_size // 2 + 1
# Unravel packed inputs from shared exponents.
# Find a shape that distributes work evenly over threads.
# We do this by finding the power of two that is closest to
# the sqrt of the number of elements.
Expand Down

0 comments on commit 973c0f1

Please sign in to comment.