Skip to content

Commit

Permalink
Adapt to PyTorch 2.0 version (#627)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Penghui <[email protected]>
  • Loading branch information
PenghuiCheng authored Mar 11, 2023
1 parent e049d0b commit dbf1381
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 37 deletions.
21 changes: 16 additions & 5 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,6 +1305,11 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
del op_cfgs['bf16_ops_list']
gc.collect()

if self.version.release < Version("2.0.0").release:
from torch.quantization.quantize import add_observer_
else:
from torch.quantization.quantize import _add_observer_ as add_observer_

if self.performance_only:
q_model = model
else:
Expand All @@ -1329,7 +1334,7 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
"by assigning the `.qconfig` attribute directly on submodules.")

if self.approach in ['post_training_static_quant', 'post_training_auto_quant']:
torch.quantization.add_observer_(q_model._model)
add_observer_(q_model._model)
if q_func is None:
iterations = tune_cfg.get('calib_iteration', 1)
self.model_calibration(q_model._model,
Expand All @@ -1346,10 +1351,10 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
inplace=True,
remove_qconfig=False)
_propagate_qconfig(q_model._model, op_cfgs)
torch.quantization.add_observer_(q_model._model, self.white_list,
add_observer_(q_model._model, self.white_list,
set(self.q_mapping.values()))
else: # pragma: no cover
torch.quantization.add_observer_(q_model._model)
add_observer_(q_model._model)
torch.quantization.convert(q_model._model, self.q_mapping, inplace=True)
# q_func can be created by neural_compressor internal or passed by user. It's critical to
# distinguish how q_func is passed since neural_compressor built-in functions accept neural_compressor
Expand Down Expand Up @@ -1926,7 +1931,10 @@ def _post_eval_hook(self, model, **args):
None
"""
from torch.utils.tensorboard import SummaryWriter
from torch.quantization import get_observer_dict
if self.version.release >= Version("2.0.0").release:
from torch.quantization.quantize import _get_observer_dict as get_observer_dict
else:
from torch.quantization import get_observer_dict

model = model._model

Expand Down Expand Up @@ -2052,7 +2060,10 @@ def inspect_tensor(self,
observer_dict = {}
ret = {}
if inspect_type == 'activation' or inspect_type == 'all':
from torch.quantization import get_observer_dict
if self.version.release >= Version("2.0.0").release:
from torch.quantization.quantize import _get_observer_dict as get_observer_dict
else:
from torch.quantization import get_observer_dict
ret['activation'] = []
get_observer_dict(new_model._model, observer_dict)
if iteration_list is None:
Expand Down
37 changes: 22 additions & 15 deletions neural_compressor/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,17 @@ def actual_forward_pre_hook(module, input):
# intersection update kw arguments
self.input_args.update(values['kwargs'])
# update arguments
for (single_input, single_arg) in zip(values['input'],
list(self.input_args.keys())[:len(values['input'])]):
self.input_args[single_arg] = single_input
if "input" in values:
for (single_input, single_arg) in \
zip(values['input'], list(self.input_args.keys())[:len(values['input'])]):
self.input_args[single_arg] = single_input
elif "args" in values:
for (single_input, single_arg) in \
zip(values['args'], list(self.input_args.keys())[:len(values['args'])]):
self.input_args[single_arg] = single_input
else:
assert False, "there is no input field was found!"

return actual_forward_pre_hook

def framework(self):
Expand Down Expand Up @@ -161,7 +169,7 @@ def update_weights(self, tensor_name, new_tensor):
new_tensor (ndarray): weight value.
"""
# TODO: copy tensor option to new tensor is better
device = next(self._model.parameters()).device
device = next(self._model.parameters()).device
new_tensor = torch.tensor(new_tensor).float().to(device)
module_index = '.'.join(tensor_name.split('.')[:-1])
module = dict(self._model.named_modules())[module_index]
Expand Down Expand Up @@ -421,10 +429,10 @@ def export_to_fp32_onnx(
logger.info(info)
logger.info("*"*len(info))

def export_to_bf16_onnx(self,
save_path='bf16-model.onnx',
def export_to_bf16_onnx(self,
save_path='bf16-model.onnx',
example_inputs = torch.rand([1, 1, 1, 1]),
opset_version=14,
opset_version=14,
dynamic_axes={"input": {0: "batch_size"},
"output": {0: "batch_size"}},
input_names=None,
Expand Down Expand Up @@ -507,7 +515,7 @@ def export_to_int8_onnx(
dtype='S8S8',
fp32_model=None,
calib_dataloader=None,
):
):
"""Export PyTorch int8 model to ONNX int8 model.
Args:
Expand Down Expand Up @@ -537,7 +545,7 @@ def export_to_int8_onnx(
elif 'U8S8' in dtype:
activation_type = ortq.QuantType.QUInt8
weight_type = ortq.QuantType.QInt8
else: # pragma: no cover
else: # pragma: no cover
# Gather requires weight type be the same as activation.
# So U8S8(acitvation|weight) option is not workable for best performance.
logger.error("Right now, we don't support dtype: {}, \
Expand All @@ -558,17 +566,17 @@ def export_to_int8_onnx(
pytorch_op_types_to_quantize=['Linear', 'Embedding', 'Conv1d', 'Conv2d']
addition_op_to_quantize = []

if quant_format == 'QDQ' and opset_version < 13: # pragma: no cover
if quant_format == 'QDQ' and opset_version < 13: # pragma: no cover
opset_version = 13
logger.warning("QDQ format requires opset_version >= 13, " +
"we reset opset_version={} here".format(opset_version))
logger.warning("QDQ format requires opset_version >= 13, " +
"we reset opset_version={} here".format(opset_version))
all_op_types_to_quantize = op_types_to_quantize + addition_op_to_quantize

# pylint: disable=E1101
fp32_path = save_path + '.tmp' if save_path else 'int8-model.onnx.tmp'
self.export_to_fp32_onnx(
save_path=fp32_path,
example_inputs = example_inputs,
example_inputs=example_inputs,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
input_names=input_names,
Expand Down Expand Up @@ -710,7 +718,7 @@ def export(
):
"""Export PyTorch model to ONNX model."""
from neural_compressor.experimental.export import (
torch_to_fp32_onnx,
torch_to_fp32_onnx,
torch_to_int8_onnx
)
if conf.dtype == 'int8':
Expand Down Expand Up @@ -796,4 +804,3 @@ def save(self, root=None):

if isinstance(self.model, torch.jit._script.RecursiveScriptModule):
self.model.save(os.path.join(root, "best_model.pt"))

5 changes: 3 additions & 2 deletions neural_compressor/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,15 @@ def __init__(self, callbacks_list):
self.callbacks_list = callbacks_list

def on_train_begin(self, dataloader=None):
"""Be called before the beginning of epochs."""
"""Be called before the beginning of training."""
for callbacks in self.callbacks_list:
callbacks.on_train_begin(dataloader)

def on_train_end(self):
"""Be called after the end of epochs."""
"""Be called after the end of training."""
for callbacks in self.callbacks_list:
callbacks.on_train_end()
logger.info("Training finished!")

def on_epoch_begin(self, epoch):
"""Be called on the beginning of epochs."""
Expand Down
26 changes: 15 additions & 11 deletions neural_compressor/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from ..adaptor.torch_utils import util
from . import logger
from packaging.version import Version
from torch.quantization import add_observer_, convert
from torch.quantization import convert
import torch
import torch.quantization as tq
import yaml
import os
import copy


yaml.SafeLoader.add_constructor('tag:yaml.org,2002:python/tuple',
lambda loader, node: tuple(loader.construct_sequence(node)))
Expand Down Expand Up @@ -108,7 +108,7 @@ def _load_int8_orchestration(model, tune_cfg, stat_dict, example_inputs, **kwarg
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
quantized_ops = {op[0]: q_cfgs for op in tune_cfg['quantizable_ops']}
version = get_torch_version()
if version < Version("1.11.0-rc1"):
if version.release < Version("1.11.0").release:
quantized_ops["default_qconfig"] = None
else:
from torch.ao.quantization import default_embedding_qat_qconfig
Expand Down Expand Up @@ -238,19 +238,19 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
op_cfg['activation']['quant_mode'] = approach_quant_mode

if tune_cfg['approach'] != "post_training_dynamic_quant":
if version < Version("1.7.0-rc1"): # pragma: no cover
if version.release < Version("1.7.0").release: # pragma: no cover
q_mapping = tq.default_mappings.DEFAULT_MODULE_MAPPING
elif version < Version("1.8.0-rc1"): # pragma: no cover
elif version.release < Version("1.8.0").release: # pragma: no cover
q_mapping = \
tq.quantization_mappings.get_static_quant_module_mappings()
else:
q_mapping = \
tq.quantization_mappings.get_default_static_quant_module_mappings()
else:
if version < Version("1.7.0-rc1"): # pragma: no cover
if version.release < Version("1.7.0").release: # pragma: no cover
q_mapping = \
tq.default_mappings.DEFAULT_DYNAMIC_MODULE_MAPPING
elif version < Version("1.8.0-rc1"): # pragma: no cover
elif version.release < Version("1.8.0").release: # pragma: no cover
q_mapping = \
tq.quantization_mappings.get_dynamic_quant_module_mappings()
else:
Expand All @@ -259,7 +259,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):

if tune_cfg['framework'] == "pytorch_fx": # pragma: no cover
# For torch.fx approach
assert version >= Version("1.8.0-rc1"), \
assert version.release >= Version("1.8.0").release, \
"Please use PyTroch 1.8 or higher version with pytorch_fx backend"
from torch.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx
if kwargs is None:
Expand All @@ -275,7 +275,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
tmp_model = model
if tune_cfg['approach'] == "quant_aware_training":
model.train()
if version > Version("1.12.1"): # pragma: no cover
if version.release > Version("1.12.1").release: # pragma: no cover
# pylint: disable=E1123
model = prepare_qat_fx(model,
fx_op_cfgs,
Expand All @@ -286,7 +286,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
fx_op_cfgs,
prepare_custom_config_dict=prepare_custom_config_dict)
else:
if version > Version("1.12.1"): # pragma: no cover
if version.release > Version("1.12.1").release: # pragma: no cover
# pylint: disable=E1123
model = prepare_fx(model,
fx_op_cfgs,
Expand All @@ -296,7 +296,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
model = prepare_fx(model,
fx_op_cfgs,
prepare_custom_config_dict=prepare_custom_config_dict)
if version > Version("1.12.1"): # pragma: no cover
if version.release > Version("1.12.1").release: # pragma: no cover
# pylint: disable=E1123
model = convert_fx(model,
convert_custom_config=convert_custom_config_dict)
Expand Down Expand Up @@ -335,6 +335,10 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")
if tune_cfg['approach'] != "post_training_dynamic_quant":
if version.release < Version("2.0.0").release:
from torch.quantization.quantize import add_observer_
else:
from torch.quantization.quantize import _add_observer_ as add_observer_
add_observer_(model)
model = convert(model, mapping=q_mapping, inplace=True)

Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class LazyImport(object):

def __init__(self, module_name):
"""Init LazyImport object.
Args:
module_name (string): The name of module imported later
"""
Expand Down
6 changes: 5 additions & 1 deletion test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,11 @@ def forward(self, x):
model.model.dequant.qconfig = torch.quantization.default_qconfig
nc_torch._fallback_quantizable_ops_recursively(
model.model, '', fallback_ops, op_qcfgs={})
torch.quantization.add_observer_(model.model)
if PT_VERSION >= Version("2.0.0").release:
from torch.quantization.quantize import _add_observer_ as add_observer_
else:
from torch.quantization.quantize import add_observer_
add_observer_(model.model)
model.model(x)
torch.quantization.convert(model.model, self.adaptor.q_mapping, inplace=True)
qy = model.model(x)
Expand Down
2 changes: 0 additions & 2 deletions test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2.x.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import torch
import torch.nn as nn
import unittest
import os
from neural_compressor import PostTrainingQuantConfig, QuantizationAwareTrainingConfig, set_workspace
from neural_compressor.data import Datasets, DATALOADERS, DataLoader
from neural_compressor.experimental.data.datasets.dataset import Datasets
from neural_compressor import quantization
from neural_compressor.training import prepare_compression, fit
from neural_compressor.utils.pytorch import load
Expand Down

0 comments on commit dbf1381

Please sign in to comment.