From 24114cebb3fd77737185b1e30bef050283c51478 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 08:49:11 -0800 Subject: [PATCH 1/6] Update [ghstack-poisoned] --- test/quantization/test_qat.py | 2 +- test/quantization/test_quant_api.py | 26 +++ torchao/core/__init__.py | 0 torchao/core/config.py | 13 ++ torchao/quantization/_transform_module.py | 17 ++ torchao/quantization/qat/api.py | 114 +++++++------ torchao/quantization/quant_api.py | 191 ++++++++++++++-------- 7 files changed, 249 insertions(+), 114 deletions(-) create mode 100644 torchao/core/__init__.py create mode 100644 torchao/core/config.py create mode 100644 torchao/quantization/_transform_module.py diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 8a78b8b387..82324394a8 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test_quantize_api(self): + def test_quantize_api_standalone(self): """ Test that the following: diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 177c357047..ca2cbf08ec 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -40,6 +40,7 @@ Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, ) +from torchao.quantization.utils import compute_error from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -761,6 +762,31 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + def test_int4_weight_only_numerics(self): + """ + Simple test of e2e int4_weight_only workflow, comparing numerics + to a bfloat16 baseline. + """ + # TODO(before land) skip on cpu-only + # TODO(before land) support other inference techniques? + + # set up inputs + x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + # TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 + # is that expected? + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() + m_int4_wo = copy.deepcopy(m_ref) + + # quantize + quantize_(m_int4_wo, int4_weight_only()) + + with torch.no_grad(): + y_ref = m_ref(x) + y_int4_wo = m_int4_wo(x) + + sqnr = compute_error(y_ref, y_int4_wo) + assert sqnr >= 20, f"SQNR {sqnr} is too low" + class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/core/__init__.py b/torchao/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/core/config.py b/torchao/core/config.py new file mode 100644 index 0000000000..fbc1216212 --- /dev/null +++ b/torchao/core/config.py @@ -0,0 +1,13 @@ +import abc + + +# directory location for this might need more polish +class AOBaseWorkflowConfig(abc.ABC): + """ + If a workflow config inherits from this then `quantize_` knows + what to do with it. + + TODO write a better docblock. + """ + + pass diff --git a/torchao/quantization/_transform_module.py b/torchao/quantization/_transform_module.py new file mode 100644 index 0000000000..f14e79b5a9 --- /dev/null +++ b/torchao/quantization/_transform_module.py @@ -0,0 +1,17 @@ +from typing import Callable, Dict + +import torch + +from torchao.core.config import AOBaseWorkflowConfig + +_QUANTIZE_CONFIG_HANDLER: Dict[ + AOBaseWorkflowConfig, + Callable[[torch.nn.Module, AOBaseWorkflowConfig], torch.nn.Module], +] = {} + + +def register_quantize_module_handler(config_type): + def decorator(func): + _QUANTIZE_CONFIG_HANDLER[config_type] = func + + return decorator diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index cd3813291f..6356ee1600 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,10 +5,14 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Union +from typing import Any, List, Optional, Union import torch +from torchao.core.config import AOBaseWorkflowConfig +from torchao.quantization._transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.granularity import ( Granularity, PerAxis, @@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any): super().__setattr__(name, value) -def intx_quantization_aware_training( - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, -) -> Callable: +@dataclass +class IntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig): + activation_config: Optional[FakeQuantizeConfig] = None + weight_config: Optional[FakeQuantizeConfig] = None + + +# for BC +intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig + + +@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) +def _intx_quantization_aware_training_transform( + module: torch.nn.Module, + config: IntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: """ - Return a function that applies fake quantization to a `torch.nn.Module`. + THIS IS NOT A PUBLIC API - any usage of this outside of torchao + can break at any time. + + Apply fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. Example usage:: @@ -267,37 +285,32 @@ def intx_quantization_aware_training( `torch.nn.Embedding` with an activation config, then we will raise ValueError as these are not supported. """ - - def _insert_fake_quantize(mod: torch.nn.Module): - """ - Swap the given module with its corresponding fake quantized version. - """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear - - if isinstance(mod, torch.nn.Linear): - return FakeQuantizedLinear.from_linear( - mod, - activation_config, - weight_config, - ) - elif isinstance(mod, torch.nn.Embedding): - if activation_config is not None: - raise ValueError( - "Activation fake quantization is not supported for embedding" - ) - return FakeQuantizedEmbedding.from_embedding(mod, weight_config) - else: + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + mod = module + activation_config = config.activation_config + weight_config = config.weight_config + + if isinstance(mod, torch.nn.Linear): + return FakeQuantizedLinear.from_linear( + mod, + activation_config, + weight_config, + ) + elif isinstance(mod, torch.nn.Embedding): + if activation_config is not None: raise ValueError( - "Module of type '%s' does not have QAT support" % type(mod) + "Activation fake quantization is not supported for embedding" ) + return FakeQuantizedEmbedding.from_embedding(mod, weight_config) + else: + raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) - return _insert_fake_quantize - -def from_intx_quantization_aware_training() -> Callable: +class FromIntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig): """ - Return a function that converts a model with fake quantized modules, + Object that knows how to convert a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, back to model with the original, corresponding modules without @@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable: ) """ - def _remove_fake_quantize(mod: torch.nn.Module): - """ - If the given module is a fake quantized module, return the original - corresponding version of the module without fake quantization. - """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear + pass + + +# for BC +from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig - if isinstance(mod, FakeQuantizedLinear): - return mod.to_linear() - elif isinstance(mod, FakeQuantizedEmbedding): - return mod.to_embedding() - else: - return mod - return _remove_fake_quantize +@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) +def _from_intx_quantization_aware_training_transform( + mod: torch.nn.Module, + config: FromIntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: + """ + If the given module is a fake quantized module, return the original + corresponding version of the module without fake quantization. + """ + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + if isinstance(mod, FakeQuantizedLinear): + return mod.to_linear() + elif isinstance(mod, FakeQuantizedEmbedding): + return mod.to_embedding() + else: + return mod class ComposableQATQuantizer(TwoStepQuantizer): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..450563be36 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,13 +18,15 @@ import logging import types import warnings -from typing import Callable, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize import torchao +from torchao.core.config import AOBaseWorkflowConfig from torchao.dtypes import ( AffineQuantizedTensor, CutlassInt4PackedLayout, @@ -43,6 +45,10 @@ ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig +from torchao.quantization._transform_module import ( + _QUANTIZE_CONFIG_HANDLER, + register_quantize_module_handler, +) from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, ) @@ -117,7 +123,6 @@ "Int8DynActInt4WeightGPTQQuantizer", ] -# update according to the support matrix LAYOUT_TO_ZERO_POINT_DOMAIN = { TensorCoreTiledLayout: [ZeroPointDomain.FLOAT], MarlinSparseLayout: [ZeroPointDomain.INT], @@ -228,6 +233,7 @@ def _replace_with_custom_fn_if_matches_filter( filter_fn, cur_fqn="", device=None, + extra_args: Optional[Tuple[Any, ...]] = (), ) -> None: """ Recursively replaces each child module in `model` with the result of `replacement_fn(child)` @@ -239,6 +245,7 @@ def _replace_with_custom_fn_if_matches_filter( filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. + extra_args (Tuple[Any, ...], optional): optional extra args to pass to `replacement_fn`. Returns: None @@ -252,12 +259,17 @@ def _replace_with_custom_fn_if_matches_filter( if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization - model = replacement_fn(model) + model = replacement_fn(model, *extra_args) return model else: for name, child in model.named_children(): new_child = _replace_with_custom_fn_if_matches_filter( - child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device + child, + replacement_fn, + filter_fn, + f"{cur_fqn}{name}.", + device, + extra_args, ) if new_child is not child: setattr(model, name, new_child) @@ -468,7 +480,10 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, - apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], + # apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], + apply_tensor_subclass: Union[ + Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig + ], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[torch.types.Device] = None, @@ -530,12 +545,33 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - _replace_with_custom_fn_if_matches_filter( - model, - apply_tensor_subclass, - _is_linear if filter_fn is None else filter_fn, - device=device, - ) + if isinstance(apply_tensor_subclass, AOBaseWorkflowConfig): + # new behavior + + # make the variable name make sense + config = apply_tensor_subclass + handler = _QUANTIZE_CONFIG_HANDLER[type(config)] + + # for each linear in the model, apply the transform if filtering passes + # key difference from old is that `config_with_transform` is easily + # inspectable + _replace_with_custom_fn_if_matches_filter( + model, + handler, + _is_linear if filter_fn is None else filter_fn, + device=device, + extra_args=(config,), + ) + + else: + # old behavior, for now keep for BC purposes + # TODO(after discussion): flesh the BC story out more + _replace_with_custom_fn_if_matches_filter( + model, + apply_tensor_subclass, + _is_linear if filter_fn is None else filter_fn, + device=device, + ) def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: @@ -684,14 +720,10 @@ def gemlite_uintx_weight_only( return _get_linear_subclass_inserter(apply_fn) -def int4_weight_only( - group_size=128, - layout=TensorCoreTiledLayout(inner_k_tiles=8), - use_hqq=False, - zero_point_domain=None, -): +@dataclass +class Int4WeightOnlyConfig(AOBaseWorkflowConfig): """ - Applies uint4 weight-only asymmetric per-group quantization to linear layers, using + Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel Note: @@ -711,59 +743,84 @@ def int4_weight_only( `zero_point_domain`: data type of zeros points, choices are [None(then the value is determined by the layout), ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] """ - def apply_int4_weight_only_quant(weight): - if weight.shape[-1] % group_size != 0: - logger.info( - f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" - ) - return weight + group_size: int = 128 + layout: Optional[TensorCoreTiledLayout] = TensorCoreTiledLayout(inner_k_tiles=8) + use_hqq: bool = False + zero_point_domain: Optional[ZeroPointDomain] = None - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] - zero_point_dtype = torch.bfloat16 - - nonlocal zero_point_domain - assert ( - type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() - ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - if zero_point_domain is None: - # the first value is the default one - zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] - else: - assert ( - zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] - ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" - - # Sparse Marlin only supports symmetric quantization. - # NOTE: If we start having lots of layouts that require different configurations, - # we should consider moving this logic somewhere else. - if isinstance(layout, MarlinSparseLayout): - mapping_type = MappingType.SYMMETRIC - assert ( - group_size == 128 or group_size == weight.shape[-1] - ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - _layout=layout, - use_hqq=use_hqq, +# for BC +# TODO maybe change other callsites +int4_weight_only = Int4WeightOnlyConfig + + +@register_quantize_module_handler(Int4WeightOnlyConfig) +def _int4_weight_only_transform( + module: torch.nn.Module, config: Int4WeightOnlyConfig +) -> torch.nn.Module: + # TODO(future PR): perhaps move this logic to a different file, to keep the API + # file clean of implementation details + + # for now, make these local variables to allow the rest of the function + # to be a direct copy-paste + weight = module.weight + group_size = config.group_size + layout = config.layout + use_hqq = config.use_hqq + zero_point_domain = config.zero_point_domain + + if weight.shape[-1] % group_size != 0: + logger.info( + f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) + return weight - return _get_linear_subclass_inserter(apply_int4_weight_only_quant) + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] + zero_point_dtype = torch.bfloat16 + + # nonlocal zero_point_domain + assert ( + type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() + ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" + if zero_point_domain is None: + # the first value is the default one + zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] + else: + assert ( + zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] + ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" + + # Sparse Marlin only supports symmetric quantization. + # NOTE: If we start having lots of layouts that require different configurations, + # we should consider moving this logic somewhere else. + if isinstance(layout, MarlinSparseLayout): + mapping_type = MappingType.SYMMETRIC + assert ( + group_size == 128 or group_size == weight.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" + + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight) + return module def int8_weight_only(group_size=None): From 5b9d876d7ea41db7964278c6b59b27e6b79645fb Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 10:08:28 -0800 Subject: [PATCH 2/6] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 7 +++- test/quantization/test_quant_api.py | 7 ++-- torchao/quantization/quant_api.py | 58 ++++++++-------------------- 3 files changed, 25 insertions(+), 47 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..1b4bf58cf9 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,6 +8,7 @@ run_tests, ) +from torchao.core.config import AOBaseWorkflowConfig from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, @@ -15,6 +16,7 @@ int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.utils import ( @@ -186,7 +188,10 @@ def test_flatten_unflatten(self, device, dtype): apply_quant_list = get_quantization_functions(False, True, device) for apply_quant in apply_quant_list: linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseWorkflowConfig): + quantize_(linear, apply_quant) + else: + ql = apply_quant(linear) lp_tensor = ql.weight tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() tensor_data_dict = { diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index ca2cbf08ec..80536bfac9 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -762,17 +762,16 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_int4_weight_only_numerics(self): """ Simple test of e2e int4_weight_only workflow, comparing numerics to a bfloat16 baseline. """ - # TODO(before land) skip on cpu-only - # TODO(before land) support other inference techniques? - # set up inputs x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) - # TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 + # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() m_int4_wo = copy.deepcopy(m_ref) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 450563be36..e36bc7d8e3 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -262,7 +262,8 @@ def _replace_with_custom_fn_if_matches_filter( model = replacement_fn(model, *extra_args) return model else: - for name, child in model.named_children(): + named_children_list = list(model.named_children()) + for name, child in named_children_list: new_child = _replace_with_custom_fn_if_matches_filter( child, replacement_fn, @@ -480,20 +481,19 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, - # apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], - apply_tensor_subclass: Union[ - Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig + config: Union[ + AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module] ], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[torch.types.Device] = None, ): - """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace + """Convert the weight of linear modules in the model with `config`, model is modified inplace Args: model (torch.nn.Module): input model - apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on + config (Union[AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release. + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on the weight of the module set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`. @@ -505,7 +505,7 @@ def quantize_( import torch.nn as nn from torchao import quantize_ - # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to + # quantize with some predefined `config` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) # also customizable with arguments # currently options are @@ -518,43 +518,13 @@ def quantize_( m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, int4_weight_only(group_size=32)) - # 2. write your own new apply_tensor_subclass - # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor - # on weight - - from torchao.dtypes import to_affine_quantized_intx - - # weight only uint4 asymmetric groupwise quantization - groupsize = 32 - apply_weight_quant = lambda x: to_affine_quantized_intx( - x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, - zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") - - def apply_weight_quant_to_linear(linear): - linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False) - return linear - - # apply to modules under block0 submodule - def filter_fn(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Linear) - - m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - quantize_(m, apply_weight_quant_to_linear, filter_fn) - """ if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - if isinstance(apply_tensor_subclass, AOBaseWorkflowConfig): - # new behavior - - # make the variable name make sense - config = apply_tensor_subclass + if isinstance(config, AOBaseWorkflowConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] - # for each linear in the model, apply the transform if filtering passes - # key difference from old is that `config_with_transform` is easily - # inspectable _replace_with_custom_fn_if_matches_filter( model, handler, @@ -564,8 +534,12 @@ def filter_fn(module: nn.Module, fqn: str) -> bool: ) else: - # old behavior, for now keep for BC purposes - # TODO(after discussion): flesh the BC story out more + # old behavior, keep to avoid breaking BC + warnings.warn("""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""") + + # make the variable name make sense + apply_tensor_subclass = config + _replace_with_custom_fn_if_matches_filter( model, apply_tensor_subclass, @@ -773,7 +747,7 @@ def _int4_weight_only_transform( logger.info( f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) - return weight + return module mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) From 1cea42fbd49f534c697471f9c35c424768607985 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 10:39:15 -0800 Subject: [PATCH 3/6] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 2 +- torchao/quantization/quant_api.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 1b4bf58cf9..9ef26026e2 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -189,7 +189,7 @@ def test_flatten_unflatten(self, device, dtype): for apply_quant in apply_quant_list: linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) if isinstance(apply_quant, AOBaseWorkflowConfig): - quantize_(linear, apply_quant) + quantize_(linear, apply_quant) else: ql = apply_quant(linear) lp_tensor = ql.weight diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e36bc7d8e3..efda1dbb23 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -481,9 +481,7 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, - config: Union[ - AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module] - ], + config: Union[AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[torch.types.Device] = None, @@ -535,7 +533,9 @@ def quantize_( else: # old behavior, keep to avoid breaking BC - warnings.warn("""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""") + warnings.warn( + """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""" + ) # make the variable name make sense apply_tensor_subclass = config From 138883b4f40073517c1a5a71dd87c00d33c87d43 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 12:44:06 -0800 Subject: [PATCH 4/6] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 19 +++++++++++++++---- test/hqq/test_hqq_affine.py | 7 ++++--- torchao/quantization/quant_api.py | 1 + 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 9ef26026e2..671c676e76 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -60,7 +60,8 @@ def get_quantization_functions( ) ) - if do_sparse: + # TODO(before land): revert this back, added due to lack of cuSparseLt in my env + if do_sparse and False: base_functions.append( int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) @@ -78,7 +79,8 @@ def test_tensor_core_layout_transpose(self): t = linear.weight shape = t.shape apply_int4_weight_only_quant = int4_weight_only(group_size=32) - ql = apply_int4_weight_only_quant(linear) + quantize_(linear, apply_int4_weight_only_quant) + ql = linear aqt = ql.weight aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) @@ -97,7 +99,11 @@ def test_tensor_core_layout_transpose(self): ) def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseWorkflowConfig): + quantize_(linear, apply_quant) + ql = linear + else: + ql = apply_quant(linear) with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) @@ -173,8 +179,13 @@ def apply_uint6_weight_only_quant(linear): @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): + print(apply_quant) linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseWorkflowConfig): + quantize_(linear, apply_quant) + ql = linear + else: + ql = apply_quant(linear) assert "AffineQuantizedTensor" in str(ql) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 381886d594..096c9d26ba 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -6,6 +6,7 @@ MappingType, ZeroPointDomain, int4_weight_only, + quantize_, uintx_weight_only, ) from torchao.utils import ( @@ -51,9 +52,9 @@ def _eval_hqq(dtype): ) dummy_linear.weight.data = W if dtype == torch.uint4: - q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)( - dummy_linear - ).weight + config = int4_weight_only(group_size=max(block_size), use_hqq=True) + quantize_(dummy_linear, config) + q_tensor_hqq = dummy_linear.weight else: q_tensor_hqq = uintx_weight_only( dtype, group_size=max(block_size), use_hqq=True diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index efda1dbb23..1c7284a01d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -794,6 +794,7 @@ def _int4_weight_only_transform( use_hqq=use_hqq, ) module.weight = torch.nn.Parameter(new_weight) + module.extra_repr = types.MethodType(_linear_extra_repr, module) return module From ba045ea89316a7a14b92d4849f44e9ff1ad276f5 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 12:56:28 -0800 Subject: [PATCH 5/6] Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 671c676e76..2cb87ab133 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -60,8 +60,7 @@ def get_quantization_functions( ) ) - # TODO(before land): revert this back, added due to lack of cuSparseLt in my env - if do_sparse and False: + if do_sparse: base_functions.append( int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) ) From 94d942606bcea5bad5c36b819d779deaa7c1572b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 22 Jan 2025 15:08:47 -0800 Subject: [PATCH 6/6] Update [ghstack-poisoned] --- torchao/quantization/quant_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 1c7284a01d..3401a42ab7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -793,7 +793,7 @@ def _int4_weight_only_transform( _layout=layout, use_hqq=use_hqq, ) - module.weight = torch.nn.Parameter(new_weight) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module