diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d3498a9..68763b66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `olmo_core.distributed.checkpoint.get_checkpoint_metadata()` function. + ### Fixed - Old ephemeral checkpoints won't be removed until after the latest ephemeral checkpoint is saved successfully. diff --git a/src/olmo_core/distributed/checkpoint/__init__.py b/src/olmo_core/distributed/checkpoint/__init__.py index 4ca57bfc..18b4dd3b 100644 --- a/src/olmo_core/distributed/checkpoint/__init__.py +++ b/src/olmo_core/distributed/checkpoint/__init__.py @@ -35,6 +35,7 @@ import torch.distributed.checkpoint as dist_cp import torch.distributed.checkpoint.state_dict as dist_cp_sd import torch.nn as nn +from torch.distributed.checkpoint.metadata import Metadata from olmo_core.aliases import PathOrStr from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path @@ -49,6 +50,7 @@ "async_save_model_and_optim_state", "load_model_and_optim_state", "unshard_checkpoint", + "get_checkpoint_metadata", ] log = logging.getLogger(__name__) @@ -310,6 +312,17 @@ def save(state_dict: Dict[str, Any], path: Path): return model_path, optim_path +def get_checkpoint_metadata(dir: PathOrStr) -> Metadata: + """ + Load the metadata from a checkpoint. + + :param dir: The path/URL to the checkpoint. + """ + dir = normalize_path(dir) + storage_reader = RemoteFileSystemReader(dir) + return storage_reader.read_metadata() + + def _prepare_env_for_save( dir: PathOrStr, process_group: Optional[dist.ProcessGroup] = None,