From 48043a14ef6860627ad1de3df4d237cdfcdaf04b Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Fri, 15 Nov 2024 11:48:42 -0800 Subject: [PATCH] Removes the build file out of the obm_module and instead processes the serving_configs in the constructor of obm_export to match what is done with the Tensorflow Export. PiperOrigin-RevId: 696955640 --- export/orbax/export/export_manager.py | 13 +- export/orbax/export/jax_module.py | 1 + export/orbax/export/modules/obm_module.py | 136 ++++++------------ .../orbax/export/modules/obm_module_test.py | 2 - export/orbax/export/obm_export.py | 37 +++-- export/orbax/export/utils.py | 2 + 6 files changed, 76 insertions(+), 115 deletions(-) diff --git a/export/orbax/export/export_manager.py b/export/orbax/export/export_manager.py index 3b9da718..72f0ae3c 100644 --- a/export/orbax/export/export_manager.py +++ b/export/orbax/export/export_manager.py @@ -31,7 +31,6 @@ obx_export_config = config.config maybe_reraise = reraise_utils.maybe_reraise - class ExportManager: """Exports a JAXModule with pre- and post-processors.""" @@ -50,11 +49,11 @@ def __init__( version: the version of the export format to use. Defaults to TF_SAVEDMODEL. """ - if version != module.export_version(): + if version != module.export_version: raise ValueError( '`version` and `module.export_version()`' f' must be the same. The former is {version}. The latter is ' - f'{module.export_version()}.' + f'{module.export_version}.' ) self._version = version self._jax_module = module @@ -62,14 +61,6 @@ def __init__( self._serialization_functions = obm_export.ObmExport( self._jax_module, serving_configs ) - obm_module_ = module.orbax_module() - if not isinstance(obm_module_, obm_module.ObmModule): - raise ValueError( - 'module.orbax_module() must return an `ObmModule`. ' - f'Got type: {type(obm_module_)}' - ) - # TODO(bdwalker): Let `ObmExport.__init__() do this `build()` step. - obm_module_.build(serving_configs) else: self._serialization_functions = tensorflow_export.TensorFlowExport( self._jax_module, serving_configs diff --git a/export/orbax/export/jax_module.py b/export/orbax/export/jax_module.py index ba4865fc..6956b60d 100644 --- a/export/orbax/export/jax_module.py +++ b/export/orbax/export/jax_module.py @@ -134,6 +134,7 @@ def model_params(self) -> PyTree: """Returns the model parameters.""" return self._export_module.model_params + @property def export_version(self) -> constants.ExportModelType: """Returns the export version.""" return self._export_version diff --git a/export/orbax/export/modules/obm_module.py b/export/orbax/export/modules/obm_module.py index 3f5a2552..fb4762f0 100644 --- a/export/orbax/export/modules/obm_module.py +++ b/export/orbax/export/modules/obm_module.py @@ -14,24 +14,22 @@ """Wraps JAX functions and parameters into a tf.Module.""" -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Mapping +import copy import logging -from typing import Any, Tuple, Union +from typing import Any, Optional, Union import jax -from jax import export as jax_export from orbax.export import constants -from orbax.export import serving_config as osc from orbax.export import typing as orbax_export_typing -# from orbax.export import utils from orbax.export.modules import orbax_module_base from orbax.export.typing import PyTree import tensorflow as tf - ApplyFn = orbax_export_typing.ApplyFn +# TODO(bdwalker): Remove this function and just check for Jax data types. def _to_jax_dtype(t): if isinstance(t, tf.DType): return t.as_numpy_dtype() @@ -44,12 +42,6 @@ def _to_jax_spec(tree: PyTree) -> PyTree: ) -def _to_sequence(a): - if isinstance(a, Sequence): - return a - return (a,) - - class ObmModule(orbax_module_base.OrbaxModuleBase): """A data module for encapsulating the data for a Jax model to be serialized through the Orbax Model export flow.""" @@ -69,7 +61,6 @@ def __init__( 'native_serialization_platform', 'flatten_signature', 'weights_name'and 'checkpoint_path'. """ - self._params = params # It is possible for jax2obm_kwargs to be None if the key is present. if not jax2obm_kwargs: @@ -78,11 +69,30 @@ def __init__( self._apply_fn_map = self._normalize_apply_fn_map( self._normalize_apply_fn_map(apply_fn) ) + + if len(self._apply_fn_map) != 1: + raise NotImplementedError( + 'ObmModule: Currently the ObmExport only supports a single method' + f' for export. Received: {self._apply_fn_map}' + ) + self._native_serialization_platform = ( jax2obm_kwargs[constants.NATIVE_SERIALIZATION_PLATFORM] if constants.NATIVE_SERIALIZATION_PLATFORM in jax2obm_kwargs else None ) + supported_platforms = [ + platform.name for platform in constants.OrbaxNativeSerializationType + ] + if ( + self._native_serialization_platform is not None + and self._native_serialization_platform not in supported_platforms + ): + raise ValueError( + 'native_serialization_platforms must be a sequence containing a' + f' subset of: {supported_platforms}' + ) + self._flatten_signature = ( jax2obm_kwargs[constants.FLATTEN_SIGNATURE] if constants.FLATTEN_SIGNATURE in jax2obm_kwargs @@ -96,43 +106,10 @@ def __init__( self._params_args_spec = _to_jax_spec(params) + self._checkpoint_path: str = None # Set the Orbax checkpoint path if provided in the jax2obm_kwargs. self._maybe_set_orbax_checkpoint_path(jax2obm_kwargs) - self.built = False - - def build( - self, - serving_configs: Sequence[osc.ServingConfig], - ) -> None: - if self.built: - raise ValueError( - 'The `build` method has already been called.' - ' It can only be called once.' - ) - self._verify_serving_configs(serving_configs) - - # Currently there will only ever be a single item in the mapping. - if len(self._apply_fn_map) != 1: - raise NotImplementedError( - 'ObmModule: Currently the ObmExport only supports a single method' - f' for export. Received: {self._apply_fn_map}' - ) - - model_function_name, jax_fn = next(iter(self._apply_fn_map.items())) - - self._convert_jax_functions_to_obm_functions( - jax_fn=jax_fn, - jax_fn_name=model_function_name, - params_args_spec=self._params_args_spec, - serving_config=serving_configs[0], - native_serialization_platform=self._native_serialization_platform, - flatten_signature=self._flatten_signature, - support_tf_resources=self._support_tf_resources, - ) - - self.built = True - def _normalize_apply_fn_map( self, apply_fn: Union[ApplyFn, Mapping[str, ApplyFn]] ) -> Mapping[str, ApplyFn]: @@ -147,56 +124,13 @@ def _normalize_apply_fn_map( apply_fn_map = apply_fn return apply_fn_map - def _verify_serving_configs( - self, serving_configs: Sequence[osc.ServingConfig] - ): - if not serving_configs or len(serving_configs) != 1: - raise ValueError( - 'ObmModule: A single serving_config must be provided for Orbax' - ' Model export.' - ) - - if not serving_configs[0].input_signature: - # TODO(wangpeng): Infer input_signature from tf_preprocessor. - raise ValueError( - 'ObmModule: The serving_config must have an input_signature set.' - ) - - if not serving_configs[0].signature_key: - raise ValueError( - 'ObmModule: The serving_config must have a signature_key set.' - ) - - def _convert_jax_functions_to_obm_functions( - self, - *, - jax_fn, - jax_fn_name: str, - params_args_spec: PyTree, - serving_config: osc.ServingConfig, - native_serialization_platform, - flatten_signature: bool, - support_tf_resources: bool, - ): - """Converts the JAX functions to OrbaxModel functions.""" - if serving_config.input_signature is None: - raise ValueError('serving_config.input_signature is required.') - if ( - not support_tf_resources - and serving_config.extra_trackable_resources is not None - ): - raise ValueError( - 'serving_config.extra_trackable_resources can only be set when' - ' support_tf_resources is True.' - ) - def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs): if constants.CHECKPOINT_PATH not in jax2obm_kwargs: return # TODO: b/374195447 - Add a version check for the Orbax checkpointer. - checkpoint_path = jax2obm_kwargs[constants.CHECKPOINT_PATH] - weights_name = ( + self._checkpoint_path = jax2obm_kwargs[constants.CHECKPOINT_PATH] + self._weights_name = ( jax2obm_kwargs[constants.WEIGHTS_NAME] if constants.WEIGHTS_NAME in jax2obm_kwargs else constants.DEFAULT_WEIGHTS_NAME @@ -212,20 +146,34 @@ def apply_fn_map(self) -> Mapping[str, ApplyFn]: """Returns the apply_fn_map from function name to jit'd apply function.""" return self._apply_fn_map + @property + def native_serialization_platform(self) -> Optional[str]: + """Returns the native serialization platform.""" + return self._native_serialization_platform + + @property + def flatten_signature(self) -> bool: + """Returns the flatten signature.""" + return self._flatten_signature + @property def export_version(self) -> constants.ExportModelType: """Returns the export version.""" return constants.ExportModelType.ORBAX_MODEL + def support_tf_resources(self) -> bool: + """Returns True if the model supports TF resources.""" + return self._support_tf_resources + @property def model_params(self) -> PyTree: """Returns the model parameter specs.""" - return self._params + return self._params_args_spec def obm_module_to_jax_exported_map( self, model_inputs: PyTree, - ) -> Mapping[str, jax_export.Exported]: + ) -> Mapping[str, jax.export.Exported]: """Converts the OrbaxModel to jax_export.Exported.""" raise NotImplementedError( 'ObmModule.methods not implemented yet. See b/363061755.' diff --git a/export/orbax/export/modules/obm_module_test.py b/export/orbax/export/modules/obm_module_test.py index 44c453a7..2f866794 100644 --- a/export/orbax/export/modules/obm_module_test.py +++ b/export/orbax/export/modules/obm_module_test.py @@ -17,9 +17,7 @@ import jax import jax.numpy as jnp from orbax.export import constants -from orbax.export import serving_config as osc from orbax.export.modules import obm_module -import tensorflow as tf class ObmModuleTest(parameterized.TestCase): diff --git a/export/orbax/export/obm_export.py b/export/orbax/export/obm_export.py index 1c8f9e77..4cdbe112 100644 --- a/export/orbax/export/obm_export.py +++ b/export/orbax/export/obm_export.py @@ -14,17 +14,44 @@ """Export class that implements the save and load abstract class defined in Export Base for use with the Orbax Model export format.""" -from typing import Any, Callable, Mapping, cast +from typing import Any, Callable, Dict, Mapping, Sequence, Tuple, cast from absl import logging +import jax from orbax.export import constants from orbax.export import export_base +from orbax.export import jax_module +from orbax.export import serving_config as osc +from orbax.export import utils from orbax.export.modules import obm_module +from orbax.export.typing import PyTree +import tensorflow as tf + + +def _to_sequence(a): + if isinstance(a, Sequence): + return a + return (a,) class ObmExport(export_base.ExportBase): """Defines the save and load methods for exporting a model using Orbax Model export.""" + def __init__( + self, + module: jax_module.JaxModule, + serving_configs: Sequence[osc.ServingConfig], + ): + """Initializes the ObmExport class.""" + if module.export_version != constants.ExportModelType.ORBAX_MODEL: + raise ValueError( + "JaxModule export version is not of type ORBAX_MODEL. Please use the" + " correct export_version. Expected ORBAX_MODEL, got" + f" {module.export_version}" + ) + + obm_model_module = module.export_module() + def save( self, model_path: str, @@ -38,18 +65,12 @@ def save( arguments are `save_options` and `serving_signatures`. """ - if self._module.export_version() != constants.ExportModelType.ORBAX_MODEL: - raise ValueError( - "JaxModule is not of type ORBAX_MODEL. Please use the correct" - " export_version. Expected ORBAX_MODEL, got" - f" {self._module.export_version()}" - ) - def load(self, model_path: str, **kwargs: Any): """Loads the model previously saved in the Orbax Model export format.""" logging.info("Loading model using Orbax Export Model.") raise NotImplementedError("ObmExport.load not implemented yet.") + @property def serving_signatures(self) -> Mapping[str, Callable[..., Any]]: """Returns a map of signature keys to serving functions.""" diff --git a/export/orbax/export/utils.py b/export/orbax/export/utils.py index 2ecd7419..c1c6d0f4 100644 --- a/export/orbax/export/utils.py +++ b/export/orbax/export/utils.py @@ -491,3 +491,5 @@ def make_e2e_inference_fn( return with_default_args( infer_step_func_map[signature_key], serving_config.get_input_signature() ) + +