-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds JAX-->TFjs converter #6744
Conversation
76a515d
to
f06078f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Marc, can you add a section for JAX conversion in the README file https://github.com/tensorflow/tfjs/blob/master/tfjs-converter/README.md#python-to-javascript
Reviewable status: 0 of 1 approvals obtained (waiting on @marcvanzee)
tfjs-converter/python/requirements.txt
line 3 at r1 (raw file):
flax>=0.5.3 jax>=0.3.15 importlib_resources>=5.9.0
I am wondering if the order matters, can you put importlib_resources to the top?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to run the python tests, you need to following
cd tfjs-converter && yarn run-python-tests
Reviewable status: 0 of 1 approvals obtained (waiting on @marcvanzee)
tfjs-converter/python/tensorflowjs/BUILD
line 53 at r1 (raw file):
# We expect JAX to already be installed on the system, e.g. via # `pip install jax`. deps = [requirement("jax")],
add requirement("importlib_resources")
to the deps list, this should remove the test error.
tfjs-converter/python/tensorflowjs/converters/jax_conversion.py
line 22 at r1 (raw file):
from jax.experimental.jax2tf import shape_poly import tensorflow as tf from tensorflowjs.converters import convert_tf_saved_model
I am not sure about this, but importing from the init.py file does not seem to work with the bazel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The jax tests failed with a shape check
https://source.cloud.google.com/results/invocations/b8b3019a-5630-43b6-92c0-a0f15b2235c6/targets/%2F%2Ftfjs-converter%2Fpython%2Ftensorflowjs%2Fconverters:jax_conversion_test/tests;group=__main__.JaxConversionTest;test=test_convert_flax_bn;row=1
Reviewed 8 of 8 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @marcvanzee)
Indeed, I made a recent change to jax2tf that causes average pooling to fail, see: jax-ml/jax#11804 I am currently working on a fix, and once that it in and JAX has rolled out a new version, we can proceed with this PR. Apologies for the delay! |
@pyu10055 I fixed the bug and everything is passing now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks I will merge this first.
Reviewed 2 of 2 files at r2.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @marcvanzee)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we won't be able to release a new version of tfjs converter before the jax release is available.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @marcvanzee)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the github dep is not recommended for our pip package.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @marcvanzee)
I will ask the JAX team when they are planning to release a new version, and if they would consider doing this soon. |
@pyu10055 after our offline discussion, I've now changed it so that we install the latest JAX release and disabled the test with a TODO. I also realized that I forgot to add content to the README, which I did as well. I also made a few minor changes to the section headers, so please take a look and make sure you are okay with it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 2 of 2 files at r3.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @marcvanzee)
Very excited about this new feature, thank you Marc! |
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is