Skip to content

Commit

Permalink
Tuning recipe (#570)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 authored Feb 26, 2023
1 parent 108c245 commit 44d1761
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 105 deletions.
52 changes: 34 additions & 18 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,10 @@ def query_fw_capability(self, model):
"""
# optype_wise and op_wise capability
self._pre_optimize(model)
recipes_ops = {}
recipes_ops['first_conv_or_matmul_quantization'] = []
recipes_ops['last_conv_or_matmul_quantization'] = []
recipes_ops['pre_post_process_quantization'] = []
exclude_first_quantizable_op = True if 'first_conv_or_matmul_quantization' in \
self.recipes and not self.recipes['first_conv_or_matmul_quantization'] \
else False
Expand Down Expand Up @@ -982,17 +986,24 @@ def query_fw_capability(self, model):
all_conv_matmul = []
for _, node in enumerate(self.pre_optimized_model.nodes()):
if node.op_type in ['Conv', 'MatMul']:
if len(first_quantizable_node) == 0:
recipes_ops['first_conv_or_matmul_quantization'] = [(node.name, node.op_type)]

# get first Conv or MatMul node
if exclude_first_quantizable_op:
if len(first_quantizable_node) == 0:
first_quantizable_node.append(node.name)

# get last Conv or MatMul node
if exclude_last_quantizable_op:
if len(last_quantizable_node) != 0:
last_quantizable_node.pop()
last_quantizable_node.append(node.name)

if len(recipes_ops['last_conv_or_matmul_quantization']):
recipes_ops['last_conv_or_matmul_quantization'].pop()
recipes_ops['last_conv_or_matmul_quantization'].append((node.name, node.op_type))

# get first and last Conv or MatMul node
if exclude_pre_post_process:
if len(first_quantizable_node) == 0:
Expand Down Expand Up @@ -1021,22 +1032,27 @@ def query_fw_capability(self, model):
op_wise.update(
{(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])})

# get backbone nodes
from collections import deque

# get nodes between first quantizable node and last quantizable node
backbone_queue = deque(last_quantizable_node)
backbone_nodes = self.pre_optimized_model.get_nodes_chain(backbone_queue, first_quantizable_node)

# get extra Conv or MatMul nodes not between first quantizable node and last quantizable node
backbone_queue_extra = deque()
for conv_or_matmul in all_conv_matmul:
if conv_or_matmul.name not in backbone_nodes:
backbone_queue_extra.append(conv_or_matmul.name)
backbone_nodes = self.pre_optimized_model.get_nodes_chain(backbone_queue_extra,
first_quantizable_node, backbone_nodes)
backbone_nodes += [i for i in first_quantizable_node]

for _, node in enumerate(self.pre_optimized_model.nodes()):
if node.name not in backbone_nodes:
recipes_ops['pre_post_process_quantization'].append((node.name, node.op_type))

if exclude_pre_post_process:
from collections import deque

# get nodes between first quantizable node and last quantizable node
backbone_queue = deque(last_quantizable_node)
backbone_nodes = self.pre_optimized_model.get_nodes_chain(backbone_queue, first_quantizable_node)

# get extra Conv or MatMul nodes not between first quantizable node and last quantizable node
backbone_queue_extra = deque()
for conv_or_matmul in all_conv_matmul:
if conv_or_matmul.name not in backbone_nodes:
backbone_queue_extra.append(conv_or_matmul.name)
backbone_nodes = self.pre_optimized_model.get_nodes_chain(backbone_queue_extra,
first_quantizable_node, backbone_nodes)
backbone_nodes += [i for i in first_quantizable_node]

for _, node in enumerate(self.pre_optimized_model.nodes()):
if node.op_type in optype_wise:
# nodes not in backbone are not quantized
Expand All @@ -1051,8 +1067,8 @@ def query_fw_capability(self, model):
else: # pragma: no cover
op_wise.update(
{(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])})

return {'optypewise': optype_wise, 'opwise': op_wise}
return {'optypewise': optype_wise, 'opwise': op_wise, 'recipes_ops': recipes_ops}

def _optypewise_filter_for_qdq(self, optype_wise):
"""Filter optypes that don't support per_channel in QDQ format.
Expand Down
10 changes: 10 additions & 0 deletions neural_compressor/adaptor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def _query_quantizable_ops(self, matched_nodes):
other_config = copy.deepcopy(op_capability['default'])

self.quantizable_op_details = OrderedDict()
self.recipes_ops = {}

self._init_op_stat = {i: [] for i in tf_quantizable_op_type}

Expand All @@ -747,8 +748,16 @@ def _query_quantizable_ops(self, matched_nodes):
'sequence': [[','.join(patterns[:pat_length - i]) for i in range(pat_length)][0]],
'precision': ['int8']
}
first_conv_or_matmul_node = []
if node_op in tf_quantizable_op_type and node_name not in self.exclude_node_names and (
node_name, self.unify_op_type_mapping[node_op]) not in self.quantizable_op_details:
if (self.unify_op_type_mapping[node_op].find("conv2d") != -1 or \
self.unify_op_type_mapping[node_op].find("matmul") != -1) and \
len(first_conv_or_matmul_node) == 0:
first_conv_or_matmul_node.append((node_name, \
self.unify_op_type_mapping[node_op]))
self.recipes_ops['first_conv_or_matmul_quantization'] = \
first_conv_or_matmul_node
if exclude_first_quantizable_op and \
(self.unify_op_type_mapping[node_op].find("conv2d") != -1 or \
self.unify_op_type_mapping[node_op].find("matmul") != -1):
Expand Down Expand Up @@ -903,6 +912,7 @@ def check_match(patterns, input_pattern):
self._query_bf16_ops(matched_bf16_nodes)
capability = {
'optypewise': self.get_optype_wise_ability(),
'recipes_ops': self.recipes_ops
}
capability['opwise'] = copy.deepcopy(self.quantizable_op_details)
capability['opwise'].update(self.bf16_op_details)
Expand Down
54 changes: 37 additions & 17 deletions neural_compressor/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from abc import abstractmethod
from neural_compressor.utils.create_obj_from_config import get_algorithm

# {location: {algorithm_type: cls}}
registry_algorithms = {}

def algorithm_registry(algorithm_type, location):
Expand All @@ -34,12 +35,12 @@ def algorithm_registry(algorithm_type, location):
cls: The class of register.
"""
def decorator_algorithm(cls):
if algorithm_type in registry_algorithms and location in registry_algorithms[algorithm_type]:
if location in registry_algorithms and algorithm_type in registry_algorithms[location]:
raise ValueError('Cannot have two algorithms with the same name')

if algorithm_type not in registry_algorithms:
registry_algorithms[algorithm_type] = {}
registry_algorithms[algorithm_type][location] = cls()
if location not in registry_algorithms:
registry_algorithms[location] = {}
registry_algorithms[location][algorithm_type] = cls()
return cls
return decorator_algorithm

Expand All @@ -57,9 +58,14 @@ def __getitem__(self, algorithm_type):
Returns:
cls (class): The class of algorithm.
"""
assert algorithm_type in self.algorithms, "algorithm type only support {}".\
format(self.algorithms.keys())
return self.algorithms[algorithm_type]
result = None
for location in self.algorithms:
for key in self.algorithms[location]:
if key == algorithm_type:
result = self.algorithms[location][key]
assert result, "algorithm type only support {}".format(self.support_algorithms())
return result


@classmethod
def support_algorithms(self):
Expand All @@ -68,7 +74,8 @@ def support_algorithms(self):
Returns:
Set: A set of all algorithms.
"""
return set(self.algorithms.keys())
supported_algos = set([self.algorithms[key] for key in self.algorithms])
return supported_algos

class AlgorithmScheduler(object):
"""control the Algorithm in different phase."""
Expand All @@ -79,12 +86,26 @@ def __init__(self, conf):
Args:
conf (dict): Configuration of algorithm.
"""
self.algorithms = get_algorithm(ALGORITHMS, conf)
self._exec_algorithms = {}
self._origin_model = None
self._q_model = None
self._dataloader = None
self._adaptor = None
self._calib_iter = None

def append_algorithm(self, location, algorithm):
"""Append algorithm to list of executed algorithms.
Args:
location: The location to call algorithm
algorithm: algorithm instance
"""
self._exec_algorithms[location] = self._exec_algorithms.get(location, [])
self._exec_algorithms[location].append(algorithm)

def reset_exec_algorithms(self):
"""Reset the list of executed algorithms."""
self._exec_algorithms = {}

def __call__(self, location):
"""Return the processed model via algorithm.
Expand All @@ -93,19 +114,18 @@ def __call__(self, location):
model: The framework model.
"""
assert self._q_model, 'set q_model for algorithm'
if len(self.algorithms) == 0:
if len(self._exec_algorithms.get(location, [])) == 0:
return self._q_model
assert self._origin_model, 'set origin model for algorithm'
assert self._dataloader, 'set dataloader for algorithm'
assert self._adaptor, 'set adaptor for algorithm'
assert self._calib_iter, 'set calibration iteration for algorithm'
for algo in self.algorithms:
if location in algo:
self._q_model = algo[location](self._origin_model,
self._q_model, \
self._adaptor, \
self._dataloader, \
self._calib_iter)
for algo in self._exec_algorithms.get(location, []):
self._q_model = algo(self._origin_model,
self._q_model, \
self._adaptor, \
self._dataloader, \
self._calib_iter)
return self._q_model

