Skip to content

Commit

Permalink
Add save_state_dict function
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 10, 2024
1 parent 9c25aed commit 71bc5c8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
29 changes: 29 additions & 0 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .filesystem import RemoteFileSystemReader, RemoteFileSystemWriter

__all__ = [
"save_state_dict",
"save_model_and_optim_state",
"async_save_model_and_optim_state",
"load_model_and_optim_state",
Expand All @@ -53,6 +54,34 @@
log = logging.getLogger(__name__)


@torch.no_grad()
def save_state_dict(
dir: PathOrStr,
state_dict: Dict[str, Any],
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
):
"""
Save an arbitrary state dictionary to a distributed format that can loaded again with
a different distributed topology.
.. important::
Please use :func:`save_model_and_optim_state` to save model/optimizer state dicts instead
unless you know what you're doing.
:param dir: Path/URL to save to.
:param state_dict: The state dict to save.
:param process_group: The process group to use for distributed collectives.
:param save_overwrite: Overwrite existing files.
"""
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
process_group=process_group,
)


@torch.no_grad()
def save_model_and_optim_state(
dir: PathOrStr,
Expand Down
17 changes: 13 additions & 4 deletions src/test/distributed/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
async_save_model_and_optim_state,
load_model_and_optim_state,
save_model_and_optim_state,
save_state_dict,
unshard_checkpoint,
)

Expand Down Expand Up @@ -106,6 +107,10 @@ def run_save_and_load_tensor_parallel_model(dir, take_step_before_checkpoint, ru
tp_mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),))

feed_forward = FeedForward().to(get_default_device())

# Save a checkpoint from the unsharded model.
save_state_dict(dir / "unsharded", {"model": feed_forward.state_dict()})

parallelize_module(
feed_forward,
tp_mesh,
Expand All @@ -131,9 +136,9 @@ def run_save_and_load_tensor_parallel_model(dir, take_step_before_checkpoint, ru

# Save checkpoint.
if run_async:
async_save_model_and_optim_state(dir, feed_forward, optim).result()
async_save_model_and_optim_state(dir / "sharded", feed_forward, optim).result()
else:
save_model_and_optim_state(dir, feed_forward, optim)
save_model_and_optim_state(dir / "sharded", feed_forward, optim)

# Create another sharded model, load the checkpoint and make sure the state matches.
feed_forward2 = FeedForward().to(get_default_device())
Expand All @@ -149,14 +154,18 @@ def run_save_and_load_tensor_parallel_model(dir, take_step_before_checkpoint, ru
},
)
optim2 = torch.optim.AdamW(feed_forward2.parameters())
load_model_and_optim_state(dir, feed_forward2, optim2)
load_model_and_optim_state(dir / "sharded", feed_forward2, optim2)
torch.testing.assert_close(feed_forward.state_dict(), feed_forward2.state_dict())
torch.testing.assert_close(optim.state_dict(), optim2.state_dict())

# Now load the checkpoint with a different topology, in this case an unsharded model.
unsharded_feed_forward = FeedForward().to(get_default_device())
unsharded_optim = torch.optim.AdamW(unsharded_feed_forward.parameters())
load_model_and_optim_state(dir, unsharded_feed_forward, unsharded_optim)
load_model_and_optim_state(dir / "sharded", unsharded_feed_forward, unsharded_optim)

# Now make sure we can load the checkpoint saved from the original unsharded model into both models.
load_model_and_optim_state(dir / "unsharded", unsharded_feed_forward)
load_model_and_optim_state(dir / "unsharded", feed_forward2)


@pytest.mark.parametrize("backend", BACKENDS)
Expand Down

0 comments on commit 71bc5c8

Please sign in to comment.