Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds JAX-->TFjs converter #6744

Merged
merged 5 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions tfjs-converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using an already hosted model (e.g. MobileNet), skip this step.
2. [JavaScript API](./src/executor/tf_model.ts), for loading and running
inference.

## Step 1: Converting a [TensorFlow SavedModel](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md), [TensorFlow Hub module](https://www.tensorflow.org/hub/), [Keras HDF5](https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model) or [tf.keras SavedModel](https://www.tensorflow.org/api_docs/python/tf/contrib/saved_model/save_keras_model) to a web-friendly format
## Step 1: Converting a [TensorFlow SavedModel](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md), [TensorFlow Hub module](https://www.tensorflow.org/hub/), [Keras HDF5](https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model), [tf.keras SavedModel](https://www.tensorflow.org/api_docs/python/tf/contrib/saved_model/save_keras_model), or [Flax/JAX model](http://github.com/google/flax) to a web-friendly format

__0. Please make sure that you run in a Docker container or a virtual environment.__

Expand Down Expand Up @@ -54,10 +54,13 @@ Install the library with interactive CLI:

__2. Run the converter script provided by the pip package:__

There are two way to trigger the model conversion:
There are three way to trigger the model conversion, explain below:

- The conversion wizard: `tensorflowjs_wizard`
- Regular conversion script: `tensorflowjs_converter`
- The conversion wizard: `tensorflowjs_wizard` ([go to section](#conversion-wizard-tensorflowjswizard))
- Regular conversion script: `tensorflowjs_converter` ([go to section](#regular-conversion-script-tensorflowjsconverter))
- Calling a converter function in Python (Flax/JAX) ([go to section](#calling-a-converter-function-in-python))

## Conversion wizard: `tensorflowjs_wizard`

To start the conversion wizard:
```bash
Expand All @@ -81,7 +84,7 @@ tensorflowjs_wizard --dryrun
To convert a batch of models or integrate the conversion process into your own
script, you should use the tensorflowjs_converter script.

## Conversion flags
## Regular conversion script: `tensorflowjs_converter`

The converter expects a __TensorFlow SavedModel__, __TensorFlow Hub module__,
__TensorFlow.js JSON__ format, __Keras HDF5 model__, or __tf.keras SavedModel__
Expand Down Expand Up @@ -141,6 +144,8 @@ Note that the input path used above is a subfolder that has a Unix epoch
time (1542211770) and is generated automatically by tensorflow when it
saved a tf.keras model in the SavedModel format.

### Conversion Flags

|Positional Arguments | Description |
|---|---|
|`input_path` | Full path of the saved model directory or TensorFlow Hub module handle or path.|
Expand Down Expand Up @@ -271,6 +276,53 @@ following location:
https://storage.cloud.google.com/tfjs-models/savedmodel/mobilenet_v1_1.0_224/group1-shard5of5
```

## Calling a Converter Function in Python (Flax/JAX)

You can also convert your model to web format in Python by calling one of the
conversion functions. This is currently the only way to convert a Flax or JAX
model, since no standard serialization format exists to store a Module (only the
checkpoints).

Here we provide an example of how to convert a Flax function using the
conversion function `tfjs.jax_conversion.convert_jax()`.

```py
import numpy as np
from flax import linen as nn
from jax import random
import jax.numpy as jnp
from tensorflowjs.converters import jax_conversion

module = nn.Dense(features=4)
inputs = jnp.ones((3, 4))
params = module.init(random.PRNKey(0), inputs)['params']

jax_conversion.convert_jax(
apply_fn=module.apply,
params=params,
input_signatures=[((3, 4), np.float32)],
model_dir=tfjs_model_dir)
```

Note that when using dynamic shapes, an additional argument `polymorphic_shapes`
should be provided specifying values for the dynamic ("polymorphic")
dimensions). So in order to convert the same model as before, but now with a
dynamic first dimension, one should call `convert_jax` as follows:

```py
jax_conversion.convert_jax(
apply_fn=module.apply,
params=params,
input_signatures=[((None, 4), np.float32)],
polymorphic_shapes=["(b, 4)"],
model_dir=tfjs_model_dir)
```

See
[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion)
for more details on the exact syntax for this argument.


## Step 2: Loading and running in the browser

If the original model was a `SavedModel`, use
Expand Down
3 changes: 3 additions & 0 deletions tfjs-converter/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ py_wheel(
license = "Apache 2.0",
python_tag = "py3",
requires = [
"flax>=0.5.3",
"importlib_resources>=5.9.0",
"jax>=0.3.16",
"protobuf<3.20,>=3.9.2",
"tensorflow>=2.1.0,<3",
"six>=1.12.0,<2",
Expand Down
3 changes: 3 additions & 0 deletions tfjs-converter/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
flax>=0.5.3
jax>=0.3.16
importlib_resources>=5.9.0
protobuf<3.20,>=3.9.2
tensorflow>=2.1.0,<3
six>=1.12.0,<2
Expand Down
27 changes: 27 additions & 0 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ py_library(
],
)

py_library(
name = "expect_flax_installed",
# This is a dummy rule used as a Flax dependency in open-source.
# We expect JAX to already be installed on the system, e.g. via
# `pip install flax`.
deps = [requirement("flax")],
)

py_library(
name = "expect_h5py_installed",
# This is a dummy rule used as a h5py dependency in open-source.
Expand All @@ -37,6 +45,25 @@ py_library(
deps = [requirement("h5py")],
)

py_library(
name = "expect_jax_installed",
# This is a dummy rule used as a JAX dependency in open-source.
# We expect JAX to already be installed on the system, e.g. via
# `pip install jax`.
deps = [
requirement("jax"),
requirement("importlib_resources"),
],
)

py_library(
name = "expect_jax2tf_installed",
# This is a dummy rule used as a jax2tf dependency in open-source.
# We expect jax2tf to already be installed on the system, e.g. via
# `pip install jax`.
deps = [requirement("jax")],
)

py_library(
name = "expect_numpy_installed",
# This is a dummy rule used as a numpy dependency in open-source.
Expand Down
26 changes: 26 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,32 @@ py_library(
],
)

py_test(
name = "jax_conversion_test",
srcs = ["jax_conversion_test.py"],
imports = ["../.."],
srcs_version = "PY3",
tags = ["ci"],
deps = [
":jax_conversion",
"//tfjs-converter/python/tensorflowjs:expect_flax_installed",
"//tfjs-converter/python/tensorflowjs:expect_jax_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
],
)

py_library(
name = "jax_conversion",
srcs = ["jax_conversion.py"],
srcs_version = "PY3",
deps = [
":tf_saved_model_conversion_v2",
"//tfjs-converter/python/tensorflowjs:expect_jax2tf_installed",
"//tfjs-converter/python/tensorflowjs:expect_jax_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
],
)

py_test(
name = "wizard_test",
srcs = ["wizard_test.py"],
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from tensorflowjs.converters.keras_tfjs_loader import deserialize_keras_model
from tensorflowjs.converters.keras_tfjs_loader import load_keras_model
from tensorflowjs.converters.tf_saved_model_conversion_v2 import convert_tf_saved_model
from tensorflowjs.converters.jax_conversion import convert_jax
149 changes: 149 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/jax_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts a JAX function to TensorFlow.js web format."""
import tempfile
from typing import Any, Callable, Optional, Sequence, Tuple, Union

from jax.experimental import jax2tf
from jax.experimental.jax2tf import shape_poly
import tensorflow as tf
from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion


_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
Array = Any
DType = Any
PolyShape = shape_poly.PolyShape


class _ReusableSavedModelWrapper(tf.train.Checkpoint):
"""Wraps a function and its parameters for saving to a SavedModel.

Implements the interface described at
https://www.tensorflow.org/hub/reusable_saved_models.
"""

def __init__(self, tf_graph, param_vars):
"""Initializes a _ReusableSavedModelWrapper.

Args:
tf_graph: a tf.function taking one argument (the inputs), which can be
be tuples/lists/dictionaries of np.ndarray or tensors. The function
may have references to the tf.Variables in `param_vars`.
param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
to be saved as the variables of the SavedModel.
"""
super().__init__()
self.variables = tf.nest.flatten(param_vars)
self.trainable_variables = [v for v in self.variables if v.trainable]
# If you intend to prescribe regularization terms for users of the model,
# add them as @tf.functions with no inputs to this list. Else drop this.
self.regularization_losses = []
self.__call__ = tf_graph


def convert_jax(
apply_fn: Callable[..., Any],
params: Array,
*,
input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]],
model_dir: str,
polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None):
"""Converts a JAX function `jax_apply_fn` and model parameters to a TensorflowJS model.

Example usage for a Flax Module:

```
import numpy as np
from flax import linen as nn
from jax import random
import jax.numpy as jnp
from tensorflowjs.converters.jax_conversion import convert_jax

module = nn.Dense(features=4)
inputs = jnp.ones((3, 4))
params = module.init(random.PRNKey(0), inputs)['params']

convert_jax(
apply_fn=module.apply,
params=params,
input_signatures=[((3, 4), np.float32)],
model_dir=tfjs_model_dir)
```

Note that when using dynamic shapes, an additional argument
`polymorphic_shapes` should be provided specifying values for the dynamic
("polymorphic") dimensions). See here for more details:
https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion

This is an adaption of the original implementation in jax2tf here:
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py

Arguments:
apply_fn: A JAX function that has one or more arguments, of which the first
argument are the model parameters. This function typically is the forward
pass of the network (e.g., `Module.apply()` in Flax).
params: A Pytree containing the parameters of the module. These will all be
converted to TF.Variables.
input_signatures: the input signatures for the second and remaining
arguments to `apply_fn` (the input). A signature must be a
`tensorflow.TensorSpec` instance, or a (nested) tuple/list/dictionary
thereof with a structure matching the second argument of `apply_fn`.
model_dir: Directory where the TensorflowJS model will be written to.
polymorphic_shapes: If given then it will be used as the
`polymorphic_shapes` argument for the second parameter of `apply_fn`. In
this case, a single `input_signatures` is supported, and should have
`None` in the polymorphic (dynamic) dimensions.
"""
if polymorphic_shapes is not None:
# If polymorphic shapes are provided, add a polymorphic spec for the
# first argument to `apply_fn`, which are the parameters.
polymorphic_shapes = [None, *polymorphic_shapes]

tf_fn = jax2tf.convert(
apply_fn,
# Gradients must be included as 'PreventGradient' is not supported.
with_gradient=True,
polymorphic_shapes=polymorphic_shapes,
# Do not use TFXLA Ops because these aren't supported by TFjs, but use
# workarounds instead. More information:
# https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops
enable_xla=False)

# Create tf.Variables for the parameters. If you want more useful variable
# names, you can use `tree.map_structure_with_path` from the `dm-tree`
# package.
param_vars = tf.nest.map_structure(
lambda param: tf.Variable(param, trainable=True), params)
# Do not use TF's jit compilation on the function.
tf_graph = tf.function(
lambda *xs: tf_fn(param_vars, *xs), autograph=False, jit_compile=False)

# This signature is needed for TensorFlow Serving use.
signatures = {
_TF_SERVING_KEY: tf_graph.get_concrete_function(*input_signatures)
}

wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
saved_model_options = tf.saved_model.SaveOptions(
experimental_custom_gradients=True)

with tempfile.TemporaryDirectory() as saved_model_dir:
tf.saved_model.save(
wrapper,
saved_model_dir,
signatures=signatures,
options=saved_model_options)
saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir)
Loading