Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes the build file out of the obm_module and instead processes the serving_configs in the constructor of obm_export to match what is done with the Tensorflow Export. #1383

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions export/orbax/export/export_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
obx_export_config = config.config
maybe_reraise = reraise_utils.maybe_reraise


class ExportManager:
"""Exports a JAXModule with pre- and post-processors."""

Expand All @@ -50,26 +49,18 @@ def __init__(
version: the version of the export format to use. Defaults to
TF_SAVEDMODEL.
"""
if version != module.export_version():
if version != module.export_version:
raise ValueError(
'`version` and `module.export_version()`'
f' must be the same. The former is {version}. The latter is '
f'{module.export_version()}.'
f'{module.export_version}.'
)
self._version = version
self._jax_module = module
if self._version == constants.ExportModelType.ORBAX_MODEL:
self._serialization_functions = obm_export.ObmExport(
self._jax_module, serving_configs
)
obm_module_ = module.orbax_module()
if not isinstance(obm_module_, obm_module.ObmModule):
raise ValueError(
'module.orbax_module() must return an `ObmModule`. '
f'Got type: {type(obm_module_)}'
)
# TODO(bdwalker): Let `ObmExport.__init__() do this `build()` step.
obm_module_.build(serving_configs)
else:
self._serialization_functions = tensorflow_export.TensorFlowExport(
self._jax_module, serving_configs
Expand Down
1 change: 1 addition & 0 deletions export/orbax/export/jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def model_params(self) -> PyTree:
"""Returns the model parameters."""
return self._export_module.model_params

@property
def export_version(self) -> constants.ExportModelType:
"""Returns the export version."""
return self._export_version
Expand Down
147 changes: 42 additions & 105 deletions export/orbax/export/modules/obm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,42 +14,23 @@

"""Wraps JAX functions and parameters into a tf.Module."""

from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Mapping
import copy
import logging
from typing import Any, Tuple, Union
from typing import Any, Optional, Union

import jax
from jax import export as jax_export
from orbax.export import constants
from orbax.export import serving_config as osc
from orbax.export import typing as orbax_export_typing
from orbax.export import utils
# from orbax.export import utils
from orbax.export.modules import orbax_module_base
from orbax.export.typing import PyTree
import tensorflow as tf


ApplyFn = orbax_export_typing.ApplyFn


def _to_jax_dtype(t):
if isinstance(t, tf.DType):
return t.as_numpy_dtype()
return t


def _to_jax_spec(tree: PyTree) -> PyTree:
return jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, _to_jax_dtype(x.dtype)), tree
)


def _to_sequence(a):
if isinstance(a, Sequence):
return a
return (a,)


class ObmModule(orbax_module_base.OrbaxModuleBase):
"""A data module for encapsulating the data for a Jax model to be serialized through the Orbax Model export flow."""

Expand All @@ -69,7 +50,6 @@ def __init__(
'native_serialization_platform', 'flatten_signature', 'weights_name'and
'checkpoint_path'.
"""
self._params = params

# It is possible for jax2obm_kwargs to be None if the key is present.
if not jax2obm_kwargs:
Expand All @@ -78,11 +58,30 @@ def __init__(
self._apply_fn_map = self._normalize_apply_fn_map(
self._normalize_apply_fn_map(apply_fn)
)

if len(self._apply_fn_map) != 1:
raise NotImplementedError(
'ObmModule: Currently the ObmExport only supports a single method'
f' for export. Received: {self._apply_fn_map}'
)

self._native_serialization_platform = (
jax2obm_kwargs[constants.NATIVE_SERIALIZATION_PLATFORM]
if constants.NATIVE_SERIALIZATION_PLATFORM in jax2obm_kwargs
else None
)
supported_platforms = [
platform.name for platform in constants.OrbaxNativeSerializationType
]
if (
self._native_serialization_platform is not None
and self._native_serialization_platform not in supported_platforms
):
raise ValueError(
'native_serialization_platforms must be a sequence containing a'
f' subset of: {supported_platforms}'
)

self._flatten_signature = (
jax2obm_kwargs[constants.FLATTEN_SIGNATURE]
if constants.FLATTEN_SIGNATURE in jax2obm_kwargs
Expand All @@ -94,45 +93,12 @@ def __init__(
if self._support_tf_resources is None:
self._support_tf_resources = False

self._params_args_spec = _to_jax_spec(params)
self._params_args_spec = utils.to_jax_spec(params)

self._checkpoint_path: str = None
# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs)

self.built = False

def build(
self,
serving_configs: Sequence[osc.ServingConfig],
) -> None:
if self.built:
raise ValueError(
'The `build` method has already been called.'
' It can only be called once.'
)
self._verify_serving_configs(serving_configs)

# Currently there will only ever be a single item in the mapping.
if len(self._apply_fn_map) != 1:
raise NotImplementedError(
'ObmModule: Currently the ObmExport only supports a single method'
f' for export. Received: {self._apply_fn_map}'
)

model_function_name, jax_fn = next(iter(self._apply_fn_map.items()))

self._convert_jax_functions_to_obm_functions(
jax_fn=jax_fn,
jax_fn_name=model_function_name,
params_args_spec=self._params_args_spec,
serving_config=serving_configs[0],
native_serialization_platform=self._native_serialization_platform,
flatten_signature=self._flatten_signature,
support_tf_resources=self._support_tf_resources,
)

self.built = True

def _normalize_apply_fn_map(
self, apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]]
) -> Mapping[str, ApplyFn]:
Expand All @@ -147,56 +113,13 @@ def _normalize_apply_fn_map(
apply_fn_map = apply_fn
return apply_fn_map

def _verify_serving_configs(
self, serving_configs: Sequence[osc.ServingConfig]
):
if not serving_configs or len(serving_configs) != 1:
raise ValueError(
'ObmModule: A single serving_config must be provided for Orbax'
' Model export.'
)

if not serving_configs[0].input_signature:
# TODO(wangpeng): Infer input_signature from tf_preprocessor.
raise ValueError(
'ObmModule: The serving_config must have an input_signature set.'
)

if not serving_configs[0].signature_key:
raise ValueError(
'ObmModule: The serving_config must have a signature_key set.'
)

def _convert_jax_functions_to_obm_functions(
self,
*,
jax_fn,
jax_fn_name: str,
params_args_spec: PyTree,
serving_config: osc.ServingConfig,
native_serialization_platform,
flatten_signature: bool,
support_tf_resources: bool,
):
"""Converts the JAX functions to OrbaxModel functions."""
if serving_config.input_signature is None:
raise ValueError('serving_config.input_signature is required.')
if (
not support_tf_resources
and serving_config.extra_trackable_resources is not None
):
raise ValueError(
'serving_config.extra_trackable_resources can only be set when'
' support_tf_resources is True.'
)

def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
if constants.CHECKPOINT_PATH not in jax2obm_kwargs:
return

# TODO: b/374195447 - Add a version check for the Orbax checkpointer.
checkpoint_path = jax2obm_kwargs[constants.CHECKPOINT_PATH]
weights_name = (
self._checkpoint_path = jax2obm_kwargs[constants.CHECKPOINT_PATH]
self._weights_name = (
jax2obm_kwargs[constants.WEIGHTS_NAME]
if constants.WEIGHTS_NAME in jax2obm_kwargs
else constants.DEFAULT_WEIGHTS_NAME
Expand All @@ -212,15 +135,29 @@ def apply_fn_map(self) -> Mapping[str, ApplyFn]:
"""Returns the apply_fn_map from function name to jit'd apply function."""
return self._apply_fn_map

@property
def native_serialization_platform(self) -> Optional[str]:
"""Returns the native serialization platform."""
return self._native_serialization_platform

@property
def flatten_signature(self) -> bool:
"""Returns the flatten signature."""
return self._flatten_signature

@property
def export_version(self) -> constants.ExportModelType:
"""Returns the export version."""
return constants.ExportModelType.ORBAX_MODEL

def support_tf_resources(self) -> bool:
"""Returns True if the model supports TF resources."""
return self._support_tf_resources

@property
def model_params(self) -> PyTree:
"""Returns the model parameter specs."""
return self._params
return self._params_args_spec

def obm_module_to_jax_exported_map(
self,
Expand Down
3 changes: 3 additions & 0 deletions export/orbax/export/modules/obm_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from orbax.export import constants
from orbax.export import obm_export
from orbax.export import serving_config as osc
from orbax.export.modules import obm_module
from orbax.export import jax_module
import tensorflow as tf


Expand Down
37 changes: 29 additions & 8 deletions export/orbax/export/obm_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,44 @@

"""Export class that implements the save and load abstract class defined in Export Base for use with the Orbax Model export format."""

from typing import Any, Callable, Mapping, cast
from typing import Any, Callable, Dict, Mapping, Sequence, Tuple, Union, cast

from absl import logging
import jax
from orbax.export import constants
from orbax.export import export_base
from orbax.export import jax_module
from orbax.export import serving_config as osc
from orbax.export import utils
from orbax.export.modules import obm_module
from orbax.export.typing import PyTree
import tensorflow as tf


def _to_sequence(a):
if isinstance(a, Sequence):
return a
return (a,)


class ObmExport(export_base.ExportBase):
"""Defines the save and load methods for exporting a model using Orbax Model export."""

def __init__(
self,
module: jax_module.JaxModule,
serving_configs: Sequence[osc.ServingConfig],
):
"""Initializes the ObmExport class."""
if module.export_version != constants.ExportModelType.ORBAX_MODEL:
raise ValueError(
"JaxModule export version is not of type ORBAX_MODEL. Please use the"
" correct export_version. Expected ORBAX_MODEL, got"
f" {module.export_version}"
)

obm_model_module = module.export_module()

def save(
self,
model_path: str,
Expand All @@ -38,18 +65,12 @@ def save(
arguments are `save_options` and `serving_signatures`.
"""

if self._module.export_version() != constants.ExportModelType.ORBAX_MODEL:
raise ValueError(
"JaxModule is not of type ORBAX_MODEL. Please use the correct"
" export_version. Expected ORBAX_MODEL, got"
f" {self._module.export_version()}"
)

def load(self, model_path: str, **kwargs: Any):
"""Loads the model previously saved in the Orbax Model export format."""
logging.info("Loading model using Orbax Export Model.")
raise NotImplementedError("ObmExport.load not implemented yet.")


@property
def serving_signatures(self) -> Mapping[str, Callable[..., Any]]:
"""Returns a map of signature keys to serving functions."""
Expand Down
13 changes: 13 additions & 0 deletions export/orbax/export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,16 @@ def make_e2e_inference_fn(
return with_default_args(
infer_step_func_map[signature_key], serving_config.get_input_signature()
)


def _to_jax_dtype(t):
if isinstance(t, tf.DType):
return t.as_numpy_dtype()
return t


def to_jax_spec(tree: PyTree) -> PyTree:
return jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, _to_jax_dtype(x.dtype)), tree
)