-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[checkpoint v2] Download api (#3447)
* a * a * fix lint and test * lint * comments * comment
- Loading branch information
Showing
6 changed files
with
158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright 2024 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Useful functions for load checkpoints from remote object store or local disk.""" | ||
|
||
import logging | ||
from typing import Optional | ||
|
||
from composer.utils import ( | ||
dist, | ||
extract_path_from_symlink, | ||
maybe_create_object_store_from_uri, | ||
parse_uri, | ||
retry, | ||
) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def download_file( | ||
source_uri: str, | ||
destination_path: str, | ||
node_ranks: Optional[list[int]] = None, | ||
num_attempts: int = 5, | ||
): | ||
"""Downloads a file (object) from the specified URI to the specified directory. | ||
Args: | ||
source_uri (str): The URI to download the file from or a symlink to the URI. | ||
destination_path (str): The directory to download the file to. | ||
node_ranks (list[int]): The ranks of the nodes that will download the file. If None, all nodes will download the file. | ||
num_attempts (int): Retry for object store downloads. Default to 5. | ||
""" | ||
# Only local rank 0 downloads | ||
local_rank = dist.get_local_rank() | ||
if local_rank != 0: | ||
return | ||
|
||
node_rank = dist.get_node_rank() | ||
if node_ranks is not None and node_rank not in node_ranks: | ||
return | ||
|
||
object_store = maybe_create_object_store_from_uri(source_uri) | ||
_, _, source_path = parse_uri(source_uri) | ||
if source_uri.endswith('.symlink'): | ||
source_path = extract_path_from_symlink(source_path, object_store) | ||
assert object_store is not None | ||
|
||
@retry(num_attempts=num_attempts) | ||
def _download(): | ||
object_store.download_object( | ||
object_name=source_path, | ||
filename=destination_path, | ||
) | ||
|
||
log.debug(f'Downloading {source_path} to {destination_path}') | ||
_download() | ||
log.debug(f'Finished downloading {source_path} to {destination_path}') | ||
|
||
|
||
def download_monolithic_checkpoint( | ||
source_uri: str, | ||
destination_path: str, | ||
global_rank_zero_only: bool = True, | ||
): | ||
"""Downloads a monolithic checkpoint from the specified URI to the specified directory. | ||
Args: | ||
source_uri (str): The URI to download the checkpoint from or symlink that points to the URI. | ||
destination_path (str): The directory to download the checkpoint to. | ||
global_rank_zero_only (bool): If True, only rank 0 will download the checkpoint. | ||
broadcast_files_to_other_nodes (bool): If True, the downloaded checkpoint will be broadcast to all other nodes. | ||
If torch syncs modules states this is unnecessary. | ||
""" | ||
node_ranks = None | ||
if global_rank_zero_only: | ||
node_ranks = [0] | ||
download_file( | ||
source_uri=source_uri, | ||
destination_path=destination_path, | ||
node_ranks=node_ranks, | ||
) | ||
if not global_rank_zero_only or (global_rank_zero_only and dist.get_global_rank() == 0): | ||
return destination_path | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright 2024 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import tempfile | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
import torch | ||
|
||
from composer.checkpoint import download_monolithic_checkpoint | ||
from composer.utils import dist | ||
from tests.checkpoint.helpers import init_model | ||
from tests.common.markers import world_size | ||
from tests.utils.test_remote_uploader import DummyObjectStore | ||
|
||
|
||
@world_size(1, 2) | ||
@pytest.mark.gpu | ||
@pytest.mark.parametrize('rank_zero_only', [True, False]) | ||
def test_download_monolithic_checkpoint(world_size: int, rank_zero_only: bool): | ||
# Write a checkpoint | ||
tmp_dir = tempfile.TemporaryDirectory() | ||
use_fsdp = False | ||
if world_size > 1: | ||
use_fsdp = True | ||
fsdp_model, _ = init_model(use_fsdp=use_fsdp) | ||
|
||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict | ||
state = get_model_state_dict(fsdp_model, options=StateDictOptions(full_state_dict=True)) | ||
|
||
checkpoint_filename = 'state_dict' | ||
save_filename = os.path.join(tmp_dir.name, checkpoint_filename) | ||
if dist.get_global_rank() == 0: | ||
torch.save(state, save_filename) | ||
|
||
class DummyS3ObjectStore(DummyObjectStore): | ||
|
||
def get_tmp_dir(self): | ||
return tmp_dir | ||
|
||
# Download a monolithic checkpoint | ||
local_file_name = 'state_dict.download' | ||
with patch('composer.utils.file_helpers.S3ObjectStore', DummyS3ObjectStore): | ||
ret = download_monolithic_checkpoint( | ||
source_uri=f's3://bucket_name/{checkpoint_filename}', | ||
destination_path=local_file_name, | ||
global_rank_zero_only=rank_zero_only, | ||
) | ||
dist.barrier() | ||
|
||
if rank_zero_only and dist.get_global_rank() != 0: | ||
assert ret == None | ||
if dist.get_global_rank() == 0: | ||
assert ret == local_file_name | ||
assert os.path.isfile(local_file_name) == True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters