Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Public API to force load custom ops #1151

Closed
guillaumekln opened this issue Feb 25, 2020 · 3 comments · Fixed by #1193
Closed

Public API to force load custom ops #1151

guillaumekln opened this issue Feb 25, 2020 · 3 comments · Fixed by #1193

Comments

@guillaumekln
Copy link
Contributor

guillaumekln commented Feb 25, 2020

Currently, it is inconvenient to load in Python a SavedModel that includes Addons custom ops. Consider the example below:

  • save.py
import tensorflow as tf
import tensorflow_addons as tfa

class Model(tf.keras.Model):
    @tf.function(input_signature=(tf.TensorSpec(shape=[None, 32], dtype=tf.float32),))
    def call(self, x):
        return tfa.activations.gelu(x)

model = Model()
tf.saved_model.save(model, '/tmp/model', signatures=model.call)
  • load.py
import tensorflow as tf
tf.saved_model.load("/tmp/model")

The load will fail because Addons custom ops are not registered to the TensorFlow runtime. This is expected as we first have to invoke tf.load_op_library on the custom ops.

However, with the new work on lazy loading #855 it got harder to force this op registration. For this model, the user should run the following which relies on internal APIs:

from tensorflow_addons.activations.gelu import _activation_so
_activation_so.ops

If the custom ops are not loaded during the main import (i.e. during import tensorflow_addons), then the package should expose a public API that registers all custom ops.

Any thoughts?

@gabrieldemarmiesse
Copy link
Member

gabrieldemarmiesse commented Feb 25, 2020

That's definitely a big problem.

I'd be in favor of avoiding using an explicit function to force the loading of the ops. It's not a very good move for UX as users might expect this to be automatic.

We can try to load all the SO at import time and then throw a warning if something goes wrong? With a mechanism similar to #1137 . But then we should have a way of disabling the warnings if some users don't care about custom ops.

In a perfect world, one would be able to save a model using the .so, and another user could reload the same model using the python-only equivalent op. Is that possible? Maybe with the keras model saving?

EDIT: Actually, I'm not so sure it's the best UX to register everything at import time. If we made a function to register all ops, would we make this function register keras functions too? The register_keras_serializable part?

@guillaumekln
Copy link
Contributor Author

From the user perspective, loading everything at import time sounds like the less surprising behavior. In the example above,

import tensorflow as tf
tf.saved_model.load("/tmp/model")

adding the Addons import would seem a natural fix for the SavedModel loading issue:

import tensorflow as tf
import tensorflow_addons as tfa
tf.saved_model.load("/tmp/model")

We can try to load all the SO at import time and then throw a warning if something goes wrong?

IIRC, the custom op errors we faced was mostly segmentation faults on import, right? In that case, it would be difficult to catch the error. We could fork the process and load the custom op but that seems overkill.

Or we could load custom ops by default and add an option to disable this automatic loading for users facing issues.

@gabrieldemarmiesse
Copy link
Member

When loading a .so, if there are any kind of issue, it throws a tensorflow.errors.NotFoundError, not a segfault. So it's easy to give a warning and just continue if there is a problem at import time.

I believe we need to discuss more this statement:

From the user perspective, loading everything at import time sounds like the less surprising behavior.

If we take the "load the .so files at import time" approach:

import tensorflow as tf
import tensorflow_addons as tfa
tf.saved_model.load("/tmp/model")

If the user is using pycharm, it's going to tell the user: tensorflow_addons imported but unused
If the user is using visual studio code, it's going to tell the user: tensorflow_addons imported but unused
If the user is using flake8, it's going to tell the user: tensorflow_addons imported but unused
If the user has a colleague who has never worked with addons, it's going to say that tensorflow_addons was imported but unused and add a commit to remove the import.

This doesn't happen if we go for:

import tensorflow as tf
import tensorflow_addons as tfa
tfa.register_all_ops()
tf.saved_model.load("/tmp/model")

When reading the code done with a library that has a well written API (like keras/pygithub/numpy...) it's possible to follow a piece of code without reading the documentation beforehand. In this case it's really clear what is happening, unlike when the shared objects are loaded at import time. A piece of code is read many more time than it's written, so we should go for readability.

If we follow the principal of least surprises, for some users, doing any significant work at import time sounds very strange and unintuitive, as demonstrated by the errors thrown by flake8, pycharm and visual studio code. It's also against the zen of python:

Explicit is better than implicit

We have other benefits in doing this: we catch errors early and we don't make an additional environment variable for people who don't want to load the SO files or want to debug what happens at import time. Because the function call means "load everything" and not "maybe load everything" we can throw a hard error and it's easier to debug. People who want to use TFA with another version of Tensorflow than the recommended one can just not use the function and they won't have issues with the shared objects.

I would recommend making a new function as proposed by @guillaumekln in his first message:

def register_all_ops(keras_objects=True, custom_kernels=True):
    if keras_objects:
        for obj in keras_objects:
            tf.keras.register_keras_serializable(obj)
    if custom_kernels:
        for custom_kernel_path in custom_kernel_paths:
            tf.load_op_library(custom_kernel_path)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants