Skip to content

Commit

Permalink
Add apply_gather, apply_greater, apply_gru, etc (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom authored Dec 20, 2019
1 parent 672c5f7 commit 82f08ed
Showing 1 changed file with 85 additions and 5 deletions.
90 changes: 85 additions & 5 deletions onnxconverter_common/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,25 @@ def apply_floor(scope, input_name, output_name, container, operator_name=None):
_apply_unary_operation(scope, 'Floor', input_name, output_name, container, operator_name=operator_name)


def apply_flatten(scope, input_name, output_name, container, operator_name=None):
def apply_flatten(scope, input_name, output_name, container, operator_name=None, axis=1):
name = _create_name_or_use_existing_one(scope, 'Flatten', operator_name)
container.add_node('Flatten', input_name, output_name, name=name)
if container.target_opset < 9:
op_version = 1
elif container.target_opset < 11:
op_version = 9
else:
op_version = 11
container.add_node('Flatten', input_name, output_name, name=name, op_version=op_version, axis=axis)


def apply_gather(scope, input_names, output_name, container, operator_name=None, axis=0):
name = _create_name_or_use_existing_one(scope, 'Gather', operator_name)
if container.target_opset < 11:
op_version = 1
else:
op_version = 11

container.add_node('Gather', input_names, output_name, name=name, op_version=op_version, axis=axis)


def apply_gemm(scope, input_name, output_name, container, operator_name=None, alpha=1.0, beta=1.0,
Expand All @@ -362,6 +378,34 @@ def apply_gemm(scope, input_name, output_name, container, operator_name=None, al
container.add_node('Gemm', input_name, output_name, name=name, **attrs)


def apply_greater(scope, input_names, output_name, container, operator_name=None):
name = _create_name_or_use_existing_one(scope, 'Greater', operator_name)
if container.target_opset < 7:
op_version = 1
elif container.target_opset < 9:
op_version = 7
else:
op_version = 9

container.add_node('Greater', input_names, output_name, name=name, op_version=op_version)


def apply_gru(scope, input_names, output_names, container, operator_name=None, output_seq=0, reset_after=0, **attrs):
name = _create_name_or_use_existing_one(scope, 'GRU', operator_name)
if container.target_opset < 3:
op_version = 1
attrs['output_sequence'] = 1 if output_seq else 0
else:
attrs['linear_before_reset'] = 1 if reset_after else 0
if container.target_opset <= 5:
attrs['output_sequence'] = 1 if output_seq else 0
op_version = 3
else:
op_version = 7

container.add_node('GRU', input_names, output_names, name=name, op_version=op_version, **attrs)


def apply_hard_sigmoid(scope, input_name, output_name, container, operator_name=None, alpha=None, beta=None):
_apply_unary_operation(scope, 'HardSigmoid', input_name, output_name, container, operator_name,
alpha=alpha, beta=beta)
Expand Down Expand Up @@ -393,10 +437,24 @@ def apply_log(scope, input_name, output_name, container, operator_name=None):
_apply_unary_operation(scope, 'Log', input_name, output_name, container, operator_name=operator_name)


def apply_lstm(scope, input_names, output_names, container, operator_name=None, output_seq=0, **attrs):
name = _create_name_or_use_existing_one(scope, 'LSTM', operator_name)
if container.target_opset <= 6:
attrs['output_sequence'] = 1 if output_seq else 0
op_version = 1
else:
op_version = 7
container.add_node('LSTM', input_names, output_names, name=name, op_version=op_version, **attrs)


def apply_matmul(scope, input_names, output_name, container, operator_name=None):
op_type = 'MatMul'
name = _create_name_or_use_existing_one(scope, op_type, operator_name)
container.add_node(op_type, input_names, output_name, op_version=9, name=name)
if container.target_opset <= 9:
op_version = 1
else:
op_version = 9
container.add_node(op_type, input_names, output_name, op_version=op_version, name=name)


def apply_max(scope, input_names, output_name, container, operator_name=None):
Expand Down Expand Up @@ -425,7 +483,8 @@ def apply_normalization(scope, input_name, output_name, container, operator_name
container.add_node('LpNormalization', input_name, output_name, name=name, p=p, axis=axis)


def apply_pad(scope, input_name, output_name, container, operator_name=None, mode=None, pads=None, value=None):
def apply_pad(scope, input_name, output_name, container, operator_name=None, mode=None, pads=None, value=None,
onnx_type=onnx_proto.TensorProto.FLOAT):
name = _create_name_or_use_existing_one(scope, 'Pad', operator_name)
attrs = {'name': name}
inputs = input_name if isinstance(input_name, list) else [input_name]
Expand All @@ -449,7 +508,7 @@ def apply_pad(scope, input_name, output_name, container, operator_name=None, mod
inputs.append(pads_name)
if value is not None:
value_name = scope.get_unique_variable_name(name + '_value')
container.add_initializer(value_name, onnx_proto.TensorProto.FLOAT, [], [value])
container.add_initializer(value_name, onnx_type, [], [value])
inputs.append(value_name)

container.add_node('Pad', inputs, output_name, op_version=op_version, **attrs)
Expand Down Expand Up @@ -602,6 +661,16 @@ def apply_resize(scope, input_name, output_name, container, operator_name=None,
container.add_node('Resize', inputs, output_name, op_version=op_version, **attrs)


def apply_rnn(scope, input_names, output_names, container, operator_name=None, output_seq=0, **attrs):
name = _create_name_or_use_existing_one(scope, 'RNN', operator_name)
if container.target_opset <= 6:
attrs['output_sequence'] = 1 if output_seq else 0
op_version = 1
else:
op_version = 7
container.add_node('RNN', input_names, output_names, name=name, op_version=op_version, **attrs)


def apply_sigmoid(scope, input_name, output_name, container, operator_name=None):
_apply_unary_operation(scope, 'Sigmoid', input_name, output_name, container, operator_name)

Expand Down Expand Up @@ -843,3 +912,14 @@ def apply_upsample(scope, input_name, output_name, container, operator_name=None
# We implement Upsample through Resize instead
apply_resize(scope, input_name, output_name, container, operator_name, mode, coordinate_transformation_mode,
scales)


def apply_unsqueeze(scope, input_name, output_name, container, operator_name=None, axis=0, rank=0):
name = _create_name_or_use_existing_one(scope, 'Unsqueeze', operator_name)
if container.target_opset < 11:
op_version = 1
if axis < 0:
axis += rank + 1
else:
op_version = 11
container.add_node('Unsqueeze', input_name, output_name, name=name, op_version=op_version, axes=[axis])

0 comments on commit 82f08ed

Please sign in to comment.