From 96d67346b838b547501502454bf2f8091d9fc907 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 22 Jan 2025 20:35:41 -0800 Subject: [PATCH] replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs ghstack-source-id: 3028fc5f84252f60353df9144ce3fda62b26fe8c ghstack-comment-id: 2608105249 Pull Request resolved: https://github.com/pytorch/ao/pull/1599 --- docs/source/api_ref_dtypes.rst | 2 +- torchao/dtypes/__init__.py | 13 +++---- torchao/dtypes/affine_quantized_tensor.py | 13 ++++--- .../prototype/quantization/autoquant_v2.py | 20 ++++------- torchao/quantization/autoquant.py | 20 ++++------- torchao/quantization/quant_api.py | 35 +++++++------------ 6 files changed, 35 insertions(+), 68 deletions(-) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index fbe680953e..c37c3a81ce 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -13,7 +13,7 @@ torchao.dtypes to_nf4 to_affine_quantized_intx to_affine_quantized_intx_static - to_affine_quantized_floatx + to_affine_quantized_float8 to_affine_quantized_floatx_static to_affine_quantized_fpx NF4Tensor diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 9cbd4cd2a0..95d1f2de32 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,16 +1,14 @@ from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, - to_affine_quantized_floatx, + to_affine_quantized_float8, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from .floatx import ( - Float8Layout, -) +from .floatx import Float8Layout from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( BlockSparseLayout, @@ -24,10 +22,7 @@ UintxLayout, to_marlinqqq_quantized_intx, ) -from .utils import ( - Layout, - PlainLayout, -) +from .utils import Layout, PlainLayout __all__ = [ "NF4Tensor", @@ -36,8 +31,8 @@ "to_affine_quantized_intx", "to_affine_quantized_intx_static", "to_affine_quantized_fpx", - "to_affine_quantized_floatx", "to_affine_quantized_floatx_static", + "to_affine_quantized_float8", "to_marlinqqq_quantized_intx", "Layout", "PlainLayout", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 506e8f0174..dca74dd948 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -14,7 +14,7 @@ choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, - dequantize_affine_floatx, + dequantize_affine_float8, quantize_affine, quantize_affine_float8, quantize_affine_floatx, @@ -28,8 +28,8 @@ "AffineQuantizedTensor", "register_layout", "to_affine_quantized_intx", - "to_affine_quantized_floatx", "to_affine_quantized_intx_static", + "to_affine_quantized_float8", "to_affine_quantized_floatx_static", "to_affine_quantized_fpx", ] @@ -124,12 +124,13 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor from torchao.dtypes.floatx import FloatxTensorCoreLayout if isinstance(self._layout, FloatxTensorCoreLayout): + # TODO(danielvegamyhre): I think this dequantize method will be used for both float8 and fpx. + # If there's no way to distinguish which it is here, we need to implement float8 and fpx sublcasses + # of AQT so we have separate dequantization procedures for each. int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx( + return dequantize_affine_float8( int_data, scale, - self._layout.ebits, - self._layout.mbits, output_dtype=output_dtype, ) else: @@ -430,7 +431,6 @@ def from_hp_to_float8( scale = choose_qparams_affine_float8( input_float, target_dtype, - target_dtype, ) fp8_data = quantize_affine_float8( input_float, @@ -500,7 +500,6 @@ def _apply_fn_to_data(self, fn): to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static -to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8 # experimental will be merged in to floatx diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index 0780eb3a84..4db8227492 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -27,14 +27,8 @@ from torchao.quantization.autoquant import ( AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV1, ) -from torchao.quantization.granularity import ( - PerRow, - PerTensor, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from torchao.quantization.granularity import PerRow, PerTensor +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.quantization.subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -991,7 +985,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight( @classmethod def from_float(cls, weight): # avoid circular dep - from torchao.dtypes import to_affine_quantized_floatx + from torchao.dtypes import to_affine_quantized_float8 from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings @@ -1015,12 +1009,11 @@ def get_per_token_block_size(x): activation_dtype=input_target_dtype, ) block_size = get_weight_block_size(weight) - weight = to_affine_quantized_floatx( + weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=target_dtype, _layout=_layout, - scale_dtype=torch.float32, ) weight = super( AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls @@ -1040,7 +1033,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight( @classmethod def from_float(cls, weight): # avoid circular dep - from torchao.dtypes import to_affine_quantized_floatx + from torchao.dtypes import to_affine_quantized_float8 from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings @@ -1058,12 +1051,11 @@ def get_weight_block_size(x): activation_dtype=input_target_dtype, ) block_size = get_weight_block_size(weight) - weight = to_affine_quantized_floatx( + weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=target_dtype, _layout=_layout, - scale_dtype=torch.float32, ) weight = super( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index d49e84e066..f9b0439de3 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -18,10 +18,7 @@ LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.quantization.utils import ( compute_error, quantize_activation_per_token_absmax, @@ -34,10 +31,7 @@ is_sm_at_least_90, ) -from .granularity import ( - PerRow, - PerTensor, -) +from .granularity import PerRow, PerTensor from .subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -969,7 +963,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, BFloat16Ten @classmethod def from_float(cls, weight): # avoid circular dep - from torchao.dtypes import to_affine_quantized_floatx + from torchao.dtypes import to_affine_quantized_float8 from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings @@ -995,12 +989,11 @@ def get_per_token_block_size(x): } block_size = get_weight_block_size(weight) - weight = to_affine_quantized_floatx( + weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=target_dtype, _layout=_layout, - scale_dtype=torch.float32, ) weight = to_linear_activation_quantized( weight, input_quant_func, quant_kwargs=input_quant_kwargs @@ -1025,7 +1018,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight( @classmethod def from_float(cls, weight): # avoid circular dep - from torchao.dtypes import to_affine_quantized_floatx + from torchao.dtypes import to_affine_quantized_float8 from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings @@ -1043,12 +1036,11 @@ def get_weight_block_size(x): "activation_dtype": input_target_dtype, } block_size = get_weight_block_size(weight) - weight = to_affine_quantized_floatx( + weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=target_dtype, _layout=_layout, - scale_dtype=torch.float32, ) weight = super( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..32e91a2cc6 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -36,7 +36,7 @@ SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, - to_affine_quantized_floatx, + to_affine_quantized_float8, to_affine_quantized_floatx_static, to_affine_quantized_intx, to_marlinqqq_quantized_intx, @@ -66,21 +66,13 @@ Int8DynActInt4WeightGPTQQuantizer, Int8DynActInt4WeightQuantizer, ) -from .granularity import ( - PerRow, - PerTensor, -) +from .granularity import PerRow, PerTensor from .linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from .qat import ( - intx_quantization_aware_training, -) -from .quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from .qat import intx_quantization_aware_training +from .quant_primitives import MappingType, ZeroPointDomain from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, Int8DynamicallyQuantizedLinearWeight, @@ -915,10 +907,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. + warnings.warn( + """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. from torchao.dtypes import SemiSparseLayout - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""") + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" + ) return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) @@ -934,15 +928,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): The actual matmul will be computed in original precision of the weight tensor. """ - from torchao.dtypes import to_affine_quantized_floatx def apply_float8wo_quant(weight): block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx( + return to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=None, _layout=Float8Layout(mm_config=None), ) @@ -1016,11 +1008,10 @@ def _input_activation_quant_func_fp8( block_size = get_block_size(x.shape, activation_granularity) if scale is None: - activation = to_affine_quantized_floatx( + activation = to_affine_quantized_float8( input_float=x, block_size=block_size, target_dtype=activation_dtype, - scale_dtype=torch.float32, _layout=Float8Layout(mm_config=None), # Config is stored on weight ) else: @@ -1102,11 +1093,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor): ), "PerRow quantization only works for bfloat16 precision input weight" block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( + quantized_weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=torch.float32, _layout=Float8Layout(mm_config=mm_config), ) @@ -1157,11 +1147,10 @@ def apply_float8_static_activation_quant(weight: torch.Tensor): if not _fp8_mm_compat(weight): return weight block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( + quantized_weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=torch.float32, _layout=Float8Layout(mm_config=mm_config), )