Skip to content

Commit

Permalink
Support sq auto tune for ort (#847)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 authored May 30, 2023
1 parent ab20376 commit 1e1d706
Show file tree
Hide file tree
Showing 7 changed files with 741 additions and 407 deletions.
96 changes: 17 additions & 79 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,95 +152,33 @@ def __init__(self, framework_specific_info):

self.optype_statistics = None

def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, percentile=99.999,
op_types=['FusedConv', 'MatMul', 'Linear', 'Conv'], scales_per_op=True, **kwargs):
def smooth_quant(self, model, dataloader, iterations, tune_cfg, alpha=0.5, folding=True,
percentile=99.999, op_types=['MatMul', 'Gemm', 'Conv', 'FusedConv'], scales_per_op=True):
"""Get augmented model with smooth quant.
Args:
model_wrapper: origin_model
dataloader: dataloader
iterations: iterations
tune_cfg: quantization config
alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ
percentile:Percentile of calibration to remove outliers
op_types: The op types whose input tensor will be dumped
scales_per_op: True, each op will have an individual scale, mainly for accuracy
False, ops with the same input will share a scale, mainly for performance
model_wrapper (object): origin_model
dataloader (object): dataloader
iterations (int): iterations
tune_cfg (dict): quantization config
alpha (float or str): smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ
folding (bool): whether fold those foldable Mul which are inserted for SmoothQuant
percentile (float): percentile of calibration to remove outliers
op_types (list): The op types whose input tensor will be dumped
scales_per_op (bool): True, each op will have an individual scale, mainly for accuracy
False, ops with the same input will share a scale, mainly for performance
Returns:
model: A modified onnx model
"""
if self.smooth_quant_model is not None:
return self.smooth_quant_model

from onnx import numpy_helper
from neural_compressor.adaptor.ox_utils.calibration import ONNXRTAugment
from neural_compressor.adaptor.ox_utils.util import fold_scale
if isinstance(alpha, str):
logger.warning(f"onnx backend only support float alpha, reset alpha to 0.5 ")
alpha = 0.5
black_nodes = []
white_nodes = []
quantize_config = None
if tune_cfg is not None:
quantize_config = self._cfg_to_quantize_config(tune_cfg)
black_nodes = [node for node in quantize_config if quantize_config[node] == 'fp32']
white_nodes = [node for node in quantize_config if quantize_config[node] != 'fp32']

augment = ONNXRTAugment(self.pre_optimized_model,
dataloader, self.quantizable_op_types,
black_nodes=black_nodes, white_nodes=white_nodes,
iterations=list(range(0, iterations)),
backend=self.backend, reduce_range=self.reduce_range)

max_vals_per_channel, shape_infos = augment.calib_smooth(percentile, op_types, quantize_config)

input_tensors_2_weights = {}
input_tensors_2_weights_nodes = {}
for name in max_vals_per_channel.keys():
curr_tensor_to_weight = []
curr_tensor_to_weight_nodes = []
nodes = [i for i in self.pre_optimized_model.nodes() if name in i.input]
for node in nodes:
if node.op_type not in op_types:
continue
if len(node.input) >= 2:
input = node.input[1] ##TODO always dump the index 1 to get the weight
if self.pre_optimized_model.get_initializer(input):
weight = numpy_helper.to_array(self.pre_optimized_model.get_initializer(input),
os.path.dirname(self.pre_optimized_model.model_path)) if \
self.pre_optimized_model.model_path is not None else \
numpy_helper.to_array(self.pre_optimized_model.get_initializer(input))
curr_tensor_to_weight.append(weight)
curr_tensor_to_weight_nodes.append(node)
input_tensors_2_weights[name] = curr_tensor_to_weight
input_tensors_2_weights_nodes[name] = curr_tensor_to_weight_nodes

if scales_per_op:
from neural_compressor.adaptor.ox_utils.util import get_smooth_scales_per_op, \
insert_smooth_mul_op_per_op, adjust_weights_per_op
scales = get_smooth_scales_per_op(max_vals_per_channel, input_tensors_2_weights,
input_tensors_2_weights_nodes, alpha)
new_added_mul_nodes, new_init_tensors, op_nodes = insert_smooth_mul_op_per_op(scales, shape_infos,
input_tensors_2_weights_nodes)
adjust_weights_per_op(self.pre_optimized_model, op_nodes, scales)
else:
from neural_compressor.adaptor.ox_utils.util import get_smooth_scales_per_input, \
insert_smooth_mul_op_per_input, adjust_weights_per_input
scales = get_smooth_scales_per_input(max_vals_per_channel, input_tensors_2_weights, alpha)
new_added_mul_nodes, new_init_tensors = insert_smooth_mul_op_per_input(scales, shape_infos,
input_tensors_2_weights_nodes)
adjust_weights_per_input(self.pre_optimized_model, input_tensors_2_weights_nodes, scales)

self.pre_optimized_model.add_nodes(new_added_mul_nodes)
self.pre_optimized_model.add_initializers(new_init_tensors)
self.pre_optimized_model.update()
self.pre_optimized_model.topological_sort()
self.pre_optimized_model.remove_unused_constant()

fold_scale(self.pre_optimized_model, scales)

self.smooth_quant_model = self.pre_optimized_model
from .ox_utils.smooth_quant import ORTSmoothQuant
quantize_config = self._cfg_to_quantize_config(tune_cfg) if tune_cfg is not None else None
sq = ORTSmoothQuant(self.pre_optimized_model, dataloader, self.reduce_range, self.backend)
self.smooth_quant_model = sq.transform(
alpha, folding, percentile, op_types, scales_per_op, iterations, quantize_config)
return self.smooth_quant_model

@dump_elapsed_time("Pass quantize model")
Expand Down
26 changes: 14 additions & 12 deletions neural_compressor/adaptor/ox_utils/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,12 @@ def _check_is_group_conv(self, node, model):
weight_name = node.input[1]
weight_shape = numpy_helper.to_array(
model.graph.initializer[name_to_indices[weight_name]]).shape
input_channel = weight_shape.shape[1]
input_channel = weight_shape[1]
if input_channel != 1: # TODO need to double check
return True
return False

def _get_input_tensor_of_ops(self, op_types=['MatMul', 'Linear', 'Conv']):
def _get_input_tensor_of_ops(self, op_types=['MatMul', 'Gemm', 'Conv', 'FusedConv']):
"""Traverse the graph and get all the data tensors flowing into layers of {op_types}.
Group conv is excluded.
Expand All @@ -622,20 +622,20 @@ def _get_input_tensor_of_ops(self, op_types=['MatMul', 'Linear', 'Conv']):
op_types: The op types whose input tensor will be dumped
Returns:
A set of tensor names
A dict of dumped tensor: node info
"""
tensors_to_dump = set()
tensors_to_node = {}
model = self.model
initializers = {i.name: i for i in model.graph.initializer}

for node in model.graph.node:
if len(op_types) == 0 or node.op_type in op_types:
if node.op_type == "Conv" and self._check_is_group_conv(node, model):
if node.op_type in ["Conv", "FusedConv"] and self._check_is_group_conv(node, model):
continue
# also need to check whether the layer has weight
if len(node.input) >= 2 and node.input[1] in initializers.keys():
tensors_to_dump.add(node.input[0])
return tensors_to_dump
tensors_to_node.setdefault(node.input[0], []).append([node.name, node.input, node.output])
return tensors_to_node

def _get_max_per_channel(self, datas: list, percentile):
"""Get the max values per input channel.
Expand Down Expand Up @@ -680,8 +680,8 @@ def calib_smooth(self, percentile, op_types, q_config):
shape_infos: The shape information of input tensors
"""
# add the input tensors of {op_types} to outputs of the model
tensors_to_dump = self._get_input_tensor_of_ops(op_types)
self.model_wrapper.add_tensors_to_outputs(tensors_to_dump)
tensors_to_node = self._get_input_tensor_of_ops(op_types)
self.model_wrapper.add_tensors_to_outputs(tensors_to_node.keys())
self.augmented_model = self.model_wrapper.model
if self.model_wrapper.is_large_model: # pragma: no cover
onnx.save_model(self.augmented_model,
Expand All @@ -693,11 +693,13 @@ def calib_smooth(self, percentile, op_types, q_config):
_, output_dicts = self.get_intermediate_outputs()

# remove the input tensors of {op_types} to outputs of the model
self.model_wrapper.remove_tensors_from_outputs(tensors_to_dump)
self.model_wrapper.remove_tensors_from_outputs(tensors_to_node.keys())
max_vals_per_channel = {}
shape_infos = {}
for key in tensors_to_dump:
for key, val in tensors_to_node.items():
max_val_per_channel = self._get_max_per_channel(output_dicts[key], percentile=percentile)
max_vals_per_channel[key] = max_val_per_channel
shape_infos[key] = output_dicts[key][0].shape
return max_vals_per_channel, shape_infos
for item in val:
shape_infos[item[1][1]] = numpy_helper.to_array(self.model_wrapper.get_initializer(item[1][1])).shape
return max_vals_per_channel, shape_infos, tensors_to_node
Loading

0 comments on commit 1e1d706

Please sign in to comment.