Skip to content

Commit

Permalink
Support TF per-channel MatMul quantization (#928)
Browse files Browse the repository at this point in the history
  • Loading branch information
Spycsh authored Aug 23, 2023
1 parent ec9ae91 commit cf55895
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ After prepare step is done, we add tune and benchmark code to generate quantized
q_model.graph_def = strip_iterator(q_model.graph_def)
q_model.save(FLAGS.output_model)
```

You can also add the optional parameter `op_type_dict={'matmul':{'weight':{'granularity':['per_channel']}}}` in `PostTrainingQuantConfig`, which enables the per-channel quantization of MatMul, to get an int8 model with better accuracy.

#### Benchmark
```python
from neural_compressor.benchmark import fit
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/tensorflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
'weight': {
'dtype': ['int8'],
'scheme': ['sym'],
'granularity': ['per_tensor'],
'granularity': ['per_tensor', 'per_channel'],
'algorithm': ['minmax']
},
'activation': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def __init__(self, model, device='cpu'):

def do_transformation(self):
"""Apply the fusion of QuantizedMatMul + Requantize + Dequantize."""
fuse_pattern = [["_QuantizedMatMul"], ['Requantize'], ['Dequantize'], ('Softmax',)]
fuse_pattern = [["_QuantizedMatMul"], ['Requantize', 'RequantizePerChannel'], ['Dequantize'], ('Softmax',)]

uint8_type = dtypes.quint8.as_datatype_enum
int8_type = dtypes.qint8.as_datatype_enum
Expand Down Expand Up @@ -460,8 +460,14 @@ def do_transformation(self):
and weight_node.op == 'Const' and not last_node.op == 'QuantizedConcatV2':
min_input_value = (min_input_node.attr['value'].tensor.float_val)[0]
max_input_value = (max_input_node.attr['value'].tensor.float_val)[0]
max_filter_value = (max_filter_node.attr['value'].tensor.float_val)[0]
min_filter_value = (min_filter_node.attr['value'].tensor.float_val)[0]
if requantize_node.op.find('PerChannel') != -1: # pragma: no cover
max_filter_tensor = tensor_util.MakeNdarray( # get tensor
max_filter_node.attr['value'].tensor)
min_filter_tensor = tensor_util.MakeNdarray( # get tensor
min_filter_node.attr['value'].tensor)
else:
max_filter_value = (max_filter_node.attr['value'].tensor.float_val)[0]
min_filter_value = (min_filter_node.attr['value'].tensor.float_val)[0]

weights_tensor = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)
bias_tensor = tensor_util.MakeNdarray(bias_node.attr['value'].tensor)
Expand All @@ -474,10 +480,14 @@ def do_transformation(self):

if -self.eps <= max_input_value - min_input_value <= self.eps:
max_input_value += self.eps
int32_bias = Helper.generate_int32_bias_for_matmul(bias_tensor, weights_tensor,
input_range, max_input_value,
min_input_value,
max_filter_value, min_filter_value)
if requantize_node.op.find('PerChannel') != -1: # pragma: no cover
int32_bias = Helper.generate_int32_bias_for_matmul_per_channel(
bias_tensor, weights_tensor, max_input_value, min_input_value,
max_filter_tensor, min_filter_tensor)
else:
int32_bias = Helper.generate_int32_bias_for_matmul(
bias_tensor, weights_tensor, input_range, max_input_value, min_input_value,
max_filter_value, min_filter_value)

