Skip to content

Commit

Permalink
replace to_affine_quantized_floatx with to_affine_quantized_float8 in…
Browse files Browse the repository at this point in the history
… quantization APIs

ghstack-source-id: 75b010de13f2a6627d542965d6a9fa6f60b86bbb
ghstack-comment-id: 2608105249
Pull Request resolved: #1599
  • Loading branch information
danielvegamyhre committed Jan 23, 2025
1 parent 9fecad1 commit e1627f3
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 154 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ torchao.dtypes
to_nf4
to_affine_quantized_intx
to_affine_quantized_intx_static
to_affine_quantized_floatx
to_affine_quantized_float8
to_affine_quantized_floatx_static
to_affine_quantized_fpx
NF4Tensor
Expand Down
14 changes: 4 additions & 10 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from . import affine_quantized_tensor_ops
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
to_affine_quantized_fpx,
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .floatx import (
Float8Layout,
)
from .float8 import Float8Layout, to_affine_quantized_float8
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
BlockSparseLayout,
Expand All @@ -21,13 +18,10 @@
MarlinSparseLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
to_marlinqqq_quantized_intx,
UintxLayout,
)
from .utils import (
Layout,
PlainLayout,
)
from .utils import Layout, PlainLayout

__all__ = [
"NF4Tensor",
Expand All @@ -36,8 +30,8 @@
"to_affine_quantized_intx",
"to_affine_quantized_intx_static",
"to_affine_quantized_fpx",
"to_affine_quantized_floatx",
"to_affine_quantized_floatx_static",
"to_affine_quantized_float8",
"to_marlinqqq_quantized_intx",
"Layout",
"PlainLayout",
Expand Down
92 changes: 22 additions & 70 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
Expand All @@ -28,7 +25,6 @@
"AffineQuantizedTensor",
"register_layout",
"to_affine_quantized_intx",
"to_affine_quantized_floatx",
"to_affine_quantized_intx_static",
"to_affine_quantized_floatx_static",
"to_affine_quantized_fpx",
Expand Down Expand Up @@ -121,40 +117,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
if output_dtype is None:
output_dtype = self.dtype

from torchao.dtypes.floatx import FloatxTensorCoreLayout

if isinstance(self._layout, FloatxTensorCoreLayout):
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
int_data,
scale,
self._layout.ebits,
self._layout.mbits,
output_dtype=output_dtype,
)
else:
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout

if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq
if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq

