From dbf1381fb306bff017f90bea67af481d87a11877 Mon Sep 17 00:00:00 2001 From: "Cheng, Penghui" Date: Sun, 12 Mar 2023 01:16:22 +0800 Subject: [PATCH] Adapt to PyTorch 2.0 version (#627) Signed-off-by: Cheng, Penghui --- neural_compressor/adaptor/pytorch.py | 21 ++++++++--- neural_compressor/model/torch_model.py | 37 +++++++++++-------- neural_compressor/training.py | 5 ++- neural_compressor/utils/pytorch.py | 26 +++++++------ neural_compressor/utils/utility.py | 2 +- .../test_adaptor_pytorch_1.x.py | 6 ++- .../test_adaptor_pytorch_2.x.py | 2 - 7 files changed, 62 insertions(+), 37 deletions(-) diff --git a/neural_compressor/adaptor/pytorch.py b/neural_compressor/adaptor/pytorch.py index 59e7b067443..e195ea06748 100644 --- a/neural_compressor/adaptor/pytorch.py +++ b/neural_compressor/adaptor/pytorch.py @@ -1305,6 +1305,11 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): del op_cfgs['bf16_ops_list'] gc.collect() + if self.version.release < Version("2.0.0").release: + from torch.quantization.quantize import add_observer_ + else: + from torch.quantization.quantize import _add_observer_ as add_observer_ + if self.performance_only: q_model = model else: @@ -1329,7 +1334,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): "by assigning the `.qconfig` attribute directly on submodules.") if self.approach in ['post_training_static_quant', 'post_training_auto_quant']: - torch.quantization.add_observer_(q_model._model) + add_observer_(q_model._model) if q_func is None: iterations = tune_cfg.get('calib_iteration', 1) self.model_calibration(q_model._model, @@ -1346,10 +1351,10 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None): inplace=True, remove_qconfig=False) _propagate_qconfig(q_model._model, op_cfgs) - torch.quantization.add_observer_(q_model._model, self.white_list, + add_observer_(q_model._model, self.white_list, set(self.q_mapping.values())) else: # pragma: no cover - torch.quantization.add_observer_(q_model._model) + add_observer_(q_model._model) torch.quantization.convert(q_model._model, self.q_mapping, inplace=True) # q_func can be created by neural_compressor internal or passed by user. It's critical to # distinguish how q_func is passed since neural_compressor built-in functions accept neural_compressor @@ -1926,7 +1931,10 @@ def _post_eval_hook(self, model, **args): None """ from torch.utils.tensorboard import SummaryWriter - from torch.quantization import get_observer_dict + if self.version.release >= Version("2.0.0").release: + from torch.quantization.quantize import _get_observer_dict as get_observer_dict + else: + from torch.quantization import get_observer_dict model = model._model @@ -2052,7 +2060,10 @@ def inspect_tensor(self, observer_dict = {} ret = {} if inspect_type == 'activation' or inspect_type == 'all': - from torch.quantization import get_observer_dict + if self.version.release >= Version("2.0.0").release: + from torch.quantization.quantize import _get_observer_dict as get_observer_dict + else: + from torch.quantization import get_observer_dict ret['activation'] = [] get_observer_dict(new_model._model, observer_dict) if iteration_list is None: diff --git a/neural_compressor/model/torch_model.py b/neural_compressor/model/torch_model.py index f1cbabfa832..9380e787832 100644 --- a/neural_compressor/model/torch_model.py +++ b/neural_compressor/model/torch_model.py @@ -130,9 +130,17 @@ def actual_forward_pre_hook(module, input): # intersection update kw arguments self.input_args.update(values['kwargs']) # update arguments - for (single_input, single_arg) in zip(values['input'], - list(self.input_args.keys())[:len(values['input'])]): - self.input_args[single_arg] = single_input + if "input" in values: + for (single_input, single_arg) in \ + zip(values['input'], list(self.input_args.keys())[:len(values['input'])]): + self.input_args[single_arg] = single_input + elif "args" in values: + for (single_input, single_arg) in \ + zip(values['args'], list(self.input_args.keys())[:len(values['args'])]): + self.input_args[single_arg] = single_input + else: + assert False, "there is no input field was found!" + return actual_forward_pre_hook def framework(self): @@ -161,7 +169,7 @@ def update_weights(self, tensor_name, new_tensor): new_tensor (ndarray): weight value. """ # TODO: copy tensor option to new tensor is better - device = next(self._model.parameters()).device + device = next(self._model.parameters()).device new_tensor = torch.tensor(new_tensor).float().to(device) module_index = '.'.join(tensor_name.split('.')[:-1]) module = dict(self._model.named_modules())[module_index] @@ -421,10 +429,10 @@ def export_to_fp32_onnx( logger.info(info) logger.info("*"*len(info)) - def export_to_bf16_onnx(self, - save_path='bf16-model.onnx', + def export_to_bf16_onnx(self, + save_path='bf16-model.onnx', example_inputs = torch.rand([1, 1, 1, 1]), - opset_version=14, + opset_version=14, dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, input_names=None, @@ -507,7 +515,7 @@ def export_to_int8_onnx( dtype='S8S8', fp32_model=None, calib_dataloader=None, - ): + ): """Export PyTorch int8 model to ONNX int8 model. Args: @@ -537,7 +545,7 @@ def export_to_int8_onnx( elif 'U8S8' in dtype: activation_type = ortq.QuantType.QUInt8 weight_type = ortq.QuantType.QInt8 - else: # pragma: no cover + else: # pragma: no cover # Gather requires weight type be the same as activation. # So U8S8(acitvation|weight) option is not workable for best performance. logger.error("Right now, we don't support dtype: {}, \ @@ -558,17 +566,17 @@ def export_to_int8_onnx( pytorch_op_types_to_quantize=['Linear', 'Embedding', 'Conv1d', 'Conv2d'] addition_op_to_quantize = [] - if quant_format == 'QDQ' and opset_version < 13: # pragma: no cover + if quant_format == 'QDQ' and opset_version < 13: # pragma: no cover opset_version = 13 - logger.warning("QDQ format requires opset_version >= 13, " + - "we reset opset_version={} here".format(opset_version)) + logger.warning("QDQ format requires opset_version >= 13, " + + "we reset opset_version={} here".format(opset_version)) all_op_types_to_quantize = op_types_to_quantize + addition_op_to_quantize # pylint: disable=E1101 fp32_path = save_path + '.tmp' if save_path else 'int8-model.onnx.tmp' self.export_to_fp32_onnx( save_path=fp32_path, - example_inputs = example_inputs, + example_inputs=example_inputs, opset_version=opset_version, dynamic_axes=dynamic_axes, input_names=input_names, @@ -710,7 +718,7 @@ def export( ): """Export PyTorch model to ONNX model.""" from neural_compressor.experimental.export import ( - torch_to_fp32_onnx, + torch_to_fp32_onnx, torch_to_int8_onnx ) if conf.dtype == 'int8': @@ -796,4 +804,3 @@ def save(self, root=None): if isinstance(self.model, torch.jit._script.RecursiveScriptModule): self.model.save(os.path.join(root, "best_model.pt")) - diff --git a/neural_compressor/training.py b/neural_compressor/training.py index b79ab822c4e..76ed30bb797 100644 --- a/neural_compressor/training.py +++ b/neural_compressor/training.py @@ -345,14 +345,15 @@ def __init__(self, callbacks_list): self.callbacks_list = callbacks_list def on_train_begin(self, dataloader=None): - """Be called before the beginning of epochs.""" + """Be called before the beginning of training.""" for callbacks in self.callbacks_list: callbacks.on_train_begin(dataloader) def on_train_end(self): - """Be called after the end of epochs.""" + """Be called after the end of training.""" for callbacks in self.callbacks_list: callbacks.on_train_end() + logger.info("Training finished!") def on_epoch_begin(self, epoch): """Be called on the beginning of epochs.""" diff --git a/neural_compressor/utils/pytorch.py b/neural_compressor/utils/pytorch.py index cd09dadc30b..8bbe0a5dae6 100644 --- a/neural_compressor/utils/pytorch.py +++ b/neural_compressor/utils/pytorch.py @@ -23,12 +23,12 @@ from ..adaptor.torch_utils import util from . import logger from packaging.version import Version -from torch.quantization import add_observer_, convert +from torch.quantization import convert import torch import torch.quantization as tq import yaml import os -import copy + yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple', lambda loader, node: tuple(loader.construct_sequence(node))) @@ -108,7 +108,7 @@ def _load_int8_orchestration(model, tune_cfg, stat_dict, example_inputs, **kwarg from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx quantized_ops = {op[0]: q_cfgs for op in tune_cfg['quantizable_ops']} version = get_torch_version() - if version < Version("1.11.0-rc1"): + if version.release < Version("1.11.0").release: quantized_ops["default_qconfig"] = None else: from torch.ao.quantization import default_embedding_qat_qconfig @@ -238,19 +238,19 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): op_cfg['activation']['quant_mode'] = approach_quant_mode if tune_cfg['approach'] != "post_training_dynamic_quant": - if version < Version("1.7.0-rc1"): # pragma: no cover + if version.release < Version("1.7.0").release: # pragma: no cover q_mapping = tq.default_mappings.DEFAULT_MODULE_MAPPING - elif version < Version("1.8.0-rc1"): # pragma: no cover + elif version.release < Version("1.8.0").release: # pragma: no cover q_mapping = \ tq.quantization_mappings.get_static_quant_module_mappings() else: q_mapping = \ tq.quantization_mappings.get_default_static_quant_module_mappings() else: - if version < Version("1.7.0-rc1"): # pragma: no cover + if version.release < Version("1.7.0").release: # pragma: no cover q_mapping = \ tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING - elif version < Version("1.8.0-rc1"): # pragma: no cover + elif version.release < Version("1.8.0").release: # pragma: no cover q_mapping = \ tq.quantization_mappings.get_dynamic_quant_module_mappings() else: @@ -259,7 +259,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): if tune_cfg['framework'] == "pytorch_fx": # pragma: no cover # For torch.fx approach - assert version >= Version("1.8.0-rc1"), \ + assert version.release >= Version("1.8.0").release, \ "Please use PyTroch 1.8 or higher version with pytorch_fx backend" from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx if kwargs is None: @@ -275,7 +275,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): tmp_model = model if tune_cfg['approach'] == "quant_aware_training": model.train() - if version > Version("1.12.1"): # pragma: no cover + if version.release > Version("1.12.1").release: # pragma: no cover # pylint: disable=E1123 model = prepare_qat_fx(model, fx_op_cfgs, @@ -286,7 +286,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): fx_op_cfgs, prepare_custom_config_dict=prepare_custom_config_dict) else: - if version > Version("1.12.1"): # pragma: no cover + if version.release > Version("1.12.1").release: # pragma: no cover # pylint: disable=E1123 model = prepare_fx(model, fx_op_cfgs, @@ -296,7 +296,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): model = prepare_fx(model, fx_op_cfgs, prepare_custom_config_dict=prepare_custom_config_dict) - if version > Version("1.12.1"): # pragma: no cover + if version.release > Version("1.12.1").release: # pragma: no cover # pylint: disable=E1123 model = convert_fx(model, convert_custom_config=convert_custom_config_dict) @@ -335,6 +335,10 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs): "passed correct configuration through `qconfig_dict` or " "by assigning the `.qconfig` attribute directly on submodules") if tune_cfg['approach'] != "post_training_dynamic_quant": + if version.release < Version("2.0.0").release: + from torch.quantization.quantize import add_observer_ + else: + from torch.quantization.quantize import _add_observer_ as add_observer_ add_observer_(model) model = convert(model, mapping=q_mapping, inplace=True) diff --git a/neural_compressor/utils/utility.py b/neural_compressor/utils/utility.py index 996756d90eb..4370d826045 100644 --- a/neural_compressor/utils/utility.py +++ b/neural_compressor/utils/utility.py @@ -80,7 +80,7 @@ class LazyImport(object): def __init__(self, module_name): """Init LazyImport object. - + Args: module_name (string): The name of module imported later """ diff --git a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py index da7ac4a42b0..71e411d44cf 100644 --- a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py +++ b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py @@ -796,7 +796,11 @@ def forward(self, x): model.model.dequant.qconfig = torch.quantization.default_qconfig nc_torch._fallback_quantizable_ops_recursively( model.model, '', fallback_ops, op_qcfgs={}) - torch.quantization.add_observer_(model.model) + if PT_VERSION >= Version("2.0.0").release: + from torch.quantization.quantize import _add_observer_ as add_observer_ + else: + from torch.quantization.quantize import add_observer_ + add_observer_(model.model) model.model(x) torch.quantization.convert(model.model, self.adaptor.q_mapping, inplace=True) qy = model.model(x) diff --git a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2.x.py b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2.x.py index 6db13c8f237..5a78bb78636 100644 --- a/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2.x.py +++ b/test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2.x.py @@ -5,10 +5,8 @@ import torch import torch.nn as nn import unittest -import os from neural_compressor import PostTrainingQuantConfig, QuantizationAwareTrainingConfig, set_workspace from neural_compressor.data import Datasets, DATALOADERS, DataLoader -from neural_compressor.experimental.data.datasets.dataset import Datasets from neural_compressor import quantization from neural_compressor.training import prepare_compression, fit from neural_compressor.utils.pytorch import load