Skip to content

Commit

Permalink
Add apply_squeeze (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom authored Dec 26, 2019
1 parent 82f08ed commit 98ee169
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions onnxconverter_common/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,21 @@ def apply_sqrt(scope, input_name, output_name, container, operator_name=None):
_apply_unary_operation(scope, 'Sqrt', input_name, output_name, container, operator_name=operator_name)


def _apply_squeeze_unsqueeze(scope, input_name, output_name, container, squeeze_str, operator_name=None, axis=0, rank=0):
name = _create_name_or_use_existing_one(scope, squeeze_str, operator_name)
if container.target_opset < 11:
op_version = 1
if axis < 0:
axis += rank + 1
else:
op_version = 11
container.add_node(squeeze_str, input_name, output_name, name=name, op_version=op_version, axes=[axis])


def apply_squeeze(scope, input_name, output_name, container, operator_name=None, axis=0, rank=0):
_apply_squeeze_unsqueeze(scope, input_name, output_name, container, 'Squeeze', operator_name, axis, rank)


def apply_sub(scope, input_names, output_name, container, operator_name=None, axis=None, broadcast=0):
_apply_basic_numerical_operation(scope, 'Sub', input_names, output_name, container, operator_name=operator_name,
axis=axis, broadcast=broadcast)
Expand Down Expand Up @@ -915,11 +930,4 @@ def apply_upsample(scope, input_name, output_name, container, operator_name=None


def apply_unsqueeze(scope, input_name, output_name, container, operator_name=None, axis=0, rank=0):
name = _create_name_or_use_existing_one(scope, 'Unsqueeze', operator_name)
if container.target_opset < 11:
op_version = 1
if axis < 0:
axis += rank + 1
else:
op_version = 11
container.add_node('Unsqueeze', input_name, output_name, name=name, op_version=op_version, axes=[axis])
_apply_squeeze_unsqueeze(scope, input_name, output_name, container, 'Unsqueeze', operator_name, axis, rank)

0 comments on commit 98ee169

Please sign in to comment.