def __tensor_flatten__(self):
return ["tensor_impl"], [
Expand Down Expand Up @@ -272,7 +256,7 @@ def from_hp_to_intx(
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
return cls(
tensor_impl,
Expand Down Expand Up @@ -417,36 +401,6 @@ def from_hp_to_fpx(
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)

@classmethod
def from_hp_to_float8(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = PlainLayout(),
):
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)

@property
def _layout(self) -> Layout:
return self.tensor_impl._layout
Expand Down Expand Up @@ -500,9 +454,7 @@ def _apply_fn_to_data(self, fn):

to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

Expand Down
12 changes: 4 additions & 8 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
)
from torchao.dtypes.floatx.float8_layout import (
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from torchao.dtypes.float8.float8_layout import (
_linear_fp8_act_fp8_weight_check,
_linear_fp8_act_fp8_weight_impl,
_linear_fp_act_fp8_weight_check,
Expand Down Expand Up @@ -37,11 +35,11 @@
_linear_fp_act_int4_weight_sparse_marlin_impl,
)
from torchao.dtypes.uintx.plain_layout import (
PlainAQTTensorImpl,
_linear_fp_act_int8_weight_check,
_linear_fp_act_int8_weight_impl,
_linear_int8_act_int8_weight_check,
_linear_int8_act_int8_weight_impl,
PlainAQTTensorImpl,
)
from torchao.dtypes.uintx.semi_sparse_layout import (
_linear_int8_act_int8_weight_semi_structured_sparse_check,
Expand All @@ -52,9 +50,7 @@
_linear_bf16_act_uint4_weight_impl,
)
from torchao.quantization.quant_primitives import dequantize_affine
from torchao.utils import (
fill_defaults,
)
from torchao.utils import fill_defaults

logger = logging.getLogger(__name__)

Expand Down
20 changes: 20 additions & 0 deletions torchao/dtypes/float8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .float8_layout import (
_linear_fp8_act_fp8_weight_check,
_linear_fp8_act_fp8_weight_impl,
_linear_fp_act_fp8_weight_check,
_linear_fp_act_fp8_weight_impl,
Float8Layout,
Float8QuantizedTensor,
to_affine_quantized_float8,
)


__all__ = [
"Float8Layout",
"to_affine_quantized_float8",
"Float8QuantizedTensor",
"_linear_fp8_act_fp8_weight_check",
"_linear_fp8_act_fp8_weight_impl",
"_linear_fp_act_fp8_weight_check",
"_linear_fp_act_fp8_weight_impl",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape
from torchao.dtypes.utils import AQTTensorImpl, get_out_shape, Layout
from torchao.float8.inference import (
Float8MMConfig,
_is_rowwise_scaled,
addmm_float8_unwrapped_inference,
Float8MMConfig,
preprocess_data,
)
from torchao.quantization.quant_primitives import (
choose_qparams_affine_float8,
dequantize_affine_float8,
FP8_TYPES,
quantize_affine_float8,
)
from torchao.utils import _is_float8_type, fill_defaults

aten = torch.ops.aten
Expand Down Expand Up @@ -209,19 +215,64 @@ def __repr__(self):
)


class Float8QuantizedTensor(AffineQuantizedTensor):
"""
Float8 quantized tensor subclass which inherits Float8QuantizedTensor class.
"""

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_float8(
int_data,
scale,
output_dtype=output_dtype,
)

@classmethod
def from_hp_to_float8(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = Float8Layout(),
):
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)


##########################
# Float8 Dispatch Kernels
##########################


def _linear_fp8_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
input_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
weight_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
bias: Optional[torch.Tensor],
) -> bool:
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
def check_aqt(aqt: Union[torch.Tensor, Float8QuantizedTensor]) -> bool:
return (
isinstance(aqt, AffineQuantizedTensor)
isinstance(aqt, Float8QuantizedTensor)
and isinstance(aqt._layout, Float8Layout)
and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
Expand All @@ -241,8 +292,8 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):


def _linear_fp8_act_fp8_weight_impl(
input_tensor: "AffineQuantizedTensor",
weight_tensor: "AffineQuantizedTensor",
input_tensor: "Float8QuantizedTensor",
weight_tensor: "Float8QuantizedTensor",
bias: Optional[torch.Tensor],
):
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
Expand Down Expand Up @@ -285,8 +336,8 @@ def _linear_fp8_act_fp8_weight_impl(


def _linear_fp_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
input_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
weight_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
bias: Optional[torch.Tensor],
) -> bool:
return (
Expand All @@ -295,7 +346,7 @@ def _linear_fp_act_fp8_weight_check(
and input_tensor.is_floating_point()
and
# weight is float8 quantized affine quantized tensor
isinstance(weight_tensor, AffineQuantizedTensor)
isinstance(weight_tensor, Float8QuantizedTensor)
and isinstance(weight_tensor._layout, Float8Layout)
and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and (
Expand All @@ -307,7 +358,10 @@ def _linear_fp_act_fp8_weight_check(

def _linear_fp_act_fp8_weight_impl(
input_tensor: torch.Tensor,
weight_tensor: "AffineQuantizedTensor",
weight_tensor: "Float8QuantizedTensor",
bias: Optional[torch.Tensor],
):
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)


to_affine_quantized_float8 = Float8QuantizedTensor.from_hp_to_float8
2 changes: 0 additions & 2 deletions torchao/dtypes/floatx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .float8_layout import Float8Layout
from .floatx_tensor_core_layout import (
FloatxTensorCoreLayout,
from_scaled_tc_floatx,
Expand All @@ -9,5 +8,4 @@
"FloatxTensorCoreLayout",
"to_scaled_tc_floatx",
"from_scaled_tc_floatx",
"Float8Layout",
]
Loading

0 comments on commit e1627f3

Please sign in to comment.