Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CompositeCheckpointHandler.metadata() to return StepMetadata. #1398

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix namedtuple empty value typestr when experimental support_rich_types is
disabled again after enabling it.

### Changed
- Return `StepMetadata` from `CompositeCheckpointHandler.metadata()`.


## [0.10.1] - 2024-11-22

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,8 @@ def finalize(self, directory: epath.Path) -> None:
def close(self):
"""Closes the CheckpointHandler."""
pass

@property
def typestr(self) -> str:
"""A unique identifier for the CheckpointHandler type."""
return self.__class__.__name__
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,16 @@
from orbax.checkpoint._src.handlers import checkpoint_handler
from orbax.checkpoint._src.handlers import handler_registration
from orbax.checkpoint._src.handlers import proto_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.path import atomicity

CheckpointArgs = checkpoint_args.CheckpointArgs
Future = future.Future
CheckpointArgs = checkpoint_args.CheckpointArgs
CheckpointHandler = checkpoint_handler.CheckpointHandler
StepMetadata = checkpoint.StepMetadata
ItemMetadata = checkpoint.ItemMetadata
AsyncCheckpointHandler = async_checkpoint_handler.AsyncCheckpointHandler
register_with_handler = checkpoint_args.register_with_handler
ProtoCheckpointHandler = proto_checkpoint_handler.ProtoCheckpointHandler
Expand Down Expand Up @@ -494,6 +498,10 @@ def __init__(
self._handler_registry,
)

self._metadata_store = checkpoint.metadata_store(
enable_write=False, blocking_write=False
)

# TODO: b/359524229 - Remove this property as it has been deprecated.
@property
def _known_handlers(self) -> Dict[str, Optional[CheckpointHandler]]:
Expand Down Expand Up @@ -811,7 +819,7 @@ def restore(
)
return CompositeResults(**restored)

def metadata(self, directory: epath.Path) -> CompositeResults:
def metadata(self, directory: epath.Path, **kwargs) -> StepMetadata:
"""Metadata for each item in the checkpoint.

This has much the same logic as `restore`, in the sense that it tries to
Expand All @@ -821,16 +829,51 @@ def metadata(self, directory: epath.Path) -> CompositeResults:

Args:
directory: Path to the checkpoint.
**kwargs: Additional arguments to construct the StepMetadata with during
deserialization. `item_metadata` and `metrics` are supported.

Returns:
CompositeResults
StepMetadata
"""
items_to_handlers = dict(
self._get_all_registered_and_unregistered_items_and_handlers()
)
existing_items = self._existing_items(directory)
try:
existing_items = self._existing_items(directory)
except OSError:
existing_items = []
logging.warning(
'Failed to get existing items from directory %s. Will use items '
'provided during initialization: %s.',
directory, existing_items
)

saved_metadata = checkpoint.StepMetadata()
if not directory.exists():
logging.warning('Directory does not exist: %s', directory)
else:
step_metadata_file_path = checkpoint.step_metadata_file_path(directory)
if not step_metadata_file_path.exists():
logging.warning(
'No step metadata file found in directory %s.', directory
)
else:
logging.info('Found step metadata file in directory %s.', directory)
serialized_metadata = self._metadata_store.read(step_metadata_file_path)
if serialized_metadata is None:
logging.warning(
'Failed to read step metadata from %s.', step_metadata_file_path
)
else:
saved_metadata = step_metadata_serialization.deserialize(
serialized_metadata,
item_metadata=kwargs.get('item_metadata', None),
metrics=kwargs.get('metrics', None),
)
item_handlers = saved_metadata.item_handlers or {}
item_metadata = dict(saved_metadata.item_metadata or {})
assert item_handlers.keys() == item_metadata.keys()

metadata = {}
for item_name in existing_items:
if (
item_name not in items_to_handlers
Expand All @@ -842,14 +885,25 @@ def metadata(self, directory: epath.Path) -> CompositeResults:
' call `restore` with an appropriate `CheckpointArgs` subclass.',
item_name,
)
metadata[item_name] = None
if item_name not in item_handlers:
item_handlers[item_name] = None
if item_name not in item_metadata:
item_metadata[item_name] = None
continue
handler = items_to_handlers[item_name]
assert handler is not None
metadata[item_name] = handler.metadata(
self._get_item_directory(directory, item_name)
)
return CompositeResults(**metadata)
if item_handlers.get(item_name) is None:
item_handlers[item_name] = handler.typestr
if item_metadata.get(item_name) is None:
item_metadata[item_name] = handler.metadata(
self._get_item_directory(directory, item_name)
)

return dataclasses.replace(
saved_metadata,
item_handlers=item_handlers,
item_metadata=ItemMetadata(**item_metadata),
)

def finalize(self, directory: epath.Path):
if not self._current_temporary_paths:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
from orbax.checkpoint._src.handlers import json_checkpoint_handler
from orbax.checkpoint._src.handlers import proto_checkpoint_handler
from orbax.checkpoint._src.handlers import standard_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.metadata import step_metadata_serialization
from orbax.checkpoint._src.metadata import value as value_metadata
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import step
from orbax.checkpoint.logging import step_statistics

CompositeArgs = composite_checkpoint_handler.CompositeArgs
JsonCheckpointHandler = json_checkpoint_handler.JsonCheckpointHandler
Expand Down Expand Up @@ -654,11 +657,7 @@ def test_metadata(self):
handler = CompositeCheckpointHandler(
'extra',
state=StandardCheckpointHandler(),
metadata=JsonCheckpointHandler(),
)
metadata = handler.metadata(self.directory)
self.assertEmpty(metadata.keys())

state = {'a': 1, 'b': 2}
self.save(
handler,
Expand All @@ -667,9 +666,9 @@ def test_metadata(self):
state=args_lib.StandardSave(state),
),
)
metadata = handler.metadata(self.directory)
step_metadata = handler.metadata(self.directory)
self.assertDictEqual(
metadata.state,
step_metadata.item_metadata.state,
{
'a': value_metadata.ScalarMetadata(
name='a', directory=self.directory / 'state', dtype=jnp.int64
Expand All @@ -680,25 +679,34 @@ def test_metadata(self):
},
)
expected_elements = ['state']
self.assertSameElements(metadata.keys(), expected_elements)
self.assertSameElements(
step_metadata.item_metadata.keys(), expected_elements
)

def test_metadata_no_save(self):
handler = CompositeCheckpointHandler(
'extra',
state=StandardCheckpointHandler(),
)
step_metadata = handler.metadata(self.directory)
self.assertIsNone(step_metadata.format)
self.assertEmpty(step_metadata.item_handlers)
self.assertEmpty(step_metadata.item_metadata)
self.assertEmpty(step_metadata.metrics)
self.assertEqual(
step_metadata.performance_metrics, step_statistics.SaveStepStatistics()
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)

def test_metadata_handler_registry(self):
registry = handler_registration.DefaultCheckpointHandlerRegistry()
state_handler = StandardCheckpointHandler()
metadata_handler = JsonCheckpointHandler()
registry.add('state', args_lib.StandardSave, state_handler)
registry.add('state', args_lib.StandardRestore, state_handler)
registry.add(
'metadata',
args_lib.JsonSave,
metadata_handler,
)
registry.add('metadata', args_lib.JsonRestore, metadata_handler)

handler = CompositeCheckpointHandler(handler_registry=registry)
metadata = handler.metadata(self.directory)
self.assertEmpty(metadata.keys())

state = {'a': 1, 'b': 2}
self.save(
handler,
Expand All @@ -707,9 +715,17 @@ def test_metadata_handler_registry(self):
state=args_lib.StandardSave(state),
),
)
metadata = handler.metadata(self.directory)
step_metadata = handler.metadata(self.directory)
self.assertIsNone(step_metadata.format)
self.assertEqual(
step_metadata.item_handlers,
{
'state': 'StandardCheckpointHandler',
'descriptor': 'ProtoCheckpointHandler',
}
)
self.assertDictEqual(
metadata.state,
step_metadata.item_metadata.state,
{
'a': value_metadata.ScalarMetadata(
name='a', directory=self.directory / 'state', dtype=jnp.int64
Expand All @@ -719,8 +735,136 @@ def test_metadata_handler_registry(self):
),
},
)
self.assertEmpty(step_metadata.metrics)
self.assertEqual(
step_metadata.performance_metrics, step_statistics.SaveStepStatistics()
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)
expected_elements = ['state']
self.assertSameElements(metadata.keys(), expected_elements)
self.assertSameElements(
step_metadata.item_metadata.keys(), expected_elements
)