@property
Expand Down
20 changes: 18 additions & 2 deletions neural_compressor/strategy/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,25 @@ def next_tune_cfg(self):
stage1_max = 1e9 # TODO set a more appropriate value
op_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [],
op_item_dtype_dict, initial_op_tuning_cfg)
for op_tuning_cfg in op_wise_tuning_sampler:
for index, op_tuning_cfg in enumerate(op_wise_tuning_sampler):
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
# Apply all recipes, if not got the qmodel that meet the requirements, discard it.
if index == 1 and not self.applied_all_recipes_flag:
logger.info("Apply all recipes.")
self.applied_all_recipes_flag = True
yield self.apply_all_tuning_recipes(deepcopy(self.cur_best_tuning_cfg))
stage1_cnt += 1
if early_stop_tuning and stage1_cnt > stage1_max:
logger.info("Early stopping the stage 1.")
break
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
yield op_tuning_cfg

# Apply all recipes, if not got the qmodel that meet the requirements, discard it.
if stage1_cnt == 1 and not self.applied_all_recipes_flag:
logger.info("Apply all recipes.")
self.applied_all_recipes_flag = True
yield self.apply_all_tuning_recipes(deepcopy(self.cur_best_tuning_cfg))

# Fallback the ops supported both static and dynamic from static to dynamic
# Tuning items: None
if self.cfg.quantization.approach == 'post_training_auto_quant':
Expand All @@ -213,6 +225,10 @@ def next_tune_cfg(self):
new_op_tuning_cfg[item.name])
new_op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
yield new_op_tuning_cfg

logger.info("Apply recipe one by one.")
for tune_cfg in self.apply_recipe_one_by_one(deepcopy(self.cur_best_tuning_cfg)):
yield tune_cfg
best_op_tuning_cfg_stage1 = deepcopy(self.cur_best_tuning_cfg)

# Fallback
Expand Down
29 changes: 13 additions & 16 deletions neural_compressor/strategy/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .utils.tuning_space import TuningItem
from ..utils import logger
from ..utils.utility import Statistics
from ..algorithm import AlgorithmScheduler

@strategy_registry
class ConservativeTuneStrategy(TuneStrategy):
Expand Down Expand Up @@ -115,23 +116,19 @@ def traverse(self):
logger.debug("Dump current tuning configuration:")
logger.debug(tune_cfg)
self.tuning_times += 1
self.algo.calib_iter = tune_cfg['calib_iteration']
# TODO align the api to let strategy has access to pre_optimized model
assert self.adaptor.pre_optimized_model
self.algo.origin_model = self.adaptor.pre_optimized_model
if self.cfg.quantization.recipes.smooth_quant:
try:
self.algo.alpha = self.cfg.quantization.recipes.smooth_quant_args.get("alpha", 0.5)
except:
self.algo.alpha = 0.5
self.algo.tune_cfg = copy.deepcopy(tune_cfg)
self.algo.q_model = self.model
self.model = self.algo('pre_quantization')
# set the parameter for pre quantization algos and run
self.set_param_for_pre_quantization_algos(self.algo_scheduler, tune_cfg, self.model)
self.model = self.algo_scheduler('pre_quantization')
# quantize
q_model = self.adaptor.quantize(copy.deepcopy(tune_cfg), self.model, self.calib_dataloader, self.q_func)
self.algo.q_model = q_model
if self.cfg.quantization.recipes.fast_bias_correction:
self.algo.algorithms[0].quantization_cfg = tune_cfg
self.last_qmodel = self.algo('post_quantization')
assert self.adaptor.pre_optimized_model
# set the parameter for post quantization algos and run
self.set_param_for_post_quantization_algos(self.algo_scheduler, tune_cfg, self.adaptor.pre_optimized_model,
q_model)
self.last_qmodel = self.algo_scheduler('post_quantization')
self.last_tune_cfg = copy.deepcopy(tune_cfg)
# Remove the reference to model
self.algo_scheduler.reset_exec_algorithms()
assert self.last_qmodel
# Return the last quantized model as a result. if performance only.
if self.cfg.tuning.exit_policy.performance_only:
Expand Down
Loading

0 comments on commit 44d1761

Please sign in to comment.