Skip to content

Commit

Permalink
Fixed _infer_reduce_common arguments (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ai-Albert authored and cjm715 committed Nov 21, 2024
1 parent a11878d commit 559b470
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ def _infer_ReduceSum(self, node): # noqa: N802
keep_dims = get_attribute(node, "keepdims", 1)
if get_opset(self.out_mp_) >= 13 and len(node.input) > 1:
# ReduceSum changes axes to input[1] in opset 13
self._infer_reduce_common(node)
self._infer_reduce_common(node, keep_dims)

def _infer_ReduceProd(self, node): # noqa: N802
axes = get_attribute(node, "axes")
Expand All @@ -1678,9 +1678,9 @@ def _infer_ReduceMean(self, node):
keep_dims = get_attribute(node, "keepdims", 1)
if get_opset(self.out_mp_) >= 18 and len(node.input) > 1:
# ReduceMean changes axes to input[1] in opset 18
self._infer_reduce_common(node)
self._infer_reduce_common(node, keep_dims)

def _infer_reduce_common(self, node):
def _infer_reduce_common(self, node, keep_dims):
axes = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
if axes is None:
Expand Down

0 comments on commit 559b470

Please sign in to comment.