def test_metadata_handler_registry_no_save(self):
registry = handler_registration.DefaultCheckpointHandlerRegistry()
state_handler = StandardCheckpointHandler()
registry.add('state', args_lib.StandardSave, state_handler)
registry.add('state', args_lib.StandardRestore, state_handler)

handler = CompositeCheckpointHandler(handler_registry=registry)
step_metadata = handler.metadata(self.directory)
self.assertIsNone(step_metadata.format)
self.assertEmpty(step_metadata.item_handlers)
self.assertEmpty(step_metadata.item_metadata)
self.assertEmpty(step_metadata.metrics)
self.assertEqual(
step_metadata.performance_metrics, step_statistics.SaveStepStatistics()
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)

def test_metadata_after_step_metadata_write(self):
handler = CompositeCheckpointHandler(
'extra',
state=StandardCheckpointHandler(),
)
step_metadata = handler.metadata(self.directory)
self.assertIsNone(step_metadata.format)
self.assertEmpty(step_metadata.item_handlers)
self.assertEmpty(step_metadata.item_metadata)
self.assertEmpty(step_metadata.metrics)
self.assertEqual(
step_metadata.performance_metrics, step_statistics.SaveStepStatistics()
)
self.assertIsNone(step_metadata.init_timestamp_nsecs)
self.assertIsNone(step_metadata.commit_timestamp_nsecs)
self.assertEmpty(step_metadata.custom)

metadata_to_write = checkpoint.StepMetadata(
format='orbax',
item_handlers={
'state': 'StandardCheckpointHandler',
},
item_metadata={
'state': 123,
},
metrics={
'loss': 1.0,
'accuracy': 0.5,
},
performance_metrics=step_statistics.SaveStepStatistics(
preemption_received_at=1.0,
),
init_timestamp_nsecs=1000,
commit_timestamp_nsecs=2000,
custom={
'custom_key': 'custom_value',
},
)
checkpoint.metadata_store(enable_write=True, blocking_write=True).write(
checkpoint.step_metadata_file_path(self.directory),
step_metadata_serialization.serialize(metadata_to_write)
)

step_metadata = handler.metadata(self.directory)
self.assertEqual(step_metadata.format, 'orbax')
self.assertDictEqual(
step_metadata.item_handlers, {'state': 'StandardCheckpointHandler'}
)
self.assertDictEqual(dict(step_metadata.item_metadata), {'state': None})
self.assertDictEqual(step_metadata.metrics, {'loss': 1.0, 'accuracy': 0.5})
self.assertEqual(
step_metadata.performance_metrics,
step_statistics.SaveStepStatistics(
preemption_received_at=1.0,
)
)
self.assertEqual(step_metadata.init_timestamp_nsecs, 1000)
self.assertEqual(step_metadata.commit_timestamp_nsecs, 2000)
self.assertEqual(step_metadata.custom, {'custom_key': 'custom_value'})

def test_metadata_existing_items_updates_step_metadata(self):
handler = CompositeCheckpointHandler(
'extra',
state=StandardCheckpointHandler(),
)
metadata_to_write = checkpoint.StepMetadata(
item_handlers={
'state': 'StandardCheckpointHandler',
},
item_metadata={
'state': 123,
},
)
checkpoint.metadata_store(enable_write=True, blocking_write=True).write(
checkpoint.step_metadata_file_path(self.directory),
step_metadata_serialization.serialize(metadata_to_write)
)

state = {'a': 1, 'b': 2}
self.save(
handler,
self.directory,
CompositeArgs(
state=args_lib.StandardSave(state),
),
)

step_metadata = handler.metadata(self.directory)
self.assertDictEqual(
step_metadata.item_handlers,
{
'state': 'StandardCheckpointHandler',
'descriptor': 'ProtoCheckpointHandler'
}
)
self.assertSameElements(
step_metadata.item_metadata.keys(), ['state', 'descriptor']
)
self.assertIsNotNone(step_metadata.item_metadata['state'])

def test_finalize(self):
state_handler = mock.create_autospec(StandardCheckpointHandler)
Expand Down
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,8 @@ def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]:
read_step_directory = self._get_read_step_directory(step, self.directory)

result = self._checkpointer.metadata(read_step_directory)
if isinstance(result, checkpoint.StepMetadata):
result = result.item_metadata
if self._default_item is None:
self._default_item = _determine_default_item_mode_from_directory(
read_step_directory
Expand Down
Loading