-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support and testing for tf.assert (as no-op) and tf.no_op to TF R…
…elay frontend.
- Loading branch information
Showing
3 changed files
with
161 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Unit tests for converting TensorFlow debugging ops to Relay.""" | ||
import tensorflow as tf | ||
import numpy as np | ||
from tvm import relay | ||
from tvm.relay.frontend.tensorflow import from_tensorflow | ||
|
||
def run_relay(graph, *vars): | ||
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) | ||
ex = relay.create_executor('debug', mod=mod) | ||
return ex.evaluate()(*vars) | ||
|
||
def test_assert_true(): | ||
g = tf.Graph() | ||
with g.as_default(): | ||
x = tf.placeholder(tf.float32, shape=()) | ||
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"]) | ||
|
||
with tf.Session() as sess: | ||
x_value = np.random.rand() | ||
assert sess.run(assert_op, feed_dict={x: x_value}) is None | ||
|
||
# In TVM, tf.assert is converted to a no-op which is actually a 0, | ||
# though it should probably be none or an empty tuple. | ||
# | ||
# ToDo: It appears that the frontend converter gets confused here and | ||
# entirely eliminates all operands from main(). Likely because x <= x | ||
# is always true, so the placeholder can be eliminated. But TF doesn't | ||
# do that, it's happening in Relay, and that optimization shouldn't | ||
# affect the arity of the main function. We should have to pass in | ||
# x_value here. | ||
np.testing.assert_allclose(0, run_relay(g).asnumpy()) | ||
|
||
def test_assert_true_var_capture(): | ||
g = tf.Graph() | ||
with g.as_default(): | ||
x = tf.placeholder(tf.float32, shape=()) | ||
|
||
# It turns out that tf.assert() creates a large and complex subgraph if | ||
# you capture a variable as part of the error message. So we need to | ||
# test that, too. | ||
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed", x]) | ||
|
||
with tf.Session() as sess: | ||
x_value = np.random.rand() | ||
assert sess.run(assert_op, feed_dict={x: x_value}) is None | ||
|
||
# ToDo: The frontend converter gets confused here as well, thinking | ||
# that it needs to be told what x is twice. It also notes the output of | ||
# the graph as a boolean, but this is actually correct, due to the | ||
# strange way that TF creates graphdef from within the tf.assert() | ||
# function. Though regardless the arity should be 1, not 2. | ||
np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy()) | ||
|
||
def test_assert_false(): | ||
g = tf.Graph() | ||
with g.as_default(): | ||
assert_op = tf.Assert(tf.constant(False), ["it failed"]) | ||
|
||
with tf.Session() as sess: | ||
try: | ||
print(sess.run(assert_op)) | ||
assert False # TF should have thrown an exception | ||
except tf.errors.InvalidArgumentError as e: | ||
assert "it failed" in e.message | ||
|
||
# In TVM, tf.assert is converted to a no-op which is actually a 0, | ||
# though it should probably be none or an empty tuple. For the same | ||
# reason, there will not be an error here, even though the assertion | ||
# argument is false. | ||
np.testing.assert_allclose(0, run_relay(g).asnumpy()) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_assert_true() | ||
test_assert_true_var_capture() | ||
test_assert_false() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Unit tests for converting TensorFlow debugging ops to Relay.""" | ||
import tensorflow as tf | ||
import numpy as np | ||
from tvm import relay | ||
from tvm.relay.frontend.tensorflow import from_tensorflow | ||
|
||
def run_relay(graph): | ||
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) | ||
ex = relay.create_executor('debug', mod=mod) | ||
return ex.evaluate()(**params) | ||
|
||
def test_no_op(): | ||
g = tf.Graph() | ||
with g.as_default(): | ||
no_op = tf.no_op() | ||
with tf.Session() as sess: | ||
# In TF, the type of a no-op is None. | ||
assert sess.run(no_op) is None | ||
|
||
# In TVM, no-op is currently translated to 0, though it should | ||
# probably be none or an empty tuple. | ||
np.testing.assert_allclose(0, run_relay(g).asnumpy()) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_no_op() | ||
|