Skip to content

Commit

Permalink
Handle dynamic shape in apply_reshape (#36)
Browse files Browse the repository at this point in the history
* Handle dynamic shape in apply_reshape
  • Loading branch information
jiafatom authored Dec 5, 2019
1 parent b9700d1 commit b5f216b
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions onnxconverter_common/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def apply_relu(scope, input_name, output_name, container, operator_name=None):


def apply_reshape(scope, input_name, output_name, container, operator_name=None, desired_shape=None):
if len(list(i for i in desired_shape if i is not None and i < 0)) > 1:
if not isinstance(desired_shape, str) and len(list(i for i in desired_shape if i is not None and i < 0)) > 1:
raise ValueError('There can only be one -1 in the targeted shape of a Reshape but got %s' % desired_shape)

name = _create_name_or_use_existing_one(scope, 'Reshape', operator_name)
Expand All @@ -557,9 +557,11 @@ def apply_reshape(scope, input_name, output_name, container, operator_name=None,
container.add_node('Reshape', input_name, output_name, op_version=1, name=name, shape=desired_shape,
consumed_inputs=[0])
else:
# The shape attribute of Reshape becomes a tensor input, so we create one tensor to store that attribute.
desired_shape_name = scope.get_unique_variable_name('shape_tensor')
container.add_initializer(desired_shape_name, onnx_proto.TensorProto.INT64, [len(desired_shape)], desired_shape)
if isinstance(desired_shape, str):
desired_shape_name = desired_shape
else:
desired_shape_name = scope.get_unique_variable_name('shape_tensor')
container.add_initializer(desired_shape_name, onnx_proto.TensorProto.INT64, [len(desired_shape)], desired_shape)

# Create ONNX Reshape operator
if isinstance(input_name, list):
Expand Down

0 comments on commit b5f216b

Please sign in to comment.