From ba317c8eae2a3cd883bd12a49079054a2a4efb95 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 23 Feb 2023 19:46:30 +0800 Subject: [PATCH] fix the op fusion issue to keep TensorMeta (#25) * fix the oop fusion issue * add UT * lint * add linear cases * lint --- test/quantization/fx/test_quantize_pt2e.py | 63 +++++++++ torch/_inductor/overrides.py | 150 +++++++++++++++++++++ 2 files changed, 213 insertions(+) diff --git a/test/quantization/fx/test_quantize_pt2e.py b/test/quantization/fx/test_quantize_pt2e.py index df05cb9b03106..24ed2308189c4 100644 --- a/test/quantization/fx/test_quantize_pt2e.py +++ b/test/quantization/fx/test_quantize_pt2e.py @@ -405,6 +405,40 @@ def forward(self, x): for with_bias, use_relu in cases: self._test_inductor_backend_helper(Mod(with_bias, use_relu), input_shape) + def test_conv2d_relu_conv2d_inductor_backend(self): + ''' + Test to ensure the TensorMeta keet after fusion + ''' + class Mod(torch.nn.Module): + def __init__(self, with_bias: bool) -> None: + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + bias=with_bias + ) + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + bias=with_bias + ) + + def forward(self, x): + x = self.conv(x) + return self.conv2(self.relu(x)) + + input_shape = (1, 3, 16, 16) + with_bias_list = [True, False] + for with_bias in with_bias_list: + self._test_inductor_backend_helper(Mod(with_bias), input_shape) + def test_conv3d_inductor_backend(self): ''' Quantize and lower convolution 3d + relu with Inductor quantization backend. @@ -461,3 +495,32 @@ def forward(self, x): cases = itertools.product(with_bias_list, use_relu_list) for with_bias, use_relu in cases: self._test_inductor_backend_helper(Mod(with_bias, use_relu), input_shape) + + def test_linear_relu_linear_inductor_backend(self): + ''' + Test for linear to keep tensor meta. + For experiment. + ''' + class Mod(torch.nn.Module): + def __init__(self, with_bias: bool) -> None: + super().__init__() + self.linear = torch.nn.Linear( + in_features=16, + out_features=8, + bias=with_bias + ) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear( + in_features=8, + out_features=8, + bias=with_bias + ) + + def forward(self, x): + x = self.linear(x) + return self.linear2(self.relu(x)) + + input_shape = (1, 16) + with_bias_list = [True, False] + for with_bias in with_bias_list: + self._test_inductor_backend_helper(Mod(with_bias), input_shape) diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index fb52bb8fcfdfa..d51148269341e 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -812,6 +812,79 @@ def pre_quantize_weights(gm: torch.fx.GraphModule): def fuse_reference_quantized_conv(gm: torch.fx.GraphModule): """ For experiment + Replace pattern: + # dequantize_per_channel - + # dequantize_per_tensor - conv - post_op(or none) - quantize_per_tensor + into new pattern: + # torch.ops.quantized_decomposed.conv_unary_inductor + """ + aten = torch.ops.aten + quantized_decomposed = torch.ops.quantized_decomposed + convolution = aten.convolution.default + relu = aten.relu.default + relu_ = aten.relu_.default + quantize_per_tensor = quantized_decomposed.quantize_per_tensor + dequantize_per_tensor = quantized_decomposed.dequantize_per_tensor + dequantize_per_channel = quantized_decomposed.dequantize_per_channel + + unary_post_ops = { + 'relu' : relu, + 'relu_' : relu_, + } + for name, unary_post_op in unary_post_ops.items(): + for node in gm.graph.nodes: + if node.target is convolution: + (x, w, bias, stride, padding, dilation, is_transposed, out_padding, groups) = node.args + assert x.target == dequantize_per_tensor, "input's node should be dequantize_per_tensor" + assert w.target == dequantize_per_channel, "weight's node should be dequantize_per_channel" + (qx, x_scale, x_zp, x_quant_min, x_quant_max, x_dtype) = x.args + (qw, w_scale, w_zp, w_axis, w_quant_min, w_quant_max, w_dtype) = w.args + if len(list(node.users)) != 1: + # There are more than 1 users of this conv node, fail to fuse + continue + + post_unary_op_is_not_none = False + if list(node.users)[0].target is unary_post_op: + # conv relu fusion + unary_op_to_be_fused = list(node.users)[0] + post_unary_op_is_not_none = True + if list(unary_op_to_be_fused.users)[0].target != quantize_per_tensor: + # Not meet fusion pattern: the op after unary_op is not quantize_per_tensor + continue + quant_per_tensor_node = list(unary_op_to_be_fused.users)[0] + elif list(node.users)[0].target is quantize_per_tensor: + # Single conv without post op + quant_per_tensor_node = list(node.users)[0] + else: + # Not meet fusion pattern: the op after conv is not unary_op to be fused or quantize_per_tensor + continue + + (y, y_scale, y_zp, y_quant_min, y_quant_max, y_dtype) = quant_per_tensor_node.args + with gm.graph.inserting_after(quant_per_tensor_node): + args = (qx, x_scale, x_zp, qw, w_scale, w_zp, w_axis, + bias, stride, padding, dilation, groups, y_scale, + y_zp, name if post_unary_op_is_not_none else "none") + new_conv_node = gm.graph.call_function(quantized_decomposed.conv_unary_inductor, args=args) + # Copy node meta + new_conv_node.meta = copy.copy(quant_per_tensor_node.meta) + quant_per_tensor_node.replace_all_uses_with(new_conv_node) + + gm.graph.erase_node(quant_per_tensor_node) # erase quantize_per_tensor + if post_unary_op_is_not_none: + gm.graph.erase_node(unary_op_to_be_fused) # erase unary_op + gm.graph.erase_node(node) # erase conv + gm.graph.erase_node(w) # erase dequantize_per_channel + gm.graph.erase_node(x) # erase dequantize_per_tensor + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm + +def fuse_reference_quantized_conv_legacy(gm: torch.fx.GraphModule): + """ + For experiment + The legacy method using the subgraph_rewriter which missing the tensor meta information after fusion """ aten = torch.ops.aten quantized_decomposed = torch.ops.quantized_decomposed @@ -869,6 +942,83 @@ def fuse_reference_quantized_linear(gm: torch.fx.GraphModule): """ For experiment Linear is decomposed to t + addmm (with bias) or t + mm (no bias) + + Case1: Replace pattern: + # dequantize_per_channel - t - + # dequantize_per_tensor - addmm - post_op(or none) - quantize_per_tensor + # bias - + into new pattern: + # torch.ops.quantized_decomposed.linear_unary_inductor + + Case2: Replace pattern: + # dequantize_per_channel - t - + # dequantize_per_tensor - mm - post_op(or none) - quantize_per_tensor + into new pattern: + # torch.ops.quantized_decomposed.linear_unary_inductor + """ + aten = torch.ops.aten + quantized_decomposed = torch.ops.quantized_decomposed + t = aten.t.default + addmm = aten.addmm.default # for linear with bias + mm = aten.mm.default # for linear without bias + relu = aten.relu.default + quantize_per_tensor = quantized_decomposed.quantize_per_tensor + dequantize_per_tensor = quantized_decomposed.dequantize_per_tensor + dequantize_per_channel = quantized_decomposed.dequantize_per_channel + + unary_post_ops = { + 'relu' : relu, + } + for name, unary_post_op in unary_post_ops.items(): + for node in gm.graph.nodes: + if node.target in [addmm, mm]: + if node.target is addmm: + (bias, x, w_t) = node.args + elif node.target is mm: + (x, w_t) = node.args + (w,) = w_t.args + (qw, w_scale, w_zp, w_axis, w_quant_min, w_quant_max, w_dtype) = w.args + (qx, x_scale, x_zp, x_quant_min, x_quant_max, x_dtype) = x.args + post_unary_op_is_not_none = False + if list(node.users)[0].target is unary_post_op: + # linear relu fusion + unary_op_to_be_fused = list(node.users)[0] + post_unary_op_is_not_none = True + if list(unary_op_to_be_fused.users)[0].target != quantize_per_tensor: + # Not meet fusion pattern: the op after unary_op is not quantize_per_tensor + continue + quant_per_tensor_node = list(unary_op_to_be_fused.users)[0] + elif list(node.users)[0].target is quantize_per_tensor: + quant_per_tensor_node = list(node.users)[0] + else: + # Not meet fusion pattern: the op after linear is not unary_op to be fused or quantize_per_tensor + continue + (y, y_scale, y_zp, y_quant_min, y_quant_max, y_dtype) = quant_per_tensor_node.args + with gm.graph.inserting_after(quant_per_tensor_node): + args = (qx, x_scale, x_zp, qw, w_scale, w_zp, w_axis, + bias if node.target is addmm else None, y_scale, y_zp, name if post_unary_op_is_not_none else "none") + new_linear_node = gm.graph.call_function(quantized_decomposed.linear_unary_inductor, args=args) + # Copy node meta + new_linear_node.meta = copy.copy(quant_per_tensor_node.meta) + quant_per_tensor_node.replace_all_uses_with(new_linear_node) + + gm.graph.erase_node(quant_per_tensor_node) # erase quantize_per_tensor + if post_unary_op_is_not_none: + gm.graph.erase_node(unary_op_to_be_fused) # erase unary_op + gm.graph.erase_node(node) # erase conv + gm.graph.erase_node(w_t) # erase transposed node + gm.graph.erase_node(w) # erase dequantize_per_channel + gm.graph.erase_node(x) # erase dequantize_per_tensor + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + return gm + +def fuse_reference_quantized_linear_legacy(gm: torch.fx.GraphModule): + """ + For experiment, use subgraph_rewriter.replace_pattern + Linear is decomposed to t + addmm (with bias) or t + mm (no bias) """ aten = torch.ops.aten quantized_decomposed = torch.ops.quantized_decomposed