bias_node.attr['dtype'].CopyFrom(
attr_value_pb2.AttrValue(
Expand Down Expand Up @@ -549,7 +559,7 @@ def do_transformation(self):
qint32_type = dtypes.qint32.as_datatype_enum

target_nodes = self.graph_analyzer.query_fusion_pattern_nodes(
[["_QuantizedMatMul"], ['Requantize']])
[["_QuantizedMatMul"], ['Requantize', 'RequantizePerChannel']])
for i in target_nodes:
quantized_node_name = i[0]
quantized_node = self.graph_info[quantized_node_name].node
Expand All @@ -566,7 +576,6 @@ def do_transformation(self):
# "BiasAdd", "Activation", "Requantize"
if "BiasAddAdd" in attr_fused_ops:
continue

new_node = node_def_pb2.NodeDef()

new_node.op = quantized_node_op
Expand Down Expand Up @@ -654,8 +663,14 @@ def do_transformation(self):
and max_filter_node and min_filter_node):
min_input_value = (min_input_node.attr['value'].tensor.float_val)[0]
max_input_value = (max_input_node.attr['value'].tensor.float_val)[0]
max_filter_value = (max_filter_node.attr['value'].tensor.float_val)[0]
min_filter_value = (min_filter_node.attr['value'].tensor.float_val)[0]
if requantize_node.op.find('PerChannel') != -1: # pragma: no cover
max_filter_tensor = tensor_util.MakeNdarray( # get tensor
max_filter_node.attr['value'].tensor)
min_filter_tensor = tensor_util.MakeNdarray( # get tensor
min_filter_node.attr['value'].tensor)
else:
max_filter_value = (max_filter_node.attr['value'].tensor.float_val)[0]
min_filter_value = (min_filter_node.attr['value'].tensor.float_val)[0]

weights_tensor = tensor_util.MakeNdarray(weight_node.attr['value'].tensor)
bias_tensor = tensor_util.MakeNdarray(bias_node.attr['value'].tensor)
Expand All @@ -667,12 +682,14 @@ def do_transformation(self):

if -self.eps <= max_input_value - min_input_value <= self.eps:
max_input_value += self.eps
int32_bias = Helper.generate_int32_bias_for_matmul(bias_tensor, weights_tensor,
input_range,
max_input_value,
min_input_value,
max_filter_value,
min_filter_value)
if requantize_node.op.find('PerChannel') != -1: # pragma: no cover
int32_bias = Helper.generate_int32_bias_for_matmul_per_channel(
bias_tensor, weights_tensor, max_input_value, min_input_value,
max_filter_tensor, min_filter_tensor)
else:
int32_bias = Helper.generate_int32_bias_for_matmul(
bias_tensor, weights_tensor, input_range, max_input_value, min_input_value,
max_filter_value, min_filter_value)
bias_node.attr['dtype'].CopyFrom(
attr_value_pb2.AttrValue(
type=float32_type if self.device == 'gpu' else qint32_type))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class GenerateGraphWithQDQPattern(GraphRewriterBase):
"""Insert Q/DQ pairs before quantizable ops."""
def __init__(self, model, calibration_data, op_wise_config, fake_quant, fp32_ops,
bf16_ops, quantized_nodes, device, performance_only, itex_mode):
"""Initilization."""
"""Initialization."""
super().__init__(model)
self.data = calibration_data
self.op_wise_config = op_wise_config
Expand Down Expand Up @@ -130,7 +130,6 @@ def do_transformation(self):
weight_node = parent_node
else:
continue

