Skip to content

Commit

Permalink
Fixed the objective initialization issue (#1062)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Jul 17, 2023
1 parent 1044d8d commit 9d7546f
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/source/objective.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Objective

## Introduction

In terms of evaluating the status of a specific model during tuning, we should have general objectives. Intel® Neural Compressor Objective supports code-free configuration through a yaml file. With built-in objectives, users can compress models with different objectives easily. In special cases, users can also register their own objective classes.
In terms of evaluating the status of a specific model during tuning, we should have general objectives. Intel® Neural Compressor Objective supports code-free configuration through `neural_compressor.config.TuningCriterion`. With built-in objectives, users can compress models with different objectives easily. In special cases, users can also register their own objective classes.

### Single Objective

Expand Down
22 changes: 21 additions & 1 deletion neural_compressor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,29 @@ def objective(self):

@objective.setter
def objective(self, objective):
"""Set objective.
Args:
objective: objective name or list of objective names
Examples:
objective = "performance"
objective = ["performance"]
objective = ["performance", "modelsize"]
objective = {
"objective": ["performance", "modelsize"]
"weight": [0.1, 0.9]
}
"""
if isinstance(objective, list):
for val in objective:
assert _check_value('objective', val, str, ['performance', 'accuracy', 'modelsize', 'footprint'])
self._objective = objective
return

if _check_value('objective', objective, str,
['performance', 'accuracy', 'modelsize', 'footprint']):
self._objective = objective
self._objective = [objective]
return

if _check_value('objective', objective, dict):
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/strategy/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ def next_tune_cfg(self):
for op_tuning_cfg in fallback_sampler:
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
yield op_tuning_cfg
logger.warning(f"[Strategy] All tuning options for the current strategy have been tried.\
If the quantized model does not seem to work well, it might be worth considering other strategies.")

def _initial_dynamic_cfg_based_on_static_cfg(self, op_static_cfg:OpTuningConfig):
op_state = op_static_cfg.get_state()
Expand Down
43 changes: 35 additions & 8 deletions neural_compressor/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..utils import logger
from ..utils.create_obj_from_config import create_eval_func
from ..utils.utility import Statistics, fault_tolerant_file, GLOBAL_STATE, MODE, LazyImport, \
DotDict, print_table, get_weights_details, dump_table, print_op_list
DotDict, print_table, get_weights_details, dump_table, print_op_list, equal_dicts
from ..utils.weights_details import WeightsDetails
from ..version import __version__

Expand Down Expand Up @@ -412,8 +412,8 @@ def traverse(self):
traverse_start_time = time()
for op_tuning_cfg in self.next_tune_cfg():
tuning_start_time = time()
tune_cfg = self._tune_cfg_converter(op_tuning_cfg)
self.trials_count += 1
tune_cfg = self._tune_cfg_converter(op_tuning_cfg)
tuning_history = self._find_tuning_history(tune_cfg)
if tuning_history and self.trials_count < self.config.tuning_criterion.max_trials: # pragma: no cover
self.last_tune_result = tuning_history['last_tune_result']
Expand Down Expand Up @@ -919,6 +919,8 @@ def _eval_baseline(self):
def _recover_best_qmodel_from_tuning_cfg(self):
"""Recover the best quantized model from tuning config."""
if self.best_tuning_cfg and not self.best_qmodel:
logger.info(f"[Strategy] Recover the {self.best_tuning_cfg.get('trial_number', 'N/A')}-trial\
as the tuning result.")
self.best_qmodel = self.adaptor.quantize(copy.deepcopy(self.best_tuning_cfg), self.model,
self.calib_dataloader, self.q_func)

Expand Down Expand Up @@ -1137,6 +1139,7 @@ def _tune_cfg_converter(self, op_tuning_cfg):
tune_cfg['recipe_cfgs'] = tune_cfg.get('recipe_cfgs', {})
# For not tuning recipe, tune cfg use it directly
tune_cfg['recipe_cfgs'].update(self._not_tuning_recipes_values)
tune_cfg['trial_number'] = deepcopy(self.trials_count)
# WA for get the smooth quant args
if 'smooth_quant_args' in self.config.recipes:
tune_cfg['recipe_cfgs']['smooth_quant_args'] = self.config.recipes['smooth_quant_args']
Expand Down Expand Up @@ -1275,11 +1278,17 @@ def _set_framework_info(self, q_dataloader, q_func=None):

def _set_objectives(self):
# set objectives
def _use_multi_obj_check(obj):
if isinstance(obj, list):
return len(obj) > 1
elif isinstance(obj, dict):
return len(obj.get('objective', [])) > 1

self.higher_is_better = bool(self.config.accuracy_criterion.higher_is_better)
obj_higher_is_better = None
obj_weight = None
obj = self.config.tuning_criterion.objective
use_multi_objs = isinstance(obj, dict)
use_multi_objs = _use_multi_obj_check(obj)
self.use_multi_objective = False
if use_multi_objs:
obj_higher_is_better = obj.get('higher_is_better', None)
Expand All @@ -1288,7 +1297,7 @@ def _set_objectives(self):
objectives = [i.lower() for i in obj_lst]
self.use_multi_objective = True
else:
objectives = [obj.lower()]
objectives = [val.lower() for val in obj]

# set metric
self.metric_name = ['Accuracy']
Expand Down Expand Up @@ -1326,7 +1335,7 @@ def _set_objectives(self):
def _same_conf(self, src_conf, dst_conf):
"""Check if the two configs are the same."""
from ..utils.utility import compare_objects
return compare_objects(src_conf, dst_conf, {'_options', '_tuning', '_accuracy'})
return compare_objects(src_conf, dst_conf, {'_options', '_tuning', '_accuracy', 'trial_number'})

def update_best_op_tuning_cfg(self, op_tuning_cfg):
"""Track and update the best tuning config with correspondence accuracy result.
Expand Down Expand Up @@ -1595,13 +1604,31 @@ def stop(self, timeout, trials_count):
header='Tune Result Statistics',
field_names=['Info Type', 'Baseline', 'Tune {} result'.format(self.trials_count), \
'Best tune result']).print_stat()


# exit policy
# 1. not_tuning(performance_only): only quantize the model without tuning or evaluation.
# 2. timeout = 0, exit the tuning process once it is found model meets the accuracy requirement.
# 3. max_trials, the number of the actually trials is less or equal to the max_trials
# There are two ways to use max_trials to dominate the exit policy.
# 1) timeout = 0, the tuning process exit when the actual_trails_count >= max_trials or
# a quantized model meets the accuracy requirements
# 2) timeout = inf, the tuning process exit until the trials_count >= max_trials
# Some use case:
# 1) Ending tuning process after a quantized model meets the accuracy requirements
# max_trials = inf, timeout = 0 (by default) # the default max_trials is 100
# value of timeout. max_trials control the exit policy
# 2) Even after finding a model that meets the accuracy goal, we may want to continue the
# tuning process for better performance or other objectives.
# timeout = 100000, max_trials = 10 # Specifics a fairly large timeout, use max_trials
# # to control the exit policy.
# 3) Only want to try a certain number of trials
# timeout = 100000, max_trials = 3 # only want to try the first 3 trials
if self._not_tuning:
need_stop = True
elif timeout == 0 and self.best_tune_result:
logger.info("[Strategy] Found a model that meets the accuracy requirements.")
need_stop = True
elif self.trials_count >= self.config.tuning_criterion.max_trials:
logger.info("[Strategy] The number of trials is equal to the maximum trials, ending the tuning process.")
need_stop = True
else:
need_stop = False
Expand All @@ -1628,7 +1655,7 @@ def _find_tuning_history(self, tune_cfg):
# some fields in tuning section of config, such as tensorboard, snapshot, resume.
if self._same_conf(tuning_history['cfg'], self.conf):
for history in tuning_history['history']:
if history and history['tune_cfg'] == tune_cfg:
if history and equal_dicts(history['tune_cfg'], tune_cfg, ignore_keys=['trial_number']):
return tuning_history

return None
Expand Down
92 changes: 92 additions & 0 deletions test/strategy/test_quant_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,26 @@ def build_fake_model():
return graph


def get_torch_demo_model():
import torch
class DemoModel(torch.nn.Module):
def __init__(self):
super(DemoModel, self).__init__()
self.fc1 = torch.nn.Linear(3, 3)
self.fc2 = torch.nn.Linear(3, 3)
self.fc3 = torch.nn.Linear(3, 3)
self.fc4 = torch.nn.Linear(3, 3)
self.fc5 = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
x = self.fc5(x)
return x
return DemoModel()

class TestQuantLevel(unittest.TestCase):

@classmethod
Expand Down Expand Up @@ -324,6 +344,78 @@ def _fake_eval(model):
eval_func=_fake_eval)
self.assertIsNone(q_model)

def test_pt_quant_level_1_with_perf_obj(self):
logger.info("*** Test: quantization level 1 with perf obj [pytorch model].")
from neural_compressor.quantization import fit
from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion
from neural_compressor.data import Datasets, DATALOADERS
import time

# model
model = get_torch_demo_model()

# fake evaluation function
acc_lst = [2.0, 1.0, 2.1, 2.2, 2.3, 2.1, 2.1, 2.2]
perf_lst = [2.0, 1.5, 1.0, 0.5, 0.1, 1.0, 1.0, 1.0]
self._internal_index = -1
def _fake_eval(model):
self._internal_index += 1
perf = perf_lst[self._internal_index]
time.sleep(perf)
return acc_lst[self._internal_index]

# dataset and dataloader
dataset = Datasets("pytorch")["dummy"](((16, 2, 3)))
dataloader = DATALOADERS["pytorch"](dataset)

tuning_criterion = TuningCriterion(timeout=10000, max_trials=6, objective='performance')
conf = PostTrainingQuantConfig(quant_level=1, tuning_criterion=tuning_criterion)

# fit
q_model = fit(model=model,
conf=conf,
calib_dataloader= dataloader,
eval_dataloader=dataloader,
eval_func=_fake_eval)
self.assertIsNotNone(q_model)
self.assertEqual(q_model.q_config.get('trial_number', -1), 4)

def test_pt_quant_level_1_with_perf_obj2(self):
logger.info("*** Test: quantization level 1 with perf obj [pytorch model].")
from neural_compressor.quantization import fit
from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion
from neural_compressor.data import Datasets, DATALOADERS
import time

# model
model = get_torch_demo_model()

# fake evaluation function
acc_lst = [2.0, 1.0, 2.1, 2.2, 2.3, 2.1, 2.1, 2.2]
perf_lst = [2.0, 1.5, 1.0, 0.5, 0.1, 1.0, 1.0, 1.0]
self._internal_index = -1
def _fake_eval(model):
self._internal_index += 1
perf = perf_lst[self._internal_index]
time.sleep(perf)
return acc_lst[self._internal_index]

# dataset and dataloader
dataset = Datasets("pytorch")["dummy"](((16, 2, 3)))
dataloader = DATALOADERS["pytorch"](dataset)

tuning_criterion = TuningCriterion(timeout=10000, max_trials=6, objective=['performance'])
conf = PostTrainingQuantConfig(quant_level=1, tuning_criterion=tuning_criterion)

# fit
q_model = fit(model=model,
conf=conf,
calib_dataloader= dataloader,
eval_dataloader=dataloader,
eval_func=_fake_eval)
self.assertIsNotNone(q_model)
self.assertEqual(q_model.q_config.get('trial_number', -1), 4)

def test_pt_quant_level_0(self):
logger.info("*** Test: quantization level 0 with pytorch model.")
from neural_compressor.quantization import fit
Expand Down

0 comments on commit 9d7546f

Please sign in to comment.