Skip to content

Commit

Permalink
Add support and testing for tf.assert (as no-op) and tf.no_op to TF R…
Browse files Browse the repository at this point in the history
…elay frontend.
  • Loading branch information
broune committed Oct 22, 2019
1 parent 6f9d028 commit 55077d8
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 2 deletions.
27 changes: 25 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,24 @@ def _impl(inputs, attr, params):
return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
return _impl

def _assert():
# ToDo: In general people want asserts to be gone from TensorFlow graphs
# when they are optimizing them, so converting it to a no-op is
# reasonable. However, it would be nice to have the option to keep them
# once Relay gets a Halt or Assert op.
return _no_op()

def _no_op():
def _impl(inputs, attr, params):
# ToDo: This should really be an op that returns nothing, which could
# be represented as an empty tuple. It turns out that TVM
# infrastructure doesn't like running functions that return None and
# also don't like running functions that return an empty tuple. So it
# doesn't work, but it should be made to work and then this could be
# improved. In the mean time, it is hard to imagine a case where it
# matters in any real way that a no-op is converted to a constant 0.
return tvm.relay.const(0)
return _impl

def _matmul():
def _impl(inputs, attr, params):
Expand Down Expand Up @@ -1319,6 +1337,7 @@ def _impl(inputs, attr, params):
'All' : _reduce('all'),
'ArgMax' : _argx(_op.argmax, 'argmax'),
'ArgMin' : _argx(_op.argmin, 'argmin'),
'Assert' : _assert(),
'AvgPool' : _pooling('avg_pool'),
'BatchMatMul' : _batch_matmul(),
'BatchMatMulV2' : _batch_matmul(),
Expand Down Expand Up @@ -1377,6 +1396,7 @@ def _impl(inputs, attr, params):
'Mod' : _elemwise('mod'),
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
'NoOp' : _no_op(),
'NotEqual' : _broadcast('not_equal'),
'OneHot' : _one_hot(),
'Pack' : _pack(),
Expand Down Expand Up @@ -2189,8 +2209,11 @@ def _parse_param(self, key, value, name, shape):
if np_array.dtype == np.dtype(object):
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
# Just leave it as placeholder.
self._nodes[name] = [_expr.var(name, shape=shape[name], dtype='uint8')]

if shape:
var_shape = shape[name]
else:
var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
self._nodes[name] = [_expr.var(name, shape=var_shape, dtype='uint8')]
return

array_ndim = len(np_array.shape)
Expand Down
93 changes: 93 additions & 0 deletions tests/python/frontend/tensorflow/test_debugging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
import tensorflow as tf
import numpy as np
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow

def run_relay(graph, *vars):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
ex = relay.create_executor('debug', mod=mod)
return ex.evaluate()(*vars)

def test_assert_true():
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float32, shape=())
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"])

with tf.Session() as sess:
x_value = np.random.rand()
assert sess.run(assert_op, feed_dict={x: x_value}) is None

# In TVM, tf.assert is converted to a no-op which is actually a 0,
# though it should probably be none or an empty tuple.
#
# ToDo: It appears that the frontend converter gets confused here and
# entirely eliminates all operands from main(). Likely because x <= x
# is always true, so the placeholder can be eliminated. But TF doesn't
# do that, it's happening in Relay, and that optimization shouldn't
# affect the arity of the main function. We should have to pass in
# x_value here.
np.testing.assert_allclose(0, run_relay(g).asnumpy())

def test_assert_true_var_capture():
g = tf.Graph()
with g.as_default():
x = tf.placeholder(tf.float32, shape=())

# It turns out that tf.assert() creates a large and complex subgraph if
# you capture a variable as part of the error message. So we need to
# test that, too.
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed", x])

with tf.Session() as sess:
x_value = np.random.rand()
assert sess.run(assert_op, feed_dict={x: x_value}) is None

# ToDo: The frontend converter gets confused here as well, thinking
# that it needs to be told what x is twice. It also notes the output of
# the graph as a boolean, but this is actually correct, due to the
# strange way that TF creates graphdef from within the tf.assert()
# function. Though regardless the arity should be 1, not 2.
np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy())

def test_assert_false():
g = tf.Graph()
with g.as_default():
assert_op = tf.Assert(tf.constant(False), ["it failed"])

with tf.Session() as sess:
try:
print(sess.run(assert_op))
assert False # TF should have thrown an exception
except tf.errors.InvalidArgumentError as e:
assert "it failed" in e.message

# In TVM, tf.assert is converted to a no-op which is actually a 0,
# though it should probably be none or an empty tuple. For the same
# reason, there will not be an error here, even though the assertion
# argument is false.
np.testing.assert_allclose(0, run_relay(g).asnumpy())


if __name__ == "__main__":
test_assert_true()
test_assert_true_var_capture()
test_assert_false()

43 changes: 43 additions & 0 deletions tests/python/frontend/tensorflow/test_no_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow debugging ops to Relay."""
import tensorflow as tf
import numpy as np
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow

def run_relay(graph):
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
ex = relay.create_executor('debug', mod=mod)
return ex.evaluate()(**params)

def test_no_op():
g = tf.Graph()
with g.as_default():
no_op = tf.no_op()
with tf.Session() as sess:
# In TF, the type of a no-op is None.
assert sess.run(no_op) is None

# In TVM, no-op is currently translated to 0, though it should
# probably be none or an empty tuple.
np.testing.assert_allclose(0, run_relay(g).asnumpy())


if __name__ == "__main__":
test_no_op()

0 comments on commit 55077d8

Please sign in to comment.