From e1627f33cc5221c8394821441370f32e34fece5b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 09:46:16 -0800 Subject: [PATCH] replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs ghstack-source-id: 75b010de13f2a6627d542965d6a9fa6f60b86bbb ghstack-comment-id: 2608105249 Pull Request resolved: https://github.com/pytorch/ao/pull/1599 --- docs/source/api_ref_dtypes.rst | 2 +- torchao/dtypes/__init__.py | 14 +-- torchao/dtypes/affine_quantized_tensor.py | 92 +++++-------------- torchao/dtypes/affine_quantized_tensor_ops.py | 12 +-- torchao/dtypes/float8/__init__.py | 20 ++++ .../{floatx => float8}/float8_layout.py | 78 +++++++++++++--- torchao/dtypes/floatx/__init__.py | 2 - .../prototype/quantization/autoquant_v2.py | 20 ++-- torchao/quantization/autoquant.py | 20 ++-- torchao/quantization/quant_api.py | 35 +++---- 10 files changed, 141 insertions(+), 154 deletions(-) create mode 100644 torchao/dtypes/float8/__init__.py rename torchao/dtypes/{floatx => float8}/float8_layout.py (81%) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index fbe680953e..c37c3a81ce 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -13,7 +13,7 @@ torchao.dtypes to_nf4 to_affine_quantized_intx to_affine_quantized_intx_static - to_affine_quantized_floatx + to_affine_quantized_float8 to_affine_quantized_floatx_static to_affine_quantized_fpx NF4Tensor diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 9cbd4cd2a0..25bbc6754c 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,16 +1,13 @@ from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, - to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, ) -from .floatx import ( - Float8Layout, -) +from .float8 import Float8Layout, to_affine_quantized_float8 from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( BlockSparseLayout, @@ -21,13 +18,10 @@ MarlinSparseLayout, SemiSparseLayout, TensorCoreTiledLayout, - UintxLayout, to_marlinqqq_quantized_intx, + UintxLayout, ) -from .utils import ( - Layout, - PlainLayout, -) +from .utils import Layout, PlainLayout __all__ = [ "NF4Tensor", @@ -36,8 +30,8 @@ "to_affine_quantized_intx", "to_affine_quantized_intx_static", "to_affine_quantized_fpx", - "to_affine_quantized_floatx", "to_affine_quantized_floatx_static", + "to_affine_quantized_float8", "to_marlinqqq_quantized_intx", "Layout", "PlainLayout", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 506e8f0174..6374d398f5 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -10,13 +10,10 @@ MappingType, ZeroPointDomain, choose_qparams_affine, - choose_qparams_affine_float8, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, - dequantize_affine_floatx, quantize_affine, - quantize_affine_float8, quantize_affine_floatx, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor @@ -28,7 +25,6 @@ "AffineQuantizedTensor", "register_layout", "to_affine_quantized_intx", - "to_affine_quantized_floatx", "to_affine_quantized_intx_static", "to_affine_quantized_floatx_static", "to_affine_quantized_fpx", @@ -121,40 +117,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - from torchao.dtypes.floatx import FloatxTensorCoreLayout - - if isinstance(self._layout, FloatxTensorCoreLayout): - int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx( - int_data, - scale, - self._layout.ebits, - self._layout.mbits, - output_dtype=output_dtype, - ) - else: - data, scale, zero_point = self.tensor_impl.get_plain() - dq = dequantize_affine( - data, - self.block_size, - scale, - zero_point, - data.dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain, - output_dtype=output_dtype, - ) - from torchao.dtypes.uintx import TensorCoreTiledLayout + data, scale, zero_point = self.tensor_impl.get_plain() + dq = dequantize_affine( + data, + self.block_size, + scale, + zero_point, + data.dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain, + output_dtype=output_dtype, + ) + from torchao.dtypes.uintx import TensorCoreTiledLayout - if isinstance(self._layout, TensorCoreTiledLayout): - # need to return to original shape if tensor was padded - # in preprocessing - # TODO: we could add an API for this if there are more use cases - # (e.g. dequant_post_process) in TensorImpl or Layout - for dim, dim_size in enumerate(self.shape): - dq = dq.narrow(dim, 0, dim_size) - return dq + if isinstance(self._layout, TensorCoreTiledLayout): + # need to return to original shape if tensor was padded + # in preprocessing + # TODO: we could add an API for this if there are more use cases + # (e.g. dequant_post_process) in TensorImpl or Layout + for dim, dim_size in enumerate(self.shape): + dq = dq.narrow(dim, 0, dim_size) + return dq def __tensor_flatten__(self): return ["tensor_impl"], [ @@ -272,7 +256,7 @@ def from_hp_to_intx( # Note: output will be uint8 tensor for sub byte tensors for now data = _layout.post_process(data) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) return cls( tensor_impl, @@ -417,36 +401,6 @@ def from_hp_to_fpx( tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) - @classmethod - def from_hp_to_float8( - cls, - input_float: torch.Tensor, - target_dtype: torch.dtype, - block_size: Tuple[int, ...], - _layout: Layout = PlainLayout(), - ): - assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8" - original_shape = input_float.shape - scale = choose_qparams_affine_float8( - input_float, - target_dtype, - target_dtype, - ) - fp8_data = quantize_affine_float8( - input_float, - scale, - target_dtype, - ) - fp8_data = _layout.post_process(fp8_data) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - dtype=input_float.dtype, - ) - @property def _layout(self) -> Layout: return self.tensor_impl._layout @@ -500,9 +454,7 @@ def _apply_fn_to_data(self, fn): to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static -to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static -to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8 # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 76df949852..eab72929a5 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -3,10 +3,8 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, -) -from torchao.dtypes.floatx.float8_layout import ( +from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor +from torchao.dtypes.float8.float8_layout import ( _linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl, _linear_fp_act_fp8_weight_check, @@ -37,11 +35,11 @@ _linear_fp_act_int4_weight_sparse_marlin_impl, ) from torchao.dtypes.uintx.plain_layout import ( - PlainAQTTensorImpl, _linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl, _linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl, + PlainAQTTensorImpl, ) from torchao.dtypes.uintx.semi_sparse_layout import ( _linear_int8_act_int8_weight_semi_structured_sparse_check, @@ -52,9 +50,7 @@ _linear_bf16_act_uint4_weight_impl, ) from torchao.quantization.quant_primitives import dequantize_affine -from torchao.utils import ( - fill_defaults, -) +from torchao.utils import fill_defaults logger = logging.getLogger(__name__) diff --git a/torchao/dtypes/float8/__init__.py b/torchao/dtypes/float8/__init__.py new file mode 100644 index 0000000000..3a8ad64e92 --- /dev/null +++ b/torchao/dtypes/float8/__init__.py @@ -0,0 +1,20 @@ +from .float8_layout import ( + _linear_fp8_act_fp8_weight_check, + _linear_fp8_act_fp8_weight_impl, + _linear_fp_act_fp8_weight_check, + _linear_fp_act_fp8_weight_impl, + Float8Layout, + Float8QuantizedTensor, + to_affine_quantized_float8, +) + + +__all__ = [ + "Float8Layout", + "to_affine_quantized_float8", + "Float8QuantizedTensor", + "_linear_fp8_act_fp8_weight_check", + "_linear_fp8_act_fp8_weight_impl", + "_linear_fp_act_fp8_weight_check", + "_linear_fp_act_fp8_weight_impl", +] diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/float8/float8_layout.py similarity index 81% rename from torchao/dtypes/floatx/float8_layout.py rename to torchao/dtypes/float8/float8_layout.py index dd995fb157..4c4f26552c 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/float8/float8_layout.py @@ -11,13 +11,19 @@ AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape +from torchao.dtypes.utils import AQTTensorImpl, get_out_shape, Layout from torchao.float8.inference import ( - Float8MMConfig, _is_rowwise_scaled, addmm_float8_unwrapped_inference, + Float8MMConfig, preprocess_data, ) +from torchao.quantization.quant_primitives import ( + choose_qparams_affine_float8, + dequantize_affine_float8, + FP8_TYPES, + quantize_affine_float8, +) from torchao.utils import _is_float8_type, fill_defaults aten = torch.ops.aten @@ -209,19 +215,64 @@ def __repr__(self): ) +class Float8QuantizedTensor(AffineQuantizedTensor): + """ + Float8 quantized tensor subclass which inherits Float8QuantizedTensor class. + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + int_data, scale = self.tensor_impl.get_plain() + return dequantize_affine_float8( + int_data, + scale, + output_dtype=output_dtype, + ) + + @classmethod + def from_hp_to_float8( + cls, + input_float: torch.Tensor, + target_dtype: torch.dtype, + block_size: Tuple[int, ...], + _layout: Layout = Float8Layout(), + ): + assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8" + original_shape = input_float.shape + scale = choose_qparams_affine_float8( + input_float, + target_dtype, + ) + fp8_data = quantize_affine_float8( + input_float, + scale, + target_dtype, + ) + fp8_data = _layout.post_process(fp8_data) + tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + dtype=input_float.dtype, + ) + + ########################## # Float8 Dispatch Kernels ########################## def _linear_fp8_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], - weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + input_tensor: Union[torch.Tensor, "Float8QuantizedTensor"], + weight_tensor: Union[torch.Tensor, "Float8QuantizedTensor"], bias: Optional[torch.Tensor], ) -> bool: - def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: + def check_aqt(aqt: Union[torch.Tensor, Float8QuantizedTensor]) -> bool: return ( - isinstance(aqt, AffineQuantizedTensor) + isinstance(aqt, Float8QuantizedTensor) and isinstance(aqt._layout, Float8Layout) and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) @@ -241,8 +292,8 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): def _linear_fp8_act_fp8_weight_impl( - input_tensor: "AffineQuantizedTensor", - weight_tensor: "AffineQuantizedTensor", + input_tensor: "Float8QuantizedTensor", + weight_tensor: "Float8QuantizedTensor", bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" @@ -285,8 +336,8 @@ def _linear_fp8_act_fp8_weight_impl( def _linear_fp_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], - weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + input_tensor: Union[torch.Tensor, "Float8QuantizedTensor"], + weight_tensor: Union[torch.Tensor, "Float8QuantizedTensor"], bias: Optional[torch.Tensor], ) -> bool: return ( @@ -295,7 +346,7 @@ def _linear_fp_act_fp8_weight_check( and input_tensor.is_floating_point() and # weight is float8 quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) + isinstance(weight_tensor, Float8QuantizedTensor) and isinstance(weight_tensor._layout, Float8Layout) and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and ( @@ -307,7 +358,10 @@ def _linear_fp_act_fp8_weight_check( def _linear_fp_act_fp8_weight_impl( input_tensor: torch.Tensor, - weight_tensor: "AffineQuantizedTensor", + weight_tensor: "Float8QuantizedTensor", bias: Optional[torch.Tensor], ): return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) + + +to_affine_quantized_float8 = Float8QuantizedTensor.from_hp_to_float8 diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 3f0a1ccd5c..a1d0e72234 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,4 +1,3 @@ -from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( FloatxTensorCoreLayout, from_scaled_tc_floatx, @@ -9,5 +8,4 @@ "FloatxTensorCoreLayout", "to_scaled_tc_floatx", "from_scaled_tc_floatx", - "Float8Layout", ] diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index 0780eb3a84..4db8227492 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -27,14 +27,8 @@ from torchao.quantization.autoquant import ( AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV1, ) -from torchao.quantization.granularity import ( - PerRow, - PerTensor, -) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from torchao.quantization.granularity import PerRow, PerTensor +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.quantization.subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -991,7 +985,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight( @classmethod def from_float(cls, weight): # avoid circular dep - from torchao.dtypes import to_affine_quantized_floatx + from torchao.dtypes import to_affine_quantized_float8 from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings @@ -1015,12 +1009,11 @@ def get_per_token_block_size(x): activation_dtype=input_target_dtype, ) block_size = get_weight_block_size(weight) - weight = to_affine_quantized_floatx( + weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=target_dtype, _layout=_layout, - scale_dtype=torch.float32, ) weight = super( AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls @@ -1040,7 +1033,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight( @classmethod def from_float(cls, weight): # avoid circular dep - from torchao.dtypes import to_affine_quantized_floatx + from torchao.dtypes import to_affine_quantized_float8 from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings @@ -1058,12 +1051,11 @@ def get_weight_block_size(x): activation_dtype=input_target_dtype, ) block_size = get_weight_block_size(weight) - weight = to_affine_quantized_floatx( + weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=target_dtype, _layout=_layout, - scale_dtype=torch.float32, ) weight = super( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index d49e84e066..f9b0439de3 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -18,10 +18,7 @@ LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.quantization.utils import ( compute_error, quantize_activation_per_token_absmax, @@ -34,10 +31,7 @@ is_sm_at_least_90, ) -from .granularity import ( - PerRow, - PerTensor, -) +from .granularity import PerRow, PerTensor from .subclass import ( # noqa Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, @@ -969,7 +963,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, BFloat16Ten @classmethod def from_float(cls, weight): # avoid circular dep - from torchao.dtypes import to_affine_quantized_floatx + from torchao.dtypes import to_affine_quantized_float8 from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings @@ -995,12 +989,11 @@ def get_per_token_block_size(x): } block_size = get_weight_block_size(weight) - weight = to_affine_quantized_floatx( + weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=target_dtype, _layout=_layout, - scale_dtype=torch.float32, ) weight = to_linear_activation_quantized( weight, input_quant_func, quant_kwargs=input_quant_kwargs @@ -1025,7 +1018,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight( @classmethod def from_float(cls, weight): # avoid circular dep - from torchao.dtypes import to_affine_quantized_floatx + from torchao.dtypes import to_affine_quantized_float8 from torchao.quantization.quant_api import _input_activation_quant_func_fp8 # weight settings @@ -1043,12 +1036,11 @@ def get_weight_block_size(x): "activation_dtype": input_target_dtype, } block_size = get_weight_block_size(weight) - weight = to_affine_quantized_floatx( + weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=target_dtype, _layout=_layout, - scale_dtype=torch.float32, ) weight = super( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..32e91a2cc6 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -36,7 +36,7 @@ SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, - to_affine_quantized_floatx, + to_affine_quantized_float8, to_affine_quantized_floatx_static, to_affine_quantized_intx, to_marlinqqq_quantized_intx, @@ -66,21 +66,13 @@ Int8DynActInt4WeightGPTQQuantizer, Int8DynActInt4WeightQuantizer, ) -from .granularity import ( - PerRow, - PerTensor, -) +from .granularity import PerRow, PerTensor from .linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from .qat import ( - intx_quantization_aware_training, -) -from .quant_primitives import ( - MappingType, - ZeroPointDomain, -) +from .qat import intx_quantization_aware_training +from .quant_primitives import MappingType, ZeroPointDomain from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, Int8DynamicallyQuantizedLinearWeight, @@ -915,10 +907,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight quantization + 2:4 sparsity to linear layers. """ - warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. + warnings.warn( + """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead. from torchao.dtypes import SemiSparseLayout - int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""") + int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" + ) return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) @@ -934,15 +928,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): The actual matmul will be computed in original precision of the weight tensor. """ - from torchao.dtypes import to_affine_quantized_floatx def apply_float8wo_quant(weight): block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx( + return to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=None, _layout=Float8Layout(mm_config=None), ) @@ -1016,11 +1008,10 @@ def _input_activation_quant_func_fp8( block_size = get_block_size(x.shape, activation_granularity) if scale is None: - activation = to_affine_quantized_floatx( + activation = to_affine_quantized_float8( input_float=x, block_size=block_size, target_dtype=activation_dtype, - scale_dtype=torch.float32, _layout=Float8Layout(mm_config=None), # Config is stored on weight ) else: @@ -1102,11 +1093,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor): ), "PerRow quantization only works for bfloat16 precision input weight" block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( + quantized_weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=torch.float32, _layout=Float8Layout(mm_config=mm_config), ) @@ -1157,11 +1147,10 @@ def apply_float8_static_activation_quant(weight: torch.Tensor): if not _fp8_mm_compat(weight): return weight block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( + quantized_weight = to_affine_quantized_float8( input_float=weight, block_size=block_size, target_dtype=weight_dtype, - scale_dtype=torch.float32, _layout=Float8Layout(mm_config=mm_config), )