Skip to content

Commit

Permalink
[checkpoint v2] Download api (#3447)
Browse files Browse the repository at this point in the history
* a

* a

* fix lint and test

* lint

* comments

* comment
  • Loading branch information
bigning authored Jul 9, 2024
1 parent 38a3334 commit 7656db4
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 0 deletions.
2 changes: 2 additions & 0 deletions composer/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Module for checkpointing API."""

from composer.checkpoint.download import download_monolithic_checkpoint
from composer.checkpoint.state_dict import (
get_metadata_state_dict,
get_model_state_dict,
Expand All @@ -15,4 +16,5 @@
'get_optim_state_dict',
'get_metadata_state_dict',
'get_resumption_state_dict',
'download_monolithic_checkpoint',
]
85 changes: 85 additions & 0 deletions composer/checkpoint/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for load checkpoints from remote object store or local disk."""

import logging
from typing import Optional

from composer.utils import (
dist,
extract_path_from_symlink,
maybe_create_object_store_from_uri,
parse_uri,
retry,
)

log = logging.getLogger(__name__)


def download_file(
source_uri: str,
destination_path: str,
node_ranks: Optional[list[int]] = None,
num_attempts: int = 5,
):
"""Downloads a file (object) from the specified URI to the specified directory.
Args:
source_uri (str): The URI to download the file from or a symlink to the URI.
destination_path (str): The directory to download the file to.
node_ranks (list[int]): The ranks of the nodes that will download the file. If None, all nodes will download the file.
num_attempts (int): Retry for object store downloads. Default to 5.
"""
# Only local rank 0 downloads
local_rank = dist.get_local_rank()
if local_rank != 0:
return

node_rank = dist.get_node_rank()
if node_ranks is not None and node_rank not in node_ranks:
return

object_store = maybe_create_object_store_from_uri(source_uri)
_, _, source_path = parse_uri(source_uri)
if source_uri.endswith('.symlink'):
source_path = extract_path_from_symlink(source_path, object_store)
assert object_store is not None

@retry(num_attempts=num_attempts)
def _download():
object_store.download_object(
object_name=source_path,
filename=destination_path,
)

log.debug(f'Downloading {source_path} to {destination_path}')
_download()
log.debug(f'Finished downloading {source_path} to {destination_path}')


def download_monolithic_checkpoint(
source_uri: str,
destination_path: str,
global_rank_zero_only: bool = True,
):
"""Downloads a monolithic checkpoint from the specified URI to the specified directory.
Args:
source_uri (str): The URI to download the checkpoint from or symlink that points to the URI.
destination_path (str): The directory to download the checkpoint to.
global_rank_zero_only (bool): If True, only rank 0 will download the checkpoint.
broadcast_files_to_other_nodes (bool): If True, the downloaded checkpoint will be broadcast to all other nodes.
If torch syncs modules states this is unnecessary.
"""
node_ranks = None
if global_rank_zero_only:
node_ranks = [0]
download_file(
source_uri=source_uri,
destination_path=destination_path,
node_ranks=node_ranks,
)
if not global_rank_zero_only or (global_rank_zero_only and dist.get_global_rank() == 0):
return destination_path
return None
2 changes: 2 additions & 0 deletions composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
create_symlink_file,
ensure_folder_has_no_conflicting_files,
ensure_folder_is_empty,
extract_path_from_symlink,
format_name_with_dist,
format_name_with_dist_and_time,
get_file,
Expand Down Expand Up @@ -158,6 +159,7 @@
'ParallelismConfig',
'MLFLOW_EXPERIMENT_ID_FORMAT_KEY',
'MLFLOW_RUN_ID_FORMAT_KEY',
'extract_path_from_symlink',
'RemoteUploader',
'validate_credentials',
'build_remote_backend',
Expand Down
11 changes: 11 additions & 0 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
'maybe_create_object_store_from_uri',
'maybe_create_remote_uploader_downloader_from_uri',
'parse_uri',
'extract_path_from_symlink',
'validate_credentials',
]

Expand All @@ -57,6 +58,16 @@ def extract_path_from_symlink(
source_path: str,
object_store: Optional[Union[LoggerDestination, ObjectStore]] = None,
) -> str:
"""Returns the checkpont path from symlink file.
Args:
source_path(str): The remote symlink path.
object_store(LoggerDestination | ObjectStore, optional): The object store
used to download the remote symlink file
Returns:
str: The content of the remote symlink file.
"""
if object_store is not None:
with tempfile.TemporaryDirectory() as tmpdir:
_, _, source_path = parse_uri(source_path)
Expand Down
56 changes: 56 additions & 0 deletions tests/checkpoint/test_download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import os
import tempfile
from unittest.mock import patch

import pytest
import torch

from composer.checkpoint import download_monolithic_checkpoint
from composer.utils import dist
from tests.checkpoint.helpers import init_model
from tests.common.markers import world_size
from tests.utils.test_remote_uploader import DummyObjectStore


@world_size(1, 2)
@pytest.mark.gpu
@pytest.mark.parametrize('rank_zero_only', [True, False])
def test_download_monolithic_checkpoint(world_size: int, rank_zero_only: bool):
# Write a checkpoint
tmp_dir = tempfile.TemporaryDirectory()
use_fsdp = False
if world_size > 1:
use_fsdp = True
fsdp_model, _ = init_model(use_fsdp=use_fsdp)

from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
state = get_model_state_dict(fsdp_model, options=StateDictOptions(full_state_dict=True))

checkpoint_filename = 'state_dict'
save_filename = os.path.join(tmp_dir.name, checkpoint_filename)
if dist.get_global_rank() == 0:
torch.save(state, save_filename)

class DummyS3ObjectStore(DummyObjectStore):

def get_tmp_dir(self):
return tmp_dir

# Download a monolithic checkpoint
local_file_name = 'state_dict.download'
with patch('composer.utils.file_helpers.S3ObjectStore', DummyS3ObjectStore):
ret = download_monolithic_checkpoint(
source_uri=f's3://bucket_name/{checkpoint_filename}',
destination_path=local_file_name,
global_rank_zero_only=rank_zero_only,
)
dist.barrier()

if rank_zero_only and dist.get_global_rank() != 0:
assert ret == None
if dist.get_global_rank() == 0:
assert ret == local_file_name
assert os.path.isfile(local_file_name) == True
2 changes: 2 additions & 0 deletions tests/utils/test_remote_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def download_object(
overwrite: bool = False,
callback: Optional[Callable[[int, int], None]] = None,
):
if overwrite is False and os.path.isfile(filename):
raise FileExistsError(f'The file at {filename} already exists and overwrite is set to False.')
object_path = pathlib.Path(self.root) / object_name
shutil.copy2(object_path, filename)

Expand Down

0 comments on commit 7656db4

Please sign in to comment.