Skip to content

Commit

Permalink
fix ipex prepare model unsupport deepcopy issue and add ut (#1174)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss authored Aug 23, 2022
1 parent 135e52f commit cc368a8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
13 changes: 9 additions & 4 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2362,8 +2362,8 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
assert isinstance(model, torch.nn.Module), \
"The model passed in is not the instance of torch.nn.Module"

model_ = copy.deepcopy(model)
if not IPEX_110 and not IPEX_112:
model_ = copy.deepcopy(model)
model_.eval().to(ipex.DEVICE)
try:
init_model = torch.jit.script(model_)
Expand All @@ -2377,10 +2377,15 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
"Fail to convert this model to PyTorch Script model"
)
init_model = model_
elif IPEX_110:
init_model = copy.deepcopy(model)
init_model.eval()
else:
model_.eval()
init_model = model_

if hasattr(model,'save_qconf_summary'):
init_model = ipex.quantization._quantize_utils.copy_prepared_model(model)
else:
init_model = copy.deepcopy(model)
init_model.eval()
# create a quantization config file for intel pytorch extension model
os.makedirs(os.path.dirname(self.ipex_config_path), exist_ok=True)
if not IPEX_110 and not IPEX_112:
Expand Down
32 changes: 25 additions & 7 deletions test/ipex/test_adaptor_ipex.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import torch
import unittest
import os
from neural_compressor.adaptor import FRAMEWORKS
from neural_compressor.model import MODELS
from neural_compressor.adaptor.pytorch import PyTorchVersionMode
import neural_compressor.adaptor.pytorch as nc_torch
from neural_compressor.experimental import Quantization, common
from neural_compressor.conf.config import QuantConf
from neural_compressor.utils.pytorch import load
from neural_compressor.utils.utility import recover
import shutil
import copy
import numpy as np
Expand All @@ -20,6 +15,7 @@
except:
TEST_IPEX = False

torch.manual_seed(9527)
assert TEST_IPEX, "Please install intel extension for pytorch"
# get torch and IPEX version
PT_VERSION = nc_torch.get_torch_version()
Expand Down Expand Up @@ -89,7 +85,7 @@ def test_tuning_ipex(self):
ipex_conf = ipex.quantization.QuantConf(
configure_file="./saved/best_configure.json",
)
q_model = ipex.quantization.convert(model, ipex_conf, torch.randn(1, 3, 224, 224))
q_model = ipex.quantization.convert(model, ipex_conf, torch.ones(1, 3, 224, 224))
from neural_compressor.experimental import Benchmark
evaluator = Benchmark('ipex_yaml.yaml')
evaluator.model = q_model
Expand Down Expand Up @@ -121,7 +117,7 @@ def test_tuning_ipex(self):
nc_model = quantizer.fit()
nc_model.save('./saved')
qconfig = ipex.quantization.default_static_qconfig
prepared_model = ipex.quantization.prepare(model, qconfig, example_inputs=torch.randn(1, 3, 224, 224), inplace=False)
prepared_model = ipex.quantization.prepare(model, qconfig, example_inputs=torch.ones(1, 3, 224, 224), inplace=False)
prepared_model.load_qconf_summary(qconf_summary = "./saved/best_configure.json")
convert_model = ipex.quantization.convert(prepared_model)
from neural_compressor.experimental import Benchmark
Expand All @@ -130,5 +126,27 @@ def test_tuning_ipex(self):
evaluator.b_dataloader = common.DataLoader(dataset)
evaluator.fit('accuracy')

def test_tuning_ipex_for_ipex_autotune_func(self):
from neural_compressor.experimental import Quantization
model = M()
qconfig = ipex.quantization.default_static_qconfig
prepared_model = ipex.quantization.prepare(model, qconfig, example_inputs=torch.ones(1, 3, 224, 224), inplace=False)
quantizer = Quantization('ipex_yaml.yaml')
quantizer.conf.usr_cfg.tuning.exit_policy['performance_only'] = True
dataset = quantizer.dataset('dummy', (100, 3, 224, 224), label=True)
quantizer.model = prepared_model
quantizer.calib_dataloader = common.DataLoader(dataset)
quantizer.eval_dataloader = common.DataLoader(dataset)
nc_model = quantizer.fit()
nc_model.save('./saved')
qconfig = ipex.quantization.default_static_qconfig
prepared_model = ipex.quantization.prepare(model, qconfig, example_inputs=torch.ones(1, 3, 224, 224), inplace=False)
prepared_model.load_qconf_summary(qconf_summary = "./saved/best_configure.json")
convert_model = ipex.quantization.convert(prepared_model)
from neural_compressor.experimental import Benchmark
evaluator = Benchmark('ipex_yaml.yaml')
evaluator.model = convert_model
evaluator.b_dataloader = common.DataLoader(dataset)
evaluator.fit('accuracy')
if __name__ == "__main__":
unittest.main()

0 comments on commit cc368a8

Please sign in to comment.