Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds JAX-->TFjs converter #6744

Merged
merged 5 commits into from
Aug 17, 2022
Merged

Adds JAX-->TFjs converter #6744

merged 5 commits into from
Aug 17, 2022

Conversation

marcvanzee
Copy link
Contributor

@marcvanzee marcvanzee commented Aug 11, 2022

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@marcvanzee marcvanzee force-pushed the jax2tfjs branch 12 times, most recently from 76a515d to f06078f Compare August 11, 2022 12:45
@rthadur rthadur requested a review from pyu10055 August 11, 2022 14:54
Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Marc, can you add a section for JAX conversion in the README file https://github.com/tensorflow/tfjs/blob/master/tfjs-converter/README.md#python-to-javascript

Reviewable status: 0 of 1 approvals obtained (waiting on @marcvanzee)


tfjs-converter/python/requirements.txt line 3 at r1 (raw file):

flax>=0.5.3
jax>=0.3.15
importlib_resources>=5.9.0

I am wondering if the order matters, can you put importlib_resources to the top?

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to run the python tests, you need to following

cd tfjs-converter && yarn run-python-tests

Reviewable status: 0 of 1 approvals obtained (waiting on @marcvanzee)


tfjs-converter/python/tensorflowjs/BUILD line 53 at r1 (raw file):

    # We expect JAX to already be installed on the system, e.g. via
    # `pip install jax`.
    deps = [requirement("jax")],

add requirement("importlib_resources") to the deps list, this should remove the test error.


tfjs-converter/python/tensorflowjs/converters/jax_conversion.py line 22 at r1 (raw file):

from jax.experimental.jax2tf import shape_poly
import tensorflow as tf
from tensorflowjs.converters import convert_tf_saved_model

I am not sure about this, but importing from the init.py file does not seem to work with the bazel

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcvanzee
Copy link
Contributor Author

The jax tests failed with a shape check https://source.cloud.google.com/results/invocations/b8b3019a-5630-43b6-92c0-a0f15b2235c6/targets/%2F%2Ftfjs-converter%2Fpython%2Ftensorflowjs%2Fconverters:jax_conversion_test/tests;group=__main__.JaxConversionTest;test=test_convert_flax_bn;row=1

Reviewed 8 of 8 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @marcvanzee)

Indeed, I made a recent change to jax2tf that causes average pooling to fail, see: jax-ml/jax#11804

I am currently working on a fix, and once that it in and JAX has rolled out a new version, we can proceed with this PR.

Apologies for the delay!

@marcvanzee
Copy link
Contributor Author

marcvanzee commented Aug 16, 2022

@pyu10055 I fixed the bug and everything is passing now.
Note I am now installing JAX from github since the fix will only be in in the next release, and they just released a new version. Is this okay with you? Given that we are writing a blog post that we want to publish soon (1-2 weeks), it would be great if we could have this code ready. I'm happy to update the import once JAX releases a new version.

@marcvanzee marcvanzee requested a review from pyu10055 August 16, 2022 14:55
Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks I will merge this first.

Reviewed 2 of 2 files at r2.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @marcvanzee)

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we won't be able to release a new version of tfjs converter before the jax release is available.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @marcvanzee)

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the github dep is not recommended for our pip package.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @marcvanzee)

@pyu10055 pyu10055 requested a review from Linchenn August 17, 2022 00:32
@marcvanzee
Copy link
Contributor Author

I will ask the JAX team when they are planning to release a new version, and if they would consider doing this soon.

@marcvanzee
Copy link
Contributor Author

@pyu10055 after our offline discussion, I've now changed it so that we install the latest JAX release and disabled the test with a TODO. I also realized that I forgot to add content to the README, which I did as well. I also made a few minor changes to the section headers, so please take a look and make sure you are okay with it.

Copy link
Collaborator

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 2 of 2 files at r3.
Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @marcvanzee)

@pyu10055
Copy link
Collaborator

Very excited about this new feature, thank you Marc!

@pyu10055 pyu10055 merged commit 167fd6f into tensorflow:master Aug 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants