Skip to content

Commit

Permalink
fix the op fusion issue to keep TensorMeta (pytorch#25)
Browse files Browse the repository at this point in the history
* fix the oop fusion issue

* add UT

* lint

* add linear cases

* lint
  • Loading branch information
leslie-fang-intel authored Feb 23, 2023
1 parent 9a2d0c7 commit ba317c8
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 0 deletions.
63 changes: 63 additions & 0 deletions test/quantization/fx/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,40 @@ def forward(self, x):
for with_bias, use_relu in cases:
self._test_inductor_backend_helper(Mod(with_bias, use_relu), input_shape)

def test_conv2d_relu_conv2d_inductor_backend(self):
'''
Test to ensure the TensorMeta keet after fusion
'''
class Mod(torch.nn.Module):
def __init__(self, with_bias: bool) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
bias=with_bias
)
self.relu = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
bias=with_bias
)

def forward(self, x):
x = self.conv(x)
return self.conv2(self.relu(x))

input_shape = (1, 3, 16, 16)
with_bias_list = [True, False]
for with_bias in with_bias_list:
self._test_inductor_backend_helper(Mod(with_bias), input_shape)

def test_conv3d_inductor_backend(self):
'''
Quantize and lower convolution 3d + relu with Inductor quantization backend.
Expand Down Expand Up @@ -461,3 +495,32 @@ def forward(self, x):
cases = itertools.product(with_bias_list, use_relu_list)
for with_bias, use_relu in cases:
self._test_inductor_backend_helper(Mod(with_bias, use_relu), input_shape)

def test_linear_relu_linear_inductor_backend(self):
'''
Test for linear to keep tensor meta.
For experiment.
'''
class Mod(torch.nn.Module):
def __init__(self, with_bias: bool) -> None:
super().__init__()
self.linear = torch.nn.Linear(
in_features=16,
out_features=8,
bias=with_bias
)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(
in_features=8,
out_features=8,
bias=with_bias
)

def forward(self, x):
x = self.linear(x)
return self.linear2(self.relu(x))

input_shape = (1, 16)
with_bias_list = [True, False]
for with_bias in with_bias_list:
self._test_inductor_backend_helper(Mod(with_bias), input_shape)
150 changes: 150 additions & 0 deletions torch/_inductor/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,79 @@ def pre_quantize_weights(gm: torch.fx.GraphModule):
def fuse_reference_quantized_conv(gm: torch.fx.GraphModule):
"""
For experiment
Replace pattern:
# dequantize_per_channel -
# dequantize_per_tensor - conv - post_op(or none) - quantize_per_tensor
into new pattern:
# torch.ops.quantized_decomposed.conv_unary_inductor
"""
aten = torch.ops.aten
quantized_decomposed = torch.ops.quantized_decomposed
convolution = aten.convolution.default
relu = aten.relu.default
relu_ = aten.relu_.default
quantize_per_tensor = quantized_decomposed.quantize_per_tensor
dequantize_per_tensor = quantized_decomposed.dequantize_per_tensor
dequantize_per_channel = quantized_decomposed.dequantize_per_channel

unary_post_ops = {
'relu' : relu,
'relu_' : relu_,
}
for name, unary_post_op in unary_post_ops.items():
for node in gm.graph.nodes:
if node.target is convolution:
(x, w, bias, stride, padding, dilation, is_transposed, out_padding, groups) = node.args
assert x.target == dequantize_per_tensor, "input's node should be dequantize_per_tensor"
assert w.target == dequantize_per_channel, "weight's node should be dequantize_per_channel"
(qx, x_scale, x_zp, x_quant_min, x_quant_max, x_dtype) = x.args
(qw, w_scale, w_zp, w_axis, w_quant_min, w_quant_max, w_dtype) = w.args
if len(list(node.users)) != 1:
# There are more than 1 users of this conv node, fail to fuse
continue

post_unary_op_is_not_none = False
if list(node.users)[0].target is unary_post_op:
# conv relu fusion
unary_op_to_be_fused = list(node.users)[0]
post_unary_op_is_not_none = True
if list(unary_op_to_be_fused.users)[0].target != quantize_per_tensor:
# Not meet fusion pattern: the op after unary_op is not quantize_per_tensor
continue
quant_per_tensor_node = list(unary_op_to_be_fused.users)[0]
elif list(node.users)[0].target is quantize_per_tensor:
# Single conv without post op
quant_per_tensor_node = list(node.users)[0]
else:
# Not meet fusion pattern: the op after conv is not unary_op to be fused or quantize_per_tensor
continue

(y, y_scale, y_zp, y_quant_min, y_quant_max, y_dtype) = quant_per_tensor_node.args
with gm.graph.inserting_after(quant_per_tensor_node):
args = (qx, x_scale, x_zp, qw, w_scale, w_zp, w_axis,
bias, stride, padding, dilation, groups, y_scale,
y_zp, name if post_unary_op_is_not_none else "none")
new_conv_node = gm.graph.call_function(quantized_decomposed.conv_unary_inductor, args=args)
# Copy node meta
new_conv_node.meta = copy.copy(quant_per_tensor_node.meta)
quant_per_tensor_node.replace_all_uses_with(new_conv_node)

gm.graph.erase_node(quant_per_tensor_node) # erase quantize_per_tensor
if post_unary_op_is_not_none:
gm.graph.erase_node(unary_op_to_be_fused) # erase unary_op
gm.graph.erase_node(node) # erase conv
gm.graph.erase_node(w) # erase dequantize_per_channel
gm.graph.erase_node(x) # erase dequantize_per_tensor

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
return gm

def fuse_reference_quantized_conv_legacy(gm: torch.fx.GraphModule):
"""
For experiment
The legacy method using the subgraph_rewriter which missing the tensor meta information after fusion
"""
aten = torch.ops.aten
quantized_decomposed = torch.ops.quantized_decomposed
Expand Down Expand Up @@ -869,6 +942,83 @@ def fuse_reference_quantized_linear(gm: torch.fx.GraphModule):
"""
For experiment
Linear is decomposed to t + addmm (with bias) or t + mm (no bias)
Case1: Replace pattern:
# dequantize_per_channel - t -
# dequantize_per_tensor - addmm - post_op(or none) - quantize_per_tensor
# bias -
into new pattern:
# torch.ops.quantized_decomposed.linear_unary_inductor
Case2: Replace pattern:
# dequantize_per_channel - t -
# dequantize_per_tensor - mm - post_op(or none) - quantize_per_tensor
into new pattern:
# torch.ops.quantized_decomposed.linear_unary_inductor
"""
aten = torch.ops.aten
quantized_decomposed = torch.ops.quantized_decomposed
t = aten.t.default
addmm = aten.addmm.default # for linear with bias
mm = aten.mm.default # for linear without bias
relu = aten.relu.default
quantize_per_tensor = quantized_decomposed.quantize_per_tensor
dequantize_per_tensor = quantized_decomposed.dequantize_per_tensor
dequantize_per_channel = quantized_decomposed.dequantize_per_channel

unary_post_ops = {
'relu' : relu,
}
for name, unary_post_op in unary_post_ops.items():
for node in gm.graph.nodes:
if node.target in [addmm, mm]:
if node.target is addmm:
(bias, x, w_t) = node.args
elif node.target is mm:
(x, w_t) = node.args
(w,) = w_t.args
(qw, w_scale, w_zp, w_axis, w_quant_min, w_quant_max, w_dtype) = w.args
(qx, x_scale, x_zp, x_quant_min, x_quant_max, x_dtype) = x.args
post_unary_op_is_not_none = False
if list(node.users)[0].target is unary_post_op:
# linear relu fusion
unary_op_to_be_fused = list(node.users)[0]
post_unary_op_is_not_none = True
if list(unary_op_to_be_fused.users)[0].target != quantize_per_tensor:
# Not meet fusion pattern: the op after unary_op is not quantize_per_tensor
continue
quant_per_tensor_node = list(unary_op_to_be_fused.users)[0]
elif list(node.users)[0].target is quantize_per_tensor:
quant_per_tensor_node = list(node.users)[0]
else:
# Not meet fusion pattern: the op after linear is not unary_op to be fused or quantize_per_tensor
continue
(y, y_scale, y_zp, y_quant_min, y_quant_max, y_dtype) = quant_per_tensor_node.args
with gm.graph.inserting_after(quant_per_tensor_node):
args = (qx, x_scale, x_zp, qw, w_scale, w_zp, w_axis,
bias if node.target is addmm else None, y_scale, y_zp, name if post_unary_op_is_not_none else "none")
new_linear_node = gm.graph.call_function(quantized_decomposed.linear_unary_inductor, args=args)
# Copy node meta
new_linear_node.meta = copy.copy(quant_per_tensor_node.meta)
quant_per_tensor_node.replace_all_uses_with(new_linear_node)

gm.graph.erase_node(quant_per_tensor_node) # erase quantize_per_tensor
if post_unary_op_is_not_none:
gm.graph.erase_node(unary_op_to_be_fused) # erase unary_op
gm.graph.erase_node(node) # erase conv
gm.graph.erase_node(w_t) # erase transposed node
gm.graph.erase_node(w) # erase dequantize_per_channel
gm.graph.erase_node(x) # erase dequantize_per_tensor

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
return gm

def fuse_reference_quantized_linear_legacy(gm: torch.fx.GraphModule):
"""
For experiment, use subgraph_rewriter.replace_pattern
Linear is decomposed to t + addmm (with bias) or t + mm (no bias)
"""
aten = torch.ops.aten
quantized_decomposed = torch.ops.quantized_decomposed
Expand Down

0 comments on commit ba317c8

Please sign in to comment.