Skip to content

Commit

Permalink
Reformat and clean up some codes.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenbingl committed Dec 20, 2019
1 parent f684aeb commit 672c5f7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 88 deletions.
83 changes: 0 additions & 83 deletions onnxconverter_common/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,47 +39,6 @@ def output_names(self):
raise NotImplementedError()


class SparkmlModelContainer(RawModelContainer):

def __init__(self, sparkml_model):
super(SparkmlModelContainer, self).__init__(sparkml_model)
# Sparkml models have no input and output specified, so we create them and store them in this container.
self._inputs = []
self._outputs = []

@property
def input_names(self):
return [variable.raw_name for variable in self._inputs]

@property
def output_names(self):
return [variable.raw_name for variable in self._outputs]

def add_input(self, variable):
# The order of adding variables matters. The final model's input names are sequentially added as this list
if variable not in self._inputs:
self._inputs.append(variable)

def add_output(self, variable):
# The order of adding variables matters. The final model's output names are sequentially added as this list
if variable not in self._outputs:
self._outputs.append(variable)


class CoremlModelContainer(RawModelContainer):

def __init__(self, coreml_model):
super(CoremlModelContainer, self).__init__(coreml_model)

@property
def input_names(self):
return [str(var.name) for var in self.raw_model.description.input]

@property
def output_names(self):
return [str(var.name) for var in self.raw_model.description.output]


class CommonSklearnModelContainer(RawModelContainer):

def __init__(self, sklearn_model):
Expand Down Expand Up @@ -107,48 +66,6 @@ def add_output(self, variable):
self._outputs.append(variable)


class SklearnModelContainer(CommonSklearnModelContainer):
pass


class LibSvmModelContainer(CommonSklearnModelContainer):
pass


class LightGbmModelContainer(CommonSklearnModelContainer):
pass


class XGBoostModelContainer(CommonSklearnModelContainer):
pass


class KerasModelContainer(RawModelContainer):

def __init__(self, keras_model):
super(KerasModelContainer, self).__init__(keras_model)
self._input_raw_names = list()
self._output_raw_names = list()

def add_input_name(self, name):
# The order of adding strings matters. The final model's input names are sequentially added as this list
if name not in self._input_raw_names:
self._input_raw_names.append(name)

def add_output_name(self, name):
# The order of adding strings matters. The final model's output names are sequentially added as this list
if name not in self._output_raw_names:
self._output_raw_names.append(name)

@property
def input_names(self):
return [name for name in self._input_raw_names]

@property
def output_names(self):
return [name for name in self._output_raw_names]


class ModelComponentContainer(ModelContainer):
'''
In the conversion phase, this class is used to collect all materials required to build an ONNX GraphProto, which is
Expand Down
3 changes: 2 additions & 1 deletion onnxconverter_common/metadata_props.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def set_denotation(onnx_model, input_name, denotation, target_opset, dimension_d
if dimension_denotation:
dimensions = graph_input.type.tensor_type.shape.dim
if len(dimension_denotation) != len(dimensions):
raise RuntimeError('Wrong number of dimensions: input "{}" has {} dimensions'.format(input_name, len(dimensions)))
raise RuntimeError(
'Wrong number of dimensions: input "{}" has {} dimensions'.format(input_name, len(dimensions)))
for dimension, channel_denotation in zip(dimensions, dimension_denotation):
dimension.denotation = channel_denotation
return onnx_model
Expand Down
12 changes: 8 additions & 4 deletions onnxconverter_common/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ def apply_reshape(scope, input_name, output_name, container, operator_name=None,
desired_shape_name = desired_shape
else:
desired_shape_name = scope.get_unique_variable_name('shape_tensor')
container.add_initializer(desired_shape_name, onnx_proto.TensorProto.INT64, [len(desired_shape)], desired_shape)
container.add_initializer(desired_shape_name, onnx_proto.TensorProto.INT64, [len(desired_shape)],
desired_shape)

# Create ONNX Reshape operator
if isinstance(input_name, list):
Expand All @@ -571,7 +572,8 @@ def apply_reshape(scope, input_name, output_name, container, operator_name=None,
container.add_node('Reshape', input_name, output_name, op_version=5, name=name)


def apply_resize(scope, input_name, output_name, container, operator_name=None, mode='nearest', coordinate_transformation_mode='asymmetric', scales=None):
def apply_resize(scope, input_name, output_name, container, operator_name=None, mode='nearest',
coordinate_transformation_mode='asymmetric', scales=None):
'''
:param mode: "nearest" or "linear"
:param scales: a float tensor for scaling (upsampling or downsampling) all input dimensions
Expand Down Expand Up @@ -806,7 +808,8 @@ def apply_transpose(scope, input_name, output_name, container, operator_name=Non
container.add_node('Transpose', input_name, output_name, name=name, perm=perm)


def apply_upsample(scope, input_name, output_name, container, operator_name=None, mode='nearest', coordinate_transformation_mode='asymmetric', scales=None):
def apply_upsample(scope, input_name, output_name, container, operator_name=None, mode='nearest',
coordinate_transformation_mode='asymmetric', scales=None):
'''
:param mode: nearest or linear
:param scales: an integer list of scaling-up rate of all input dimensions
Expand Down Expand Up @@ -838,4 +841,5 @@ def apply_upsample(scope, input_name, output_name, container, operator_name=None
else:
# Upsample op is deprecated in ONNX opset 10
# We implement Upsample through Resize instead
apply_resize(scope, input_name, output_name, container, operator_name, mode, coordinate_transformation_mode, scales)
apply_resize(scope, input_name, output_name, container, operator_name, mode, coordinate_transformation_mode,
scales)

0 comments on commit 672c5f7

Please sign in to comment.