Skip to content

Commit

Permalink
replace to_affine_quantized_floatx with to_affine_quantized_float8 in…
Browse files Browse the repository at this point in the history
… quantization APIs

ghstack-source-id: 3028fc5f84252f60353df9144ce3fda62b26fe8c
ghstack-comment-id: 2608105249
Pull Request resolved: #1599
  • Loading branch information
danielvegamyhre committed Jan 23, 2025
1 parent 1b64306 commit 96d6734
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 68 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ torchao.dtypes
to_nf4
to_affine_quantized_intx
to_affine_quantized_intx_static
to_affine_quantized_floatx
to_affine_quantized_float8
to_affine_quantized_floatx_static
to_affine_quantized_fpx
NF4Tensor
Expand Down
13 changes: 4 additions & 9 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from . import affine_quantized_tensor_ops
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized_floatx,
to_affine_quantized_float8,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
to_affine_quantized_fpx,
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .floatx import (
Float8Layout,
)
from .floatx import Float8Layout
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
BlockSparseLayout,
Expand All @@ -24,10 +22,7 @@
UintxLayout,
to_marlinqqq_quantized_intx,
)
from .utils import (
Layout,
PlainLayout,
)
from .utils import Layout, PlainLayout

__all__ = [
"NF4Tensor",
Expand All @@ -36,8 +31,8 @@
"to_affine_quantized_intx",
"to_affine_quantized_intx_static",
"to_affine_quantized_fpx",
"to_affine_quantized_floatx",
"to_affine_quantized_floatx_static",
"to_affine_quantized_float8",
"to_marlinqqq_quantized_intx",
"Layout",
"PlainLayout",
Expand Down
13 changes: 6 additions & 7 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
dequantize_affine_float8,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
Expand All @@ -28,8 +28,8 @@
"AffineQuantizedTensor",
"register_layout",
"to_affine_quantized_intx",
"to_affine_quantized_floatx",
"to_affine_quantized_intx_static",
"to_affine_quantized_float8",
"to_affine_quantized_floatx_static",
"to_affine_quantized_fpx",
]
Expand Down Expand Up @@ -124,12 +124,13 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
from torchao.dtypes.floatx import FloatxTensorCoreLayout

if isinstance(self._layout, FloatxTensorCoreLayout):
# TODO(danielvegamyhre): I think this dequantize method will be used for both float8 and fpx.
# If there's no way to distinguish which it is here, we need to implement float8 and fpx sublcasses
# of AQT so we have separate dequantization procedures for each.
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
return dequantize_affine_float8(
int_data,
scale,
self._layout.ebits,
self._layout.mbits,
output_dtype=output_dtype,
)
else:
Expand Down Expand Up @@ -430,7 +431,6 @@ def from_hp_to_float8(
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
Expand Down Expand Up @@ -500,7 +500,6 @@ def _apply_fn_to_data(self, fn):

to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
# experimental will be merged in to floatx
Expand Down
20 changes: 6 additions & 14 deletions torchao/prototype/quantization/autoquant_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@
from torchao.quantization.autoquant import (
AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV1,
)
from torchao.quantization.granularity import (
PerRow,
PerTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.quantization.subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -991,7 +985,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes import to_affine_quantized_float8
from torchao.quantization.quant_api import _input_activation_quant_func_fp8

# weight settings
Expand All @@ -1015,12 +1009,11 @@ def get_per_token_block_size(x):
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = super(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls
Expand All @@ -1040,7 +1033,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes import to_affine_quantized_float8
from torchao.quantization.quant_api import _input_activation_quant_func_fp8

# weight settings
Expand All @@ -1058,12 +1051,11 @@ def get_weight_block_size(x):
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = super(
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls
Expand Down
20 changes: 6 additions & 14 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.quantization.utils import (
compute_error,
quantize_activation_per_token_absmax,
Expand All @@ -34,10 +31,7 @@
is_sm_at_least_90,
)

from .granularity import (
PerRow,
PerTensor,
)
from .granularity import PerRow, PerTensor
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -969,7 +963,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, BFloat16Ten
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes import to_affine_quantized_float8
from torchao.quantization.quant_api import _input_activation_quant_func_fp8

# weight settings
Expand All @@ -995,12 +989,11 @@ def get_per_token_block_size(x):
}
block_size = get_weight_block_size(weight)

weight = to_affine_quantized_floatx(
weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = to_linear_activation_quantized(
weight, input_quant_func, quant_kwargs=input_quant_kwargs
Expand All @@ -1025,7 +1018,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.dtypes import to_affine_quantized_float8
from torchao.quantization.quant_api import _input_activation_quant_func_fp8

# weight settings
Expand All @@ -1043,12 +1036,11 @@ def get_weight_block_size(x):
"activation_dtype": input_target_dtype,
}
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = super(
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls
Expand Down
35 changes: 12 additions & 23 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
to_affine_quantized_floatx,
to_affine_quantized_float8,
to_affine_quantized_floatx_static,
to_affine_quantized_intx,
to_marlinqqq_quantized_intx,
Expand Down Expand Up @@ -66,21 +66,13 @@
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightQuantizer,
)
from .granularity import (
PerRow,
PerTensor,
)
from .granularity import PerRow, PerTensor
from .linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .qat import (
intx_quantization_aware_training,
)
from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .qat import intx_quantization_aware_training
from .quant_primitives import MappingType, ZeroPointDomain
from .subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8DynamicallyQuantizedLinearWeight,
Expand Down Expand Up @@ -915,10 +907,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
quantization + 2:4 sparsity to linear layers.
"""
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
warnings.warn(
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
from torchao.dtypes import SemiSparseLayout
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
)

return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())

Expand All @@ -934,15 +928,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
The actual matmul will be computed in original precision of the weight tensor.
"""
from torchao.dtypes import to_affine_quantized_floatx

def apply_float8wo_quant(weight):
block_size = (1, weight.shape[1])
return to_affine_quantized_floatx(
return to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=None,
_layout=Float8Layout(mm_config=None),
)

Expand Down Expand Up @@ -1016,11 +1008,10 @@ def _input_activation_quant_func_fp8(

block_size = get_block_size(x.shape, activation_granularity)
if scale is None:
activation = to_affine_quantized_floatx(
activation = to_affine_quantized_float8(
input_float=x,
block_size=block_size,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=None), # Config is stored on weight
)
else:
Expand Down Expand Up @@ -1102,11 +1093,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
), "PerRow quantization only works for bfloat16 precision input weight"

block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = to_affine_quantized_floatx(
quantized_weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
)

Expand Down Expand Up @@ -1157,11 +1147,10 @@ def apply_float8_static_activation_quant(weight: torch.Tensor):
if not _fp8_mm_compat(weight):
return weight
block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = to_affine_quantized_floatx(
quantized_weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
)

Expand Down

0 comments on commit 96d6734

Please sign in to comment.