if computational_node_name in self.op_wise_config.keys():
op_wise_cfg = self.op_wise_config[computational_node_name]
per_channel = op_wise_cfg[0]
Expand Down Expand Up @@ -463,7 +462,6 @@ def _insert_qdq_pattern_for_weight_node(self,
insert_reshape = False
shape_convert = None
shape_revert = None

# The weight node of BatchMatMul may have no value
if 'value' in weight_node.attr and \
host_op_type in ("Conv2D", "MatMul", "BatchMatMul", "BatchMatMulV2", "Conv3D", \
Expand All @@ -474,6 +472,12 @@ def _insert_qdq_pattern_for_weight_node(self,
ranges = np.abs(float_tensor).max(axis=(0, 1, 2, 3))
elif host_op_type in ('Conv2D', 'Conv2DBackpropInput'):
ranges = np.abs(float_tensor).max(axis=(0, 1, 2))
elif host_op_type in ('MatMul'): # pragma: no cover
if 'transpose_b' in weight_node.attr and weight_node.attr["transpose_b"].b: # pragma: no cover
ranges = np.abs(float_tensor).max(axis=(1))
else:
# itex qdq needs to transpose this range
ranges = np.abs(float_tensor).max(axis=(0))
else:
ranges = np.abs(float_tensor).max(axis=(0, 1))

Expand Down
33 changes: 33 additions & 0 deletions neural_compressor/adaptor/tf_utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,39 @@ def generate_int32_bias_for_matmul(bias_tensor, weights_tensor,

return int32_bias

@staticmethod
def generate_int32_bias_for_matmul_per_channel(bias_tensor, weights_tensor, max_input, min_input,
max_filter_tensor, min_filter_tensor,
): # pragma: no cover
"""Static method that generate per-channel int32 bias for matmul op.
Args:
bias_tensor: bias node tensor.
weights_tensor: weights tensor.
max_input: max activation input value.
min_input: min activation input value.
max_filter_tensor: max weight input tensor.
min_filter_tensor: min weight input tensor.
Returns:
int32_bias: int32 bias
"""
channel_size = bias_tensor.shape[0]
activation_range = 255.0
weights_range = 127.0
scales = []
relative_scale = 255 * min_input / (max_input - min_input)
for i in range(channel_size):
scales.append(activation_range * weights_range /
((max_input - min_input) *
max(abs(max_filter_tensor[i]), abs(min_filter_tensor[i]))))
int32_bias = []
for i in range(channel_size):
value = np.sum(np.array(weights_tensor),axis=0,dtype=np.int32)[i]
int32_bias.append((int)(np.around(value * relative_scale + bias_tensor[i] * scales[i])))

return int32_bias

@staticmethod
def gen_valid_sampling_log(log_path):
"""Generate the valid sampling log.
Expand Down
16 changes: 14 additions & 2 deletions neural_compressor/adaptor/tf_utils/quantize_graph_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,13 @@ def generate_quantized_weight_node(host_op_type,
if per_channel:
if host_op_type in ('Conv3D', 'Conv3DBackpropInputV2'):
ranges = np.abs(float_tensor).max(axis=(0, 1, 2, 3))
elif host_op_type in ('Conv2D', 'Conv2DBackpropInput'):
ranges = np.abs(float_tensor).max(axis=(0, 1, 2))
elif host_op_type in ('MatMul'):
if 'transpose_b' in input_node.attr and input_node.attr["transpose_b"].b: # pragma: no cover
ranges = np.abs(float_tensor).max(axis=(1))
else:
ranges = np.abs(float_tensor).max(axis=(0))
else:
ranges = np.abs(float_tensor).max(axis=(0, 1, 2))

Expand All @@ -363,7 +370,13 @@ def generate_quantized_weight_node(host_op_type,
ranges[ranges < epsilon] = epsilon
min_value[np.abs(min_value) < epsilon] = -epsilon
max_value[np.abs(max_value) < epsilon] = epsilon
qint8_tensor = (np.around(float_tensor *127.0/ranges)).astype(np.int8)
if 'transpose_b' in input_node.attr and input_node.attr["transpose_b"].b: # pragma: no cover
# transpose for broadcasting
float_tensor = np.transpose(float_tensor, [1, 0])
qint8_tensor = (np.around(float_tensor *127.0/ranges)).astype(np.int8)
qint8_tensor = np.transpose(qint8_tensor, [1, 0])
else:
qint8_tensor = (np.around(float_tensor *127.0/ranges)).astype(np.int8)
else:
min_value = np.min(float_tensor)
max_value = np.max(float_tensor)
Expand Down Expand Up @@ -406,7 +419,6 @@ def generate_quantized_weight_node(host_op_type,
qint8_tensor,
dtypes.qint8,
shape=shape)

min_node = QuantizeGraphHelper.create_constant_node(min_name, min_value,
dtypes.float32, device="cpu")

Expand Down
7 changes: 7 additions & 0 deletions test/tfnewapi/test_tensorflow_graph_qdq_matmul_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def build_fake_yaml():
accuracy:
metric:
topk: 1
quantization:
model_wise:
weight:
granularity: per_tensor
scheme: sym
dtype: int8
algorithm: minmax
tuning:
strategy:
name: basic
Expand Down

0 comments on commit cf55895

Please sign in to comment.