Skip to content

Commit

Permalink
update quantization new format (#46529)
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill authored Oct 14, 2022
1 parent 8f1ac7c commit 84333cf
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 88 deletions.
38 changes: 22 additions & 16 deletions python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def forward(self, inputs):

self._quantize_inputs = ImperativeQuantizeInputs(**kwargs)

self._quantize_outputs = ImperativeQuantizeOutputs(moving_rate)
self._quantize_outputs = ImperativeQuantizeOutputs(
moving_rate, activation_bits)

def quantize(self, model):
"""
Expand Down Expand Up @@ -412,16 +413,18 @@ class ImperativeQuantizeOutputs(object):
Calculate the output scales for target layers.
"""

def __init__(self, moving_rate=0.9):
def __init__(self, moving_rate=0.9, activation_bits=8):
"""
The constructor for ImperativeQuantizeOutputs.
Args:
moving_rate(float): The decay coefficient of moving average.
The default value is 0.9.
activation_bits(int, optional): quantization bit number for activation. Default is 8.
"""
super(ImperativeQuantizeOutputs, self).__init__()
self._moving_rate = moving_rate
self._activation_bits = activation_bits

def apply(self, model):
"""
Expand Down Expand Up @@ -478,7 +481,7 @@ def save_quantized_model(self,
the saved model. Default None.
onnx_format (bool, optional): Whether to export the quantized model
with format of ONNX. Default is False.
**configs (dict, optional): Other save configuration options for
**config (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use
them. Default None.
Expand Down Expand Up @@ -518,27 +521,30 @@ def save_quantized_model(self,
model_filename=model_filename,
params_filename=params_filename))

self._gather_scales(infer_program, scope, fetch_targets)
if not onnx_format:
self._gather_scales(infer_program, scope, fetch_targets)

# Remove `moving_average_abs_max_scale` node in sub graphs.
graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
for sub_graph in graph.all_sub_graphs():
for _op in sub_graph.all_op_nodes():
if _op.name() == "moving_average_abs_max_scale":
sub_graph.safe_remove_nodes(_op)
sub_graph.resolve_hazard()
infer_program = graph.to_program()
# Remove `moving_average_abs_max_scale` node in sub graphs.
graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
for sub_graph in graph.all_sub_graphs():
for _op in sub_graph.all_op_nodes():
if _op.name() == "moving_average_abs_max_scale":
sub_graph.safe_remove_nodes(_op)
sub_graph.resolve_hazard()
infer_program = graph.to_program()

self._set_skip_quant_attr(infer_program)
self._set_skip_quant_attr(infer_program)

clip_extra = False
if onnx_format:
clip_extra = False
else:
graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
transform_pass = ReplaceFakeQuantDequantPass(scope, place)
transform_pass = ReplaceFakeQuantDequantPass(
scope, place, quant_bits=self._activation_bits)
transform_pass.apply(graph)

quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(graph)

infer_program = graph.to_program()

clip_extra = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def __init__(self,
self._fetch_list = None
self._data_loader = data_loader

self._out_scale_op_list = utils._out_scale_op_list
self._out_scale_op_list = utils.QUANT_SUPPORTED_OP_TYPE_LIST
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._weight_op_pairs = {}
Expand Down Expand Up @@ -843,9 +843,6 @@ def _sample_histogram(self):
hist, _ = np.histogram(var_tensor_abs, bins=bins)
self._sampling_act_histogram[var_name][0] += hist

def l2_loss(self, gt, pred):
return ((gt - pred)**2).mean()

def _sample_ptf(self):
"""
The following code are modified from:
Expand Down Expand Up @@ -885,10 +882,10 @@ def _sample_ptf(self):
q_max) * scale4
quant_dequant_var_scale8 = np.clip(np.round(var_tensor / scale8), 0,
q_max) * scale8
score1 = self.l2_loss(var_tensor, quant_dequant_var_scale1)
score2 = self.l2_loss(var_tensor, quant_dequant_var_scale2)
score4 = self.l2_loss(var_tensor, quant_dequant_var_scale4)
score8 = self.l2_loss(var_tensor, quant_dequant_var_scale8)
score1 = utils.l2_loss(var_tensor, quant_dequant_var_scale1)
score2 = utils.l2_loss(var_tensor, quant_dequant_var_scale2)
score4 = utils.l2_loss(var_tensor, quant_dequant_var_scale4)
score8 = utils.l2_loss(var_tensor, quant_dequant_var_scale8)
score = [score1, score2, score4, score8]
mask = 2**score.index(min(score))
scale = scale1 * mask
Expand Down Expand Up @@ -1035,7 +1032,7 @@ def _update_program(self):
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=self._is_full_quantize)
is_full_quantized=True)

for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
Expand Down
Loading

0 comments on commit 84333cf

Please sign in to comment.