Skip to content

Commit

Permalink
Merge pull request #471 from onnx/gs/fix-random-uniform
Browse files Browse the repository at this point in the history
fix dynamic shape in tf.random_uniform for some cases
  • Loading branch information
guschmue authored Apr 22, 2019
2 parents d03e469 + 8d29126 commit 27f5e67
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 14 deletions.
27 changes: 27 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,33 @@ def test_randomuniform_int(self):
# since results are random, compare the shapes only
self._run_test_case([_OUTPUT], {}, check_value=False, check_shape=True)

@skip_caffe2_backend()
def test_randomuniform_dyn_shape(self):
# test for dynamic shape coming from a shape op
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
x = tf.placeholder(x_val.dtype, name=_TFINPUT)
x_ = tf.stack([x, x])
x_ = tf.identity(x_)
x_ = tf.shape(x_, name="shape")
x_ = tf.random_uniform(x_, name="rand", dtype=tf.float32)
x_ = tf.identity(x_)
_ = tf.identity(x_, name=_TFOUTPUT)
# since results are random, compare the shapes only
self._run_test_case([_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)

@skip_caffe2_backend()
def test_randomuniform_calc_shape(self):
# test for dynamic shape coming from some subgraph
x_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
x = tf.placeholder(x_val.dtype, [None, 3], name=_TFINPUT)
x_ = tf.identity(x)
x_ = tf.shape(x_, name="shape")[1:]
x_ = tf.random_uniform(x_, name="rand", dtype=tf.float32)
x_ = tf.identity(x_)
_ = tf.identity(x_, name=_TFOUTPUT)
# since results are random, compare the shapes only
self._run_test_case([_OUTPUT], {_INPUT: x_val}, check_value=False, check_shape=True)

@skip_caffe2_backend()
def test_argminmax(self):
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
Expand Down
54 changes: 40 additions & 14 deletions tf2onnx/rewriter/random_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx random_uniform op
"""
import numpy as np
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
from tf2onnx import utils
from tf2onnx import utils, handler


# pylint: disable=missing-docstring
Expand All @@ -29,10 +30,10 @@ def rewrite_random_uniform(g, ops):
# max is on input 0
tmax = input2.inputs[0].get_tensor_value()
tmin = input2.inputs[1].get_tensor_value()

new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output)
to_delete = list(set(match.get_nodes()))
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
for n in set(match.get_nodes()):
for n in to_delete:
g.remove_node(n.name)

return ops
Expand All @@ -59,25 +60,50 @@ def rewrite_random_uniform_fold_const(g, ops):
tmax_minus_tmin = mul.inputs[1].get_tensor_value()
tmin = output.inputs[1].get_tensor_value()
tmax = tmin + tmax_minus_tmin
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output)
to_delete = list(set(match.get_nodes()))
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
for n in set(match.get_nodes()):
for n in to_delete:
g.remove_node(n.name)

return ops


def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output):
def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete):
dtype = g.get_dtype(output.output[0])
op_name = utils.make_name("RandomUniform")
if ru_op.inputs[0].type == "Shape":
shape_node = ru_op.inputs[0]
new_node = g.make_node("RandomUniformLike", inputs=[shape_node.input[0]], name=op_name,
attr={"low": tmin, "high": tmax, "dtype": dtype},
shapes=shape_node.output_shapes, dtypes=[dtype])
else:
shape = g.get_shape(output.output[0])
shape_node = ru_op.inputs[0]
shape = g.get_shape(output.output[0])
if shape_node.is_const():
# if the tensorflow input (aka the shape) is const we can use the RandomUniform op
new_node = g.make_node("RandomUniform", [], name=op_name,
attr={"low": tmin, "high": tmax, "dtype": dtype, "shape": shape},
shapes=[shape], dtypes=[dtype])
else:
if shape_node.type == "Shape":
# if shape is dynamic - in tensorflow shape comes as tensor VALUE,
# in onnx RandomUniformLike finds takes the shape from the tensor itself.
# In many cases there is a shape op in tensorflow before RandomUniform and
# to make that work for onnx we just need to remove the shape op.
new_node = g.make_node("RandomUniformLike", inputs=[shape_node.input[0]], name=op_name,
attr={"low": tmin, "high": tmax, "dtype": dtype},
shapes=shape, dtypes=[dtype])
else:
# if the shape is calculated we need to create a tensor so RandomUniformLike
# can take the shape from there. Pre opset9 this is somewhat hacky because there is
# no real fill op in onnx. In general this is not going to help performance but the tensors
# created are expected to be small.

# tell the caller to not delete the shape node
to_delete.remove(shape_node)
# create a fill op with the shape of the value of the input tensor
zero = g.make_const(utils.make_name("zero"), np.zeros((), dtype=np.float32))
fill_node = g.make_node("Fill", inputs=[shape_node.output[0], zero.name],
shapes=shape, dtypes=[dtype])
func, _ = handler.tf_op.find_effective_op("Fill")
func(g, fill_node)
# and use RandomUniformLike to create the random tensor
new_node = g.make_node("RandomUniformLike", inputs=[fill_node.output[0]], name=op_name,
attr={"low": tmin, "high": tmax, "dtype": dtype},
shapes=shape, dtypes=[dtype])
return new_node

0 comments on commit 27f5e67

Please sign in to comment.