diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index eb67cf24b81e8..5fc7c4ed68bcc 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -436,6 +436,24 @@ def _impl(inputs, attr, params): return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr) return _impl +def _assert(): + # ToDo: In general people want asserts to be gone from TensorFlow graphs + # when they are optimizing them, so converting it to a no-op is + # reasonable. However, it would be nice to have the option to keep them + # once Relay gets a Halt or Assert op. + return _no_op() + +def _no_op(): + def _impl(inputs, attr, params): + # ToDo: This should really be an op that returns nothing, which could + # be represented as an empty tuple. It turns out that TVM + # infrastructure doesn't like running functions that return None and + # also don't like running functions that return an empty tuple. So it + # doesn't work, but it should be made to work and then this could be + # improved. In the mean time, it is hard to imagine a case where it + # matters in any real way that a no-op is converted to a constant 0. + return tvm.relay.const(0) + return _impl def _matmul(): def _impl(inputs, attr, params): @@ -1319,6 +1337,7 @@ def _impl(inputs, attr, params): 'All' : _reduce('all'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), + 'Assert' : _assert(), 'AvgPool' : _pooling('avg_pool'), 'BatchMatMul' : _batch_matmul(), 'BatchMatMulV2' : _batch_matmul(), @@ -1377,6 +1396,7 @@ def _impl(inputs, attr, params): 'Mod' : _elemwise('mod'), 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), + 'NoOp' : _no_op(), 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), 'Pack' : _pack(), @@ -2189,8 +2209,11 @@ def _parse_param(self, key, value, name, shape): if np_array.dtype == np.dtype(object): # Object types are generally tensorflow DT_STRING (DecodeJpeg op). # Just leave it as placeholder. - self._nodes[name] = [_expr.var(name, shape=shape[name], dtype='uint8')] - + if shape: + var_shape = shape[name] + else: + var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape) + self._nodes[name] = [_expr.var(name, shape=var_shape, dtype='uint8')] return array_ndim = len(np_array.shape) diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py new file mode 100644 index 0000000000000..98e97d1c67018 --- /dev/null +++ b/tests/python/frontend/tensorflow/test_debugging.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for converting TensorFlow debugging ops to Relay.""" +import tensorflow as tf +import numpy as np +from tvm import relay +from tvm.relay.frontend.tensorflow import from_tensorflow + +def run_relay(graph, *vars): + mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + ex = relay.create_executor('debug', mod=mod) + return ex.evaluate()(*vars) + +def test_assert_true(): + g = tf.Graph() + with g.as_default(): + x = tf.placeholder(tf.float32, shape=()) + assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"]) + + with tf.Session() as sess: + x_value = np.random.rand() + assert sess.run(assert_op, feed_dict={x: x_value}) is None + + # In TVM, tf.assert is converted to a no-op which is actually a 0, + # though it should probably be none or an empty tuple. + # + # ToDo: It appears that the frontend converter gets confused here and + # entirely eliminates all operands from main(). Likely because x <= x + # is always true, so the placeholder can be eliminated. But TF doesn't + # do that, it's happening in Relay, and that optimization shouldn't + # affect the arity of the main function. We should have to pass in + # x_value here. + np.testing.assert_allclose(0, run_relay(g).asnumpy()) + +def test_assert_true_var_capture(): + g = tf.Graph() + with g.as_default(): + x = tf.placeholder(tf.float32, shape=()) + + # It turns out that tf.assert() creates a large and complex subgraph if + # you capture a variable as part of the error message. So we need to + # test that, too. + assert_op = tf.Assert(tf.less_equal(x, x), ["it failed", x]) + + with tf.Session() as sess: + x_value = np.random.rand() + assert sess.run(assert_op, feed_dict={x: x_value}) is None + + # ToDo: The frontend converter gets confused here as well, thinking + # that it needs to be told what x is twice. It also notes the output of + # the graph as a boolean, but this is actually correct, due to the + # strange way that TF creates graphdef from within the tf.assert() + # function. Though regardless the arity should be 1, not 2. + np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy()) + +def test_assert_false(): + g = tf.Graph() + with g.as_default(): + assert_op = tf.Assert(tf.constant(False), ["it failed"]) + + with tf.Session() as sess: + try: + print(sess.run(assert_op)) + assert False # TF should have thrown an exception + except tf.errors.InvalidArgumentError as e: + assert "it failed" in e.message + + # In TVM, tf.assert is converted to a no-op which is actually a 0, + # though it should probably be none or an empty tuple. For the same + # reason, there will not be an error here, even though the assertion + # argument is false. + np.testing.assert_allclose(0, run_relay(g).asnumpy()) + + +if __name__ == "__main__": + test_assert_true() + test_assert_true_var_capture() + test_assert_false() + diff --git a/tests/python/frontend/tensorflow/test_no_op.py b/tests/python/frontend/tensorflow/test_no_op.py new file mode 100644 index 0000000000000..0d09cf4b8949a --- /dev/null +++ b/tests/python/frontend/tensorflow/test_no_op.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for converting TensorFlow debugging ops to Relay.""" +import tensorflow as tf +import numpy as np +from tvm import relay +from tvm.relay.frontend.tensorflow import from_tensorflow + +def run_relay(graph): + mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) + ex = relay.create_executor('debug', mod=mod) + return ex.evaluate()(**params) + +def test_no_op(): + g = tf.Graph() + with g.as_default(): + no_op = tf.no_op() + with tf.Session() as sess: + # In TF, the type of a no-op is None. + assert sess.run(no_op) is None + + # In TVM, no-op is currently translated to 0, though it should + # probably be none or an empty tuple. + np.testing.assert_allclose(0, run_relay(g).asnumpy()) + + +if __name__ == "__main__": + test_no_op() +