From 1b35144163ea34ee00e03ba053b8b3fea6f23840 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 22 Jan 2025 11:37:10 -0800 Subject: [PATCH] replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs ghstack-source-id: 059b6978da29d45ed55481b0c510231f2ad93303 ghstack-comment-id: 2608105249 Pull Request resolved: https://github.com/pytorch/ao/pull/1599 --- torchao/quantization/quant_api.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..cfae8ee0ac 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -36,6 +36,7 @@ SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, + to_affine_quantized_float8, to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, @@ -66,21 +67,13 @@ Int8DynActInt4WeightGPTQQuantizer, Int8DynActInt4WeightQuantizer, ) -from .granularity import ( - PerRow, - PerTensor, -) +from .granularity import PerRow, PerTensor from .linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from .qat import ( - intx_quantization_aware_training, -) -from .quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from .qat import intx_quantization_aware_training +from .quant_primitives import MappingType, ZeroPointDomain from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, Int8DynamicallyQuantizedLinearWeight, @@ -915,10 +908,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. + warnings.warn( + """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. from torchao.dtypes import SemiSparseLayout - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""") + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" + ) return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) @@ -934,15 +929,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): The actual matmul will be computed in original precision of the weight tensor. """ - from torchao.dtypes import to_affine_quantized_floatx def apply_float8wo_quant(weight): block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx( + return to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=None, _layout=Float8Layout(mm_config=None), ) @@ -1016,11 +1009,10 @@ def _input_activation_quant_func_fp8( block_size = get_block_size(x.shape, activation_granularity) if scale is None: - activation = to_affine_quantized_floatx( + activation = to_affine_quantized_float8( input_float=x, block_size=block_size, target_dtype=activation_dtype, - scale_dtype=torch.float32, _layout=Float8Layout(mm_config=None), # Config is stored on weight ) else: @@ -1102,11 +1094,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor): ), "PerRow quantization only works for bfloat16 precision input weight" block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( + quantized_weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=torch.float32, _layout=Float8Layout(mm_config=mm_config), )