-
Notifications
You must be signed in to change notification settings - Fork 511
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MX4 ops front-end API - rebase+fix comm test
Differential Revision: D58627286
- Loading branch information
1 parent
b903979
commit df1a7bd
Showing
5 changed files
with
157 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
import torch | ||
|
||
from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32 | ||
|
||
lib = torch.library.Library("fbgemm", "FRAGMENT") | ||
lib.define( | ||
""" | ||
quantize_mx( | ||
Tensor input, | ||
int scale_bits, | ||
int elem_ebits, | ||
int elem_mbits, | ||
float elem_max_norm, | ||
int mx_group_size | ||
) -> Tensor | ||
""" | ||
) | ||
|
||
lib.define( | ||
""" | ||
dequantize_mx( | ||
Tensor input, | ||
int mx_group_size | ||
) -> Tensor | ||
""" | ||
) | ||
|
||
|
||
@torch.library.impl(lib, "quantize_mx", "CPU") | ||
@torch.library.impl(lib, "quantize_mx", "CUDA") | ||
def quantize_mx( | ||
input: torch.Tensor, | ||
scale_bits: int = 8, | ||
elem_ebits: int = 2, | ||
elem_mbits: int = 3, | ||
elem_max_norm: float = 6.0, | ||
mx_group_size: int = 32, | ||
) -> torch.Tensor: | ||
""" | ||
Registered quantize_mx ops for E2E comm | ||
We use Triton implementation for quantization | ||
Args: | ||
input: FP32 tensor of size total_elems to be quantized | ||
scale_bits: num bits of the shared exponent (i.e., 8 for MX4 e2m1) | ||
elem_ebits: num bits of the exponent (i.e., 2 for MX4 e2m1) | ||
elem_mbits: num bits of the mantissa incl. sign and implicit bits ( | ||
i.e., 3 for MX4 e2m1) | ||
elem_max_norm: max value of the float (i.e., 6.0 for MX4 e2m1) | ||
mx_group_size: num elements that share the max shared_exponent | ||
Return: | ||
output: MX4 tensor packed into int8 values with size | ||
(total_elems / 2 + total_elems / groupsize) | ||
the shared exponent of each group is stored at the last byte | ||
of output of each group | ||
""" | ||
return fp32_to_mx4(input, mx_group_size, use_triton=True) | ||
|
||
|
||
@torch.library.impl(lib, "dequantize_mx", "CPU") | ||
@torch.library.impl(lib, "dequantize_mx", "CUDA") | ||
def dequantize_mx( | ||
input: torch.Tensor, | ||
mx_group_size: int = 32, | ||
) -> torch.Tensor: | ||
""" | ||
Registered dequantize_mx ops for E2E comm | ||
We use CUDA implementation for quantization | ||
Args: | ||
input: FP8 tensor (MX4 packed in FP8) | ||
mx_group_size: number of elements that shares the same max shared_exponent | ||
Return: | ||
output: FP32 tensor with total elements (total_elems) | ||
""" | ||
return mx4_to_fp32(input, mx_group_size, use_triton=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters