Skip to content

Commit

Permalink
Fix onnxrt smooth quant (#951)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 authored Jun 13, 2023
1 parent 8f5f5de commit 1b26c0d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 52 deletions.
91 changes: 42 additions & 49 deletions neural_compressor/adaptor/ox_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
"""SmoothQuant for onnxrt adaptor."""

import os
import copy
import onnx
import logging
import numpy as np
from onnx import onnx_pb as onnx_proto
from neural_compressor.model.model import BaseModel
from neural_compressor.model.onnx_model import ONNXModel
from neural_compressor.adaptor.ox_utils.util import find_by_name, quantize_data, _get_qrange_for_qType
from neural_compressor.adaptor.ox_utils.util import find_by_name, \
quantize_data, _get_qrange_for_qType, is_B_transposed
from onnx import numpy_helper, helper

logger = logging.getLogger("neural_compressor")
Expand Down Expand Up @@ -68,17 +70,6 @@ def make_sub_graph(node, inits, input_data, output_data, reduce_range, opset, ir
from onnx import helper, TensorProto, numpy_helper
input = helper.make_tensor_value_info(node.input[0], dtype_map[input_data.dtype], input_data.shape)
output = helper.make_tensor_value_info(node.output[0], dtype_map[output_data.dtype], output_data.shape)

for init in inits:
q_dq_val = quant_dequant_data(numpy_helper.to_array(init), reduce_range)
new_tensor = helper.make_tensor(
name=init.name,
data_type=dtype_map[numpy_helper.to_array(init).dtype],
dims=numpy_helper.to_array(init).shape if \
len(numpy_helper.to_array(init).shape) != 0 else [],
vals=q_dq_val if \
len(numpy_helper.to_array(init)) != 0 else [numpy_helper.to_array(init)])
init.CopyFrom(new_tensor)
graph = helper.make_graph([node], 'sub_graph', [input], [output], inits)
model = helper.make_model(graph, opset_imports=opset)
model.ir_version = ir_version
Expand Down Expand Up @@ -110,11 +101,15 @@ class ORTSmoothQuant:
def __init__(self, model, dataloader, reduce_range=False, backend='CPUExecutionProvider'):
"""Initialize the attributes of class."""
self.model = model if isinstance(model, BaseModel) else ONNXModel(model)
self.value_infos = {vi.name: vi for vi in self.model.model.graph.value_info}
self.value_infos.update({ot.name: ot for ot in self.model.model.graph.output})
self.value_infos.update({it.name: it for it in self.model.model.graph.input})
self.dataloader = dataloader
self.reduce_range = reduce_range
self.backend = backend
self.tensor_scales_info = {}
self.new_added_mul_nodes = []
self.new_added_value_info = []
self.new_init_tensors = [] # scales_tensor
self.alpha = None
self.percentile = None
Expand All @@ -129,7 +124,7 @@ def __init__(self, model, dataloader, reduce_range=False, backend='CPUExecutionP
self._build_absorb_function()

def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm', 'Conv', 'MatMul', 'FusedConv'],
scales_per_op=False, calib_iter=100, quantize_config=None,
scales_per_op=True, calib_iter=100, quantize_config=None,
auto_alpha_args={'alpha_min': 0.3, 'alpha_max': 0.7, 'alpha_step': 0.05, 'attn_method': 'min'}):
"""The main entry of smooth quant.
Expand Down Expand Up @@ -167,6 +162,7 @@ def transform(self, alpha=0.5, folding=True, percentile=99.999, op_types=['Gemm'
self._insert_smooth_mul_op(scales)
self._adjust_weights(scales)
self.model.add_nodes(self.new_added_mul_nodes)
self.model.model.graph.value_info.extend(self.new_added_value_info)
self.model.add_initializers(self.new_init_tensors)
for node, old_input_name, new_input_name in self.replace_input:
self.model.replace_node_input(node, old_input_name, new_input_name)
Expand Down Expand Up @@ -194,9 +190,15 @@ def recover(self):
for node, old_input_name, new_input_name in self.replace_input:
self.model.replace_node_input(node, new_input_name, old_input_name)

for value_info in self.new_added_value_info:
self.model.model.graph.value_info.remove(value_info)

self.model.remove_nodes(self.new_added_mul_nodes)
self.model.remove_initializers(self.new_init_tensors)
self.tensor_scales_info = {}
self.new_added_mul_nodes = []
self.new_init_tensors = []
self.new_added_value_info = []

def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter):
"""Check need calibration or not.
Expand Down Expand Up @@ -340,7 +342,7 @@ def _get_output_loss(self, node_name, scale, calib_iter):
loss = 0
if len(node) > 0:
node = node[0]

orig_outputs = self.model.output()
added_tensors = [node.input[0], node.output[0]]
self.model.add_tensors_to_outputs(added_tensors)

Expand All @@ -350,20 +352,8 @@ def _get_output_loss(self, node_name, scale, calib_iter):
ort.InferenceSession(self.model.model.SerializeToString(),
providers=[self.backend])
base_dir = '' if not self.model.is_large_model else os.path.dirname(self.model.model_path)
if node.op_type in ['Conv', 'FusedConv']:
weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir)
weight_q = quant_dequant_data(weight)
elif node.op_type in ['MatMul', 'Gemm']:
weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir)
weight_q = quant_dequant_data(weight)

base_dir = '' if not self.model.is_large_model else os.path.dirname(self.model.model_path)
if node.op_type in ['Conv', 'FusedConv']:
weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir)
weight_q = quant_dequant_data(weight)
elif node.op_type in ['MatMul', 'Gemm']:
weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir)
weight_q = quant_dequant_data(weight)
weight = onnx.numpy_helper.to_array(self.model.get_initializer(node.input[1]), base_dir)
weight_q = quant_dequant_data(weight)

self.model.set_initializer(node.input[1], weight_q)
inits = [self.model.get_initializer(i) for i in node.input if self.model.get_initializer(i) is not None]
Expand All @@ -383,9 +373,9 @@ def _get_output_loss(self, node_name, scale, calib_iter):
if model is None:
model = make_sub_graph(node, inits, outputs[0], outputs[1],
self.reduce_range, self.model.model.opset_import, self.model.model.ir_version)
loss += get_quant_dequant_output(model, outputs[0], outputs[1], self.reduce_range, self.backend)
loss += get_quant_dequant_output(model, outputs[0] * scale, outputs[1], self.reduce_range, self.backend)

self.model.remove_tensors_from_outputs(added_tensors)
self.model.remove_tensors_from_outputs([i for i in added_tensors if i not in orig_outputs])
self.model.set_initializer(node.input[1], weight)
return loss

Expand Down Expand Up @@ -430,15 +420,16 @@ def _auto_tune_alpha(self, calib_iter, alpha_min=0.3, alpha_max=0.7, alpha_step=

## Searching optimal alphas
for tensor_name, node_infos in self.tensors_to_node.items():
loss_all_ops = {}
for node_info in node_infos:
loss_alpha = {}
key = node_info[0] if self.scales_per_op else tensor_name

node = self.model.get_node(node_info[0])
for alpha in alpha_space:
scale = self._get_smooth_scales(alpha, [key])
self._adjust_weights(scale)
input_scale = self._reshape_scale_for_input(tensor_name, key)
input_scale = self._reshape_scale_for_input(tensor_name, key) if \
not (node.op_type == 'Gemm' and is_B_transposed(node)) else \
self.tensor_scales_info[key]
loss = self._get_output_loss(node_info[0], input_scale, calib_iter)
loss_alpha[alpha] = loss
if key not in optimal_alphas: # Update alpha results
Expand All @@ -447,7 +438,6 @@ def _auto_tune_alpha(self, calib_iter, alpha_min=0.3, alpha_max=0.7, alpha_step=
optimal_alphas[key] = alpha if optimal_alphas[key] in loss_alpha and \
loss < loss_alpha[optimal_alphas[key]] else optimal_alphas[key]
self.recover()
loss_all_ops[key] = loss_alpha
logger.info("auto tuning alpha done")
if self.model.is_large_model:
from onnx.external_data_helper import load_external_data_for_model
Expand All @@ -474,28 +464,25 @@ def _get_smooth_scales(self, alpha, target_list=[]):
# if scales_per_op the key of scales is the node name, otherwise the activation of node
if self.scales_per_op:
for node_info in nodes:
node = self.model.input_name_to_nodes[node_info[1][1]][0]
if len(target_list) > 0 and node_info[0] not in target_list:
continue
weight = numpy_helper.to_array(self.model.get_initializer(node_info[1][1]))
if len(weight.shape) == 4: # conv
if weight.shape[1] == 1: # depthwise conv
pass
else:
weight = np.moveaxis(weight, 0, 1)
if (len(weight.shape) == 4 and weight.shape[1] != 1) or \
(node.op_type == 'Gemm' and is_B_transposed(node)):
weight = np.moveaxis(weight, 0, 1)
specific_alpha = alpha[node_info[0]] if isinstance(alpha, dict) else alpha
scales[node_info[0]] = self._get_smooth_scale(weight, specific_alpha, tensor)
else:
if len(target_list) > 0 and tensor not in target_list:
continue
weights = [numpy_helper.to_array(self.model.get_initializer(node_info[1][1])) for \
node_info in nodes]
weights_in_channel_max = []
for weight in weights: # mamul ic*oc, conv oc*ic*k*k
if len(weight.shape) == 4: # conv
if weight.shape[1] == 1: # depthwise conv
pass
else:
weight = np.moveaxis(weight, 0, 1)
for node_info in nodes:
node = self.model.input_name_to_nodes[node_info[1][1]][0]
weight = numpy_helper.to_array(self.model.get_initializer(node_info[1][1]))
if (len(weight.shape) == 4 and weight.shape[1] != 1) or \
(node.op_type == 'Gemm' and is_B_transposed(node)):
weight = np.moveaxis(weight, 0, 1)
weight = weight.reshape(weight.shape[0], -1)
cur_max = np.amax(weight, axis=-1)
weights_in_channel_max.append(cur_max)
Expand Down Expand Up @@ -555,6 +542,10 @@ def _insert_smooth_mul_op(self, scales):
name=key + "_smooth_mul"
)
self.new_added_mul_nodes.append(mul_node)
if input_name in self.value_infos:
value_info = copy.deepcopy(self.value_infos[input_name])
value_info.name = mul_node.output[0]
self.new_added_value_info.append(value_info)
if self.scales_per_op:
self.replace_input.append([self.model.get_node(key), input_name, mul_output_name])
else:
Expand All @@ -573,10 +564,12 @@ def _adjust_weights(self, scales):
if key not in scales:
continue
input = node_info[1][1]
node = self.model.input_name_to_nodes[input][0]
weight = numpy_helper.to_array(self.model.get_initializer(input))
if len(weight.shape) == 2:
scale = np.expand_dims(scales[key],
axis=-1) # TODO, to support conv
scale = np.expand_dims(scales[key], axis=0) if \
node.op_type == 'Gemm' and is_B_transposed(node) else\
np.expand_dims(scales[key], axis=-1)
new_weight = weight * scale
elif len(weight.shape) == 4: # TODO need to check conv
node = self.model.input_name_to_nodes[input][0]
Expand Down
6 changes: 3 additions & 3 deletions test/algorithm/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def tearDownClass(self):

def test_sq(self):
sq = ORTSmoothQuant(copy.deepcopy(self.model), self.dataloader)
model = sq.transform(calib_iter=5)
model = sq.transform(calib_iter=5, scales_per_op=False)
self.assertEqual(len([i for i in model.model.graph.node if i.op_type == 'Mul']), 1)
sq.recover()
self.assertEqual(len(sq.model.nodes()), len(self.model.graph.node))
Expand All @@ -68,7 +68,7 @@ def test_sq(self):
self.assertAlmostEqual(tensor[0][0], sq_tensor[0][0], 4)

sq = ORTSmoothQuant(copy.deepcopy(self.model), self.dataloader)
model = sq.transform(calib_iter=5, folding=False)
model = sq.transform(calib_iter=5, folding=False, scales_per_op=False)
self.assertEqual(len([i for i in model.model.graph.node if i.op_type == 'Mul']), 2)
sq.recover()
self.assertEqual(len(sq.model.nodes()), len(self.model.graph.node))
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_sq(self):


sq = ORTSmoothQuant(copy.deepcopy(self.model), self.dataloader)
model = sq.transform(calib_iter=5, alpha='auto')
model = sq.transform(calib_iter=5, alpha='auto', scales_per_op=False)
self.assertEqual(len([i for i in model.model.graph.node if i.op_type == 'Mul']), 1)
sq.recover()
self.assertEqual(len(sq.model.nodes()), len(self.model.graph.node))
Expand Down

0 comments on commit 1b26c0d

Please sign in to comment.