Skip to content

Commit

Permalink
Optimize FBGEMM Triton MX4 Dequantize (#2837)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2837

We previously had to use python to unravel values from exponents and feed them to triton as two separate tensors. This introduced a lot of overhead as it introduced large copies.

This diff does a bunch of fancy indexing to directly operate on a tensor with mixed elements and exponents. The result is that triton dequantize is now slightly faster than the cuda kernel. My hope is that this allows us to standardize on a single implementation.

I think we could probably do something similar during quantize to get a significant speedup as well.

```
INFO:root:input size: 1073741824 group size: 32
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 MX4 quantized time per iter: 7563us
input_size=1073741824 MX4 dequantized time per iter: 2756us
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 MX4 triton quantized time per iter: 5110us
input_size=1073741824 MX4 triton dequantized time per iter: 2417us
INFO:root:Start to benchmark ...
INFO:root:Start to benchmark ...
input_size=1073741824 FP8 quantized time per iter: 6274us
input_size=1073741824 FP8 dequantized time per iter: 4223us
```

Reviewed By: sryap

Differential Revision: D59661776
  • Loading branch information
jwfromm authored and facebook-github-bot committed Jul 12, 2024
1 parent a7cd500 commit 2e38cc2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 38 deletions.
88 changes: 50 additions & 38 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,27 +287,23 @@ def triton_quantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:

@triton.autotune(
configs=[
Config({"BLOCK_SIZE": 32}),
Config({"BLOCK_SIZE": 64}),
Config({"BLOCK_SIZE": 512}),
Config({"BLOCK_SIZE": 1024}),
Config({"BLOCK_SIZE": 2048}),
Config({"BLOCK_SIZE": 4096}),
Config({"BLOCK_SIZE": 8192}),
Config({"GROUP_LOAD": 1}),
Config({"GROUP_LOAD": 4}),
Config({"GROUP_LOAD": 8}),
Config({"GROUP_LOAD": 16}),
Config({"GROUP_LOAD": 32}),
],
key=["K"],
prune_configs_by={"early_config_prune": prune_configs},
)
@triton.jit
def _kernel_dequantize_mx4(
A,
shared_exp,
mx4_lookup_table,
out,
M,
K,
GROUP_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
GROUP_LOAD: tl.constexpr,
) -> None:
"""Dequantize a packed MX4 tensor and apply scaling.
Expand All @@ -318,23 +314,36 @@ def _kernel_dequantize_mx4(
M (int): Total number of elements in input.
K (int): Number of elements each thread should operate on.
GROUP_SIZE (int): Size of chunks that use the same shared exponent.
BLOCK_SIZE (int): Size of each block.
GROUP_LOAD (int): Number of groups to process simultaneously.
"""
# Define constants.
MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type]
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
PACKED_GROUP_SIZE: tl.constexpr = GROUP_SIZE // 2 + 1 # type: ignore[Incompatible variable type]
OUTPUT_LIMIT: tl.constexpr = (K // PACKED_GROUP_SIZE) * GROUP_SIZE # type: ignore[Incompatible variable type]
OUTPUT_SIZE: tl.constexpr = (M // PACKED_GROUP_SIZE) * GROUP_SIZE # type: ignore[Incompatible variable type]

# Get the current thread number.
pid = tl.program_id(0)
# Find the starting offsets for this thread.
input_start = pid * K
output_start = pid * K * 2
exp_start = input_start + GROUP_SIZE // 2
# Remove shared exponents from output offset.
output_start = pid * GROUP_SIZE * (K // PACKED_GROUP_SIZE)
# Initiate offset ranges used in this thread.
input_offset = tl.arange(0, BLOCK_SIZE) + input_start
output_offset = tl.arange(0, 2 * BLOCK_SIZE) + output_start

# Define constants.
MX4_BIT_MASK: tl.constexpr = 0xF # type: ignore[Incompatible variable type]
FP32_EXP_BIAS: tl.constexpr = 127 # type: ignore[Incompatible variable type]
# This is a little complicated because we need to skip one value (the shared exponent)
# every group_size elements.
input_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE // 2)
# Add 1 every GROUP_SIZE / 2 steps so we skip shared exponent.
exp_indices = input_offset // (GROUP_SIZE // 2)
input_offset = input_offset + exp_indices + input_start
# We need to space out each group of the input by 1 since thats the shared exp.
output_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE) + output_start
# Stride exponent access across packed groups.
exp_offset = exp_indices * PACKED_GROUP_SIZE + exp_start

# Iterate over input tensor and unpack mx4 values.
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
for _k in range(0, tl.cdiv(K, GROUP_LOAD * PACKED_GROUP_SIZE)):
a = tl.load(
A + input_offset,
# Mask values that are out of this chunk or the main array.
Expand All @@ -350,8 +359,11 @@ def _kernel_dequantize_mx4(
high_fp32 = tl.load(mx4_lookup_table + high_mx4)

# Get proper shared exponent and convert it to a float scale.
group_offset = (2 * input_offset) // GROUP_SIZE
exp = tl.load(shared_exp + group_offset)
exp = tl.load(
A + exp_offset,
mask=(exp_offset < M) & (exp_offset < (K * (pid + 1))),
other=0.0,
)
# Remove fp32 exponent bias.
exp = exp.to(tl.uint8, bitcast=True) - FP32_EXP_BIAS

Expand All @@ -370,11 +382,14 @@ def _kernel_dequantize_mx4(
out + output_offset,
scaled_fp32,
# Mask values that are out of this chunk or the main array.
mask=(output_offset < 2 * M) & (output_offset < ((2 * K) * (pid + 1))),
mask=(output_offset < OUTPUT_SIZE)
& (output_offset < OUTPUT_LIMIT * (pid + 1)),
)

input_offset += BLOCK_SIZE
output_offset += 2 * BLOCK_SIZE
# Update indices for next group.
input_offset += GROUP_LOAD * PACKED_GROUP_SIZE
exp_offset += GROUP_LOAD * PACKED_GROUP_SIZE
output_offset += GROUP_LOAD * GROUP_SIZE


def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor:
Expand All @@ -394,25 +409,21 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor
return torch.empty(a.shape, device=a.device, dtype=torch.float32)
# View a as 2D for simplicity.
orig_shape = a.shape
a = a.flatten()
# Find number of groups.
packed_group_size = group_size // 2 + 1
# Unravel packed inputs from shared exponents.
packed_group_size = group_size // 2
a = a.view(-1, packed_group_size + 1)
packed_input = a[:, :-1]
shared_exp = a[:, -1:]
# Find a shape that distributes work evenly over threads.
# We do this by finding the power of two that is closest to
# the sqrt of the number of elements.
num_threads = int(2 ** round(math.log2(math.sqrt(packed_input.numel()))))
num_threads = int(2 ** round(math.log2(math.sqrt(a.numel()))))
# Make sure that the number of elements per row is a multiple of packed group_size.
K = packed_input.numel() // num_threads
K = a.numel() // num_threads
K = (K // packed_group_size) * packed_group_size
if K == 0:
K = packed_group_size
# Try to evenly divide input into chunks of size K, allow last chunk to be smaller.
M = int(math.ceil(packed_input.numel() / K))
# Flatten inputs.
packed_input = packed_input.flatten().contiguous()
shared_exp = shared_exp.flatten().contiguous()
M = int(math.ceil(a.numel() / K))

# Use a lookup table to convert
mx4_to_fp_values = torch.tensor(
Expand All @@ -422,15 +433,16 @@ def triton_dequantize_mx4(a: torch.Tensor, group_size: int = 32) -> torch.Tensor
)

# Create output tensor.
out = torch.empty([2 * packed_input.numel()], device=a.device, dtype=torch.float)
num_groups = a.numel() // packed_group_size
output_elems = num_groups * group_size
out = torch.empty([output_elems], device=a.device, dtype=torch.float)
# Invoke triton dequantization kernel over rows.
grid = (M,)
_kernel_dequantize_mx4[grid](
packed_input,
shared_exp,
a,
mx4_to_fp_values,
out,
packed_input.numel(),
a.numel(),
K,
GROUP_SIZE=group_size,
)
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def test_mx4(self, power: int, sizes: int) -> None:
)

check_diff_quantize(input, output_ref, output_cuda)
check_diff_quantize(input, output_cuda, output_triton)
check_diff_quantize(input, output_cuda, output_cuda_from_quantized_triton)
check_diff_quantize(input, output_cuda_from_quantized_triton, output_triton)
check_diff_quantize(input, output_triton, output_cpu)
Expand Down

0 comments on commit 2e38cc2

Please sign in to comment.