Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add separate quantization primitives for float8 #1597

Merged
merged 12 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,21 @@
import unittest

import torch
from parameterized import parameterized

from torchao.dtypes.utils import is_device
from torchao.float8.float8_utils import EPS as float8_eps
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
dequantize_affine,
dequantize_affine_float8,
fake_quantize_affine,
fake_quantize_affine_cachemask,
quantize_affine,
quantize_affine_float8,
)

# TODO: remove test for utils?
Expand Down Expand Up @@ -838,6 +843,71 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

@parameterized.expand(
[
(
torch.float32,
torch.float8_e4m3fn,
),
(
torch.float32,
torch.float8_e5m2,
),
(
torch.bfloat16,
torch.float8_e4m3fn,
),
(
torch.bfloat16,
torch.float8_e5m2,
),
]
)
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
input = torch.randn(10, 10)

# float8 quantization primitives
scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype)
quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype)
dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype)

# reference implementation using generic primitives
expected_scale, _ = choose_qparams_affine(
input,
MappingType.SYMMETRIC,
input.shape,
float8_dtype,
eps=float8_eps, # use same EPS as float8 training
scale_dtype=torch.float32,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
)
expected_quantized = quantize_affine(
input,
input.shape,
scale,
output_dtype=float8_dtype,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=None,
)
expected_dequantized = dequantize_affine(
expected_quantized,
input.shape,
scale,
input_dtype=float8_dtype,
output_dtype=hp_dtype,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=None,
)

self.assertTrue(torch.equal(expected_scale, scale))
torch.testing.assert_close(expected_quantized, quantized)
torch.testing.assert_close(expected_dequantized, dequantized)


if __name__ == "__main__":
unittest.main()
67 changes: 67 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"choose_qparams_affine_float8",
"quantize_affine_float8",
"dequantize_affine_float8",
]


Expand Down Expand Up @@ -1300,3 +1303,67 @@ def dequantize_affine_floatx(
tensor = tensor * scale.float().view(-1, 1)
tensor = tensor.to(dtype=output_dtype)
return tensor


def choose_qparams_affine_float8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg Does these look good?

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively I could just make these functions wrappers around the generic primitives which pass in the appropriate params for float8, as that may be more maintainable - although it does add a step of indirection and hides how the scale is actually computed. Any thoughts?

For example:

def choose_qparams_float8(input: torch.Tensor, float8_dtype: torch.dtype):
    scale, _ = choose_qparams_affine(
            input,
            MappingType.SYMMETRIC,
            input.shape,     # only tensorwise scaling is supported at the moment
            float8_dtype,
            eps=float8_eps,  # use same EPS as float8 training
            scale_dtype=torch.float32,
            quant_min=torch.finfo(float8_dtype).min,
            quant_max=torch.finfo(float8_dtype).max,
        )
    return scale

tensor: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.

Args:
tensor (torch.Tensor): Input tensor to be quantized.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# only tensorwise scaling is supported for now:
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
min_val_neg = torch.min(tensor)
max_val_pos = torch.max(tensor)
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
return scale.to(dtype=torch.float32)


def quantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.

Args:
tensor (torch.Tensor): Input tensor to be quantized.
scale (torch.Tensor): Scaling factor for the quantization.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) / scale
max_value = torch.finfo(float8_dtype).max
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
fp8_tensor = tensor_clamped.to(float8_dtype)
return fp8_tensor


def dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to high precision tensor.

Args:
tensor (torch.Tensor): Input float8 tensor to be dequantized.
scale (torch.Tensor): Scaling factor for the dequantization.
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
fp8_tensor = tensor.to(torch.float32)
hp_tensor = fp8_tensor * scale
return hp_tensor.to(output_dtype)
Loading