Skip to content

Commit

Permalink
replace to_affine_quantized_floatx with to_affine_quantized_float8 in…
Browse files Browse the repository at this point in the history
… quantization APIs

ghstack-source-id: 059b6978da29d45ed55481b0c510231f2ad93303
ghstack-comment-id: 2608105249
Pull Request resolved: #1599
  • Loading branch information
danielvegamyhre committed Jan 22, 2025
1 parent 26a0a50 commit 1b35144
Showing 1 changed file with 11 additions and 20 deletions.
31 changes: 11 additions & 20 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
to_affine_quantized_float8,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
to_affine_quantized_intx,
Expand Down Expand Up @@ -66,21 +67,13 @@
Int8DynActInt4WeightGPTQQuantizer,
Int8DynActInt4WeightQuantizer,
)
from .granularity import (
PerRow,
PerTensor,
)
from .granularity import PerRow, PerTensor
from .linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .qat import (
intx_quantization_aware_training,
)
from .quant_primitives import (
MappingType,
ZeroPointDomain,
)
from .qat import intx_quantization_aware_training
from .quant_primitives import MappingType, ZeroPointDomain
from .subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8DynamicallyQuantizedLinearWeight,
Expand Down Expand Up @@ -915,10 +908,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
quantization + 2:4 sparsity to linear layers.
"""
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
warnings.warn(
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
from torchao.dtypes import SemiSparseLayout
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
)

return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())

Expand All @@ -934,15 +929,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
The actual matmul will be computed in original precision of the weight tensor.
"""
from torchao.dtypes import to_affine_quantized_floatx

def apply_float8wo_quant(weight):
block_size = (1, weight.shape[1])
return to_affine_quantized_floatx(
return to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=None,
_layout=Float8Layout(mm_config=None),
)

Expand Down Expand Up @@ -1016,11 +1009,10 @@ def _input_activation_quant_func_fp8(

block_size = get_block_size(x.shape, activation_granularity)
if scale is None:
activation = to_affine_quantized_floatx(
activation = to_affine_quantized_float8(
input_float=x,
block_size=block_size,
target_dtype=activation_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=None), # Config is stored on weight
)
else:
Expand Down Expand Up @@ -1102,11 +1094,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
), "PerRow quantization only works for bfloat16 precision input weight"

block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = to_affine_quantized_floatx(
quantized_weight = to_affine_quantized_float8(
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
)

Expand Down

0 comments on commit 1b35144

Please sign in to comment.