Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds JAX-->TFjs converter #6744

Merged
merged 5 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tfjs-converter/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ py_wheel(
license = "Apache 2.0",
python_tag = "py3",
requires = [
"flax>=0.5.3",
"importlib_resources>=5.9.0",
"jax>=0.3.15",
"protobuf<3.20,>=3.9.2",
"tensorflow>=2.1.0,<3",
"six>=1.12.0,<2",
Expand Down
3 changes: 3 additions & 0 deletions tfjs-converter/python/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
flax>=0.5.3
jax>=0.3.15
importlib_resources>=5.9.0
protobuf<3.20,>=3.9.2
tensorflow>=2.1.0,<3
six>=1.12.0,<2
Expand Down
27 changes: 27 additions & 0 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ py_library(
],
)

py_library(
name = "expect_flax_installed",
# This is a dummy rule used as a Flax dependency in open-source.
# We expect JAX to already be installed on the system, e.g. via
# `pip install flax`.
deps = [requirement("flax")],
)

py_library(
name = "expect_h5py_installed",
# This is a dummy rule used as a h5py dependency in open-source.
Expand All @@ -37,6 +45,25 @@ py_library(
deps = [requirement("h5py")],
)

py_library(
name = "expect_jax_installed",
# This is a dummy rule used as a JAX dependency in open-source.
# We expect JAX to already be installed on the system, e.g. via
# `pip install jax`.
deps = [
requirement("jax"),
requirement("importlib_resources"),
],
)

py_library(
name = "expect_jax2tf_installed",
# This is a dummy rule used as a jax2tf dependency in open-source.
# We expect jax2tf to already be installed on the system, e.g. via
# `pip install jax`.
deps = [requirement("jax")],
)

py_library(
name = "expect_numpy_installed",
# This is a dummy rule used as a numpy dependency in open-source.
Expand Down
26 changes: 26 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,32 @@ py_library(
],
)

py_test(
name = "jax_conversion_test",
srcs = ["jax_conversion_test.py"],
imports = ["../.."],
srcs_version = "PY3",
tags = ["ci"],
deps = [
":jax_conversion",
"//tfjs-converter/python/tensorflowjs:expect_flax_installed",
"//tfjs-converter/python/tensorflowjs:expect_jax_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
],
)

py_library(
name = "jax_conversion",
srcs = ["jax_conversion.py"],
srcs_version = "PY3",
deps = [
":tf_saved_model_conversion_v2",
"//tfjs-converter/python/tensorflowjs:expect_jax2tf_installed",
"//tfjs-converter/python/tensorflowjs:expect_jax_installed",
"//tfjs-converter/python/tensorflowjs:expect_tensorflow_installed",
],
)

py_test(
name = "wizard_test",
srcs = ["wizard_test.py"],
Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
from tensorflowjs.converters.keras_tfjs_loader import deserialize_keras_model
from tensorflowjs.converters.keras_tfjs_loader import load_keras_model
from tensorflowjs.converters.tf_saved_model_conversion_v2 import convert_tf_saved_model
from tensorflowjs.converters.jax_conversion import convert_jax
149 changes: 149 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/jax_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts a JAX function to TensorFlow.js web format."""
import tempfile
from typing import Any, Callable, Optional, Sequence, Tuple, Union

from jax.experimental import jax2tf
from jax.experimental.jax2tf import shape_poly
import tensorflow as tf
from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion


_TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
Array = Any
DType = Any
PolyShape = shape_poly.PolyShape


class _ReusableSavedModelWrapper(tf.train.Checkpoint):
"""Wraps a function and its parameters for saving to a SavedModel.

Implements the interface described at
https://www.tensorflow.org/hub/reusable_saved_models.
"""

def __init__(self, tf_graph, param_vars):
"""Initializes a _ReusableSavedModelWrapper.

Args:
tf_graph: a tf.function taking one argument (the inputs), which can be
be tuples/lists/dictionaries of np.ndarray or tensors. The function
may have references to the tf.Variables in `param_vars`.
param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
to be saved as the variables of the SavedModel.
"""
super().__init__()
self.variables = tf.nest.flatten(param_vars)
self.trainable_variables = [v for v in self.variables if v.trainable]
# If you intend to prescribe regularization terms for users of the model,
# add them as @tf.functions with no inputs to this list. Else drop this.
self.regularization_losses = []
self.__call__ = tf_graph


def convert_jax(
apply_fn: Callable[..., Any],
params: Array,
*,
input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]],
model_dir: str,
polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None):
"""Converts a JAX function `jax_apply_fn` and model parameters to a TensorflowJS model.

Example usage for a Flax Module:

```
import numpy as np
from flax import linen as nn
from jax import random
import jax.numpy as jnp
from tensorflowjs.converters.jax_conversion import convert_jax

module = nn.Dense(features=4)
inputs = jnp.ones((3, 4))
params = module.init(random.PRNKey(0), inputs)['params']

convert_jax(
apply_fn=module.apply,
params=params,
input_signatures=[((3, 4), np.float32)]
model_dir=tfjs_model_dir)
```

Note that when using dynamic shapes, an additional argument
`polymorphic_shapes` should be provided specifying values for the dynamic
("polymorphic") dimensions). See here for more details:
https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion

This is an adaption of the original implementation in jax2tf here:
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py

Arguments:
apply_fn: A JAX function that has one or more arguments, of which the first
argument are the model parameters. This function typically is the forward
pass of the network (e.g., `Module.apply()` in Flax).
params: A Pytree containing the parameters of the module. These will all be
converted to TF.Variables.
input_signatures: the input signatures for the second and remaining
arguments to `apply_fn` (the input). A signature must be a
`tensorflow.TensorSpec` instance, or a (nested) tuple/list/dictionary
thereof with a structure matching the second argument of `apply_fn`.
model_dir: Directory where the TensorflowJS model will be written to.
polymorphic_shapes: If given then it will be used as the
`polymorphic_shapes` argument for the second parameter of `apply_fn`. In
this case, a single `input_signatures` is supported, and should have
`None` in the polymorphic (dynamic) dimensions.
"""
if polymorphic_shapes is not None:
# If polymorphic shapes are provided, add a polymorphic spec for the
# first argument to `apply_fn`, which are the parameters.
polymorphic_shapes = [None, *polymorphic_shapes]

tf_fn = jax2tf.convert(
apply_fn,
# Gradients must be included as 'PreventGradient' is not supported.
with_gradient=True,
polymorphic_shapes=polymorphic_shapes,
# Do not use TFXLA Ops because these aren't supported by TFjs, but use
# workarounds instead. More information:
# https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops
enable_xla=False)

# Create tf.Variables for the parameters. If you want more useful variable
# names, you can use `tree.map_structure_with_path` from the `dm-tree`
# package.
param_vars = tf.nest.map_structure(
lambda param: tf.Variable(param, trainable=True), params)
# Do not use TF's jit compilation on the function.
tf_graph = tf.function(
lambda *xs: tf_fn(param_vars, *xs), autograph=False, jit_compile=False)

# This signature is needed for TensorFlow Serving use.
signatures = {
_TF_SERVING_KEY: tf_graph.get_concrete_function(*input_signatures)
}

wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
saved_model_options = tf.saved_model.SaveOptions(
experimental_custom_gradients=True)

with tempfile.TemporaryDirectory() as saved_model_dir:
tf.saved_model.save(
wrapper,
saved_model_dir,
signatures=signatures,
options=saved_model_options)
saved_model_conversion.convert_tf_saved_model(saved_model_dir, model_dir)
144 changes: 144 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/jax_conversion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for converting JAX to TensorFlow.js web format."""
import functools

from flax import linen as nn
from jax import random
import jax.numpy as jnp
import tensorflow as tf
from tensorflowjs.converters import jax_conversion


class FlaxModule(nn.Module):
"""A simple Flax Module containing a few Dense layers and ReLUs."""

@nn.compact
def __call__(self, x):
x = nn.Dense(features=20)(x)
x = nn.relu(x)
for _ in range(5):
x = nn.Dense(features=10)(x)
x = nn.relu(x)

x = nn.Dense(features=2)(x)
x = nn.sigmoid(x)
return x


class FlaxModuleBatchNorm(nn.Module):
"""A simple CNN-like Flax model with BatchNorm."""

@nn.compact
def __call__(self, x, *, training=True):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x


class JaxConversionTest(tf.test.TestCase):

def test_convert_simple(self):
apply_fn = lambda params, x: jnp.sum(x) * params['w']
jax_conversion.convert_jax(
apply_fn,
{'w': 0.5},
input_signatures=[tf.TensorSpec((2, 3), tf.float32)],
model_dir=self.get_temp_dir())

def test_convert_poly(self):
apply_fn = lambda params, x: jnp.sum(x) * params['w']
jax_conversion.convert_jax(
apply_fn,
{'w': 0.5},
input_signatures=[tf.TensorSpec((None, 3), tf.float32)],
polymorphic_shapes=['(b, 3)'],
model_dir=self.get_temp_dir())

def test_convert_tf_poly_mismatch_raises(self):
apply_fn = lambda params, x: jnp.sum(x) * params['w']
with self.assertRaisesRegex(
ValueError, 'polymorphic shape.* must match .* for argument shape'):
jax_conversion.convert_jax(
apply_fn,
{'w': 0.5},
input_signatures=[tf.TensorSpec((None, 3), tf.float32)],
polymorphic_shapes=['(b, 4)'],
model_dir=self.get_temp_dir())

def test_convert_multiargs(self):
apply_fn = lambda params, x, y: jnp.sum(x) * jnp.sum(y) * params['w']
jax_conversion.convert_jax(
apply_fn,
{'w': 0.5},
input_signatures=[tf.TensorSpec((2, 3), tf.float32),
tf.TensorSpec((5, 6), tf.float32)],
model_dir=self.get_temp_dir())

def test_convert_multiarg_poly(self):
apply_fn = lambda params, x, y: jnp.sum(x) * jnp.sum(y) * params['w']
jax_conversion.convert_jax(
apply_fn,
{'w': 0.5},
input_signatures=[tf.TensorSpec((None, 3), tf.float32),
tf.TensorSpec((None, 6), tf.float32)],
polymorphic_shapes=['(b, 3)', '(b, 6)'],
model_dir=self.get_temp_dir())

def test_convert_flax(self):
m, x = FlaxModule(), jnp.zeros((3, 4))
variables = m.init(random.PRNGKey(0), x)
jax_conversion.convert_jax(
m.apply,
variables,
input_signatures=[tf.TensorSpec((3, 4), tf.float32)],
model_dir=self.get_temp_dir())

def test_convert_flax_poly(self):
m, x = FlaxModule(), jnp.zeros((3, 4))
variables = m.init(random.PRNGKey(0), x)
jax_conversion.convert_jax(
m.apply,
variables,
input_signatures=[tf.TensorSpec((None, 4), tf.float32)],
polymorphic_shapes=['(b, 4)'],
model_dir=self.get_temp_dir())

def test_convert_flax_bn(self):
m, x = FlaxModuleBatchNorm(), jnp.zeros((1, 32, 32, 3))
variables = m.init(random.PRNGKey(0), x)
# Note: if we don't pass training=False here, we will get an error during
# conversion since `batch_stats` is mutated while it is not passed as a
# mutable variable collections (we currently do not support this).
apply_fn = functools.partial(m.apply, training=False)
jax_conversion.convert_jax(
apply_fn,
variables,
input_signatures=[tf.TensorSpec((1, 32, 32, 3), tf.float32)],
model_dir=self.get_temp_dir())


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

Used primarily to convert saved weights, or saved_models from their
hdf5 format to a JSON + binary weights format that the TS codebase can use.
."""
"""

from __future__ import absolute_import
from __future__ import division
Expand Down