Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557561606
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Aug 16, 2023
1 parent 5947368 commit ebb8a8e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
22 changes: 12 additions & 10 deletions checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
TypeHandler = type_handlers.TypeHandler
AggregateHandler = aggregate_handlers.AggregateHandler
MsgpackHandler = aggregate_handlers.MsgpackHandler
TransformFn = Callable[[PyTree, PyTree, PyTree], Tuple[PyTree, PyTree]]
LegacyTransformFn = Callable[[PyTree, PyTree, PyTree], Tuple[PyTree, PyTree]]
Transform = transform_utils.Transform
RestoreTransform = transform_utils.RestoreTransform
JsonCheckpointHandler = json_checkpoint_handler.JsonCheckpointHandler
Expand Down Expand Up @@ -849,7 +849,7 @@ def restore(
restore_args: Optional[PyTree] = None,
transforms: Optional[PyTree] = None,
transforms_default_to_original: bool = True,
transform_fn: Optional[TransformFn] = None,
legacy_transform_fn: Optional[LegacyTransformFn] = None,
) -> PyTree:
"""Restores a PyTree from the checkpoint directory at the given path.
Expand Down Expand Up @@ -940,9 +940,9 @@ class TrainState:
completely.
See `transform_utils` for further information.
transforms_default_to_original: See transform_utils.apply_transformations.
transform_fn: WARNING: NOT GENERALLY SUPPORTED. A function which accepts
the `item` argument, a PyTree checkpoint structure and a PyTree of
ParamInfos based on the checkpoint. Returns a transformed PyTree
legacy_transform_fn: WARNING: NOT GENERALLY SUPPORTED. A function which
accepts the `item` argument, a PyTree checkpoint structure and a PyTree
of ParamInfos based on the checkpoint. Returns a transformed PyTree
matching the desired return tree structure, and a matching ParamInfo
tree.
Expand Down Expand Up @@ -982,10 +982,12 @@ async def _create_byte_limiter():
transforms_default_to_original=transforms_default_to_original,
)

if transform_fn is not None and transforms is not None:
raise ValueError('Cannot provide both `transforms` and `transform_fn`.')
if transform_fn is not None:
structure, param_infos = transform_fn(item, structure, param_infos)
if legacy_transform_fn is not None and transforms is not None:
raise ValueError(
'Cannot provide both `transforms` and `legacy_transform_fn`.'
)
if legacy_transform_fn is not None:
structure, param_infos = legacy_transform_fn(item, structure, param_infos)
if restore_args is None:
restore_args = jax.tree_util.tree_map(lambda x: RestoreArgs(), item)
checkpoint_restore_args = restore_args
Expand All @@ -1009,7 +1011,7 @@ def _maybe_set_default_restore_types(
self._maybe_deserialize(structure, param_infos, checkpoint_restore_args)
)

if not transform_fn:
if not legacy_transform_fn:
restored_item = _transform_checkpoint(
item,
restored_item,
Expand Down
8 changes: 2 additions & 6 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,7 @@ def __init__(
'Must provide a ts.Context if use_ocdbt is True. Ensure that the'
' context contains a coordinator address.'
)
self._ts_context = ts_context or ts.Context(
{'file_io_concurrency': {'limit': 128}}
)
self._ts_context = ts_context or serialization.TS_CONTEXT

def enable_ocdbt(self, ts_context: ts.Context) -> None:
self._use_ocdbt = True
Expand Down Expand Up @@ -651,9 +649,7 @@ def __init__(
'Must provide a ts.Context if use_ocdbt is True. Ensure that the'
' context contains a coordinator address.'
)
self._ts_context = ts_context or ts.Context(
{'file_io_concurrency': {'limit': 128}}
)
self._ts_context = ts_context or serialization.TS_CONTEXT

def enable_ocdbt(self, ts_context: ts.Context) -> None:
self._use_ocdbt = True
Expand Down
4 changes: 3 additions & 1 deletion checkpoint/orbax/checkpoint/value_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class ArrayMetadata(Metadata):
Tuple of integers describing the array shape.
shards:
Tuple of integers indicating how many shards each dimension is divided
into. May be None if the array is not sharded.
into. E.g. a dimension may be 1 if it is unsharded, or 2 if it is divided
into 2 chunks.
May be None if the array is not sharded.
dtype:
Dtype of array elements.
"""
Expand Down

0 comments on commit ebb8a8e

Please sign in to comment.