Skip to content

Commit

Permalink
Pre-download only relevant files per rank. (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras authored Dec 19, 2024
1 parent d67f767 commit aa4afb3
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ license = { file = "LICENSE" }
dependencies = [
"numpy<2.0",
"torch>=2.5.1",
"cached-path",
"cached-path@git+https://github.com/allenai/cached_path@shanea/cache-gs-client",
"requests",
"packaging",
"rich",
Expand Down
13 changes: 10 additions & 3 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.state_dict as dist_cp_sd
import torch.nn as nn
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.metadata import Metadata

from olmo_core.aliases import PathOrStr
Expand Down Expand Up @@ -119,10 +120,12 @@ def save_model_and_optim_state(
"""
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
state_dict = _prepare_state_dict(model, optim=optim, process_group=process_group)
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
process_group=process_group,
planner=planner,
)


Expand All @@ -142,10 +145,12 @@ def async_save_model_and_optim_state(
"""
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
state_dict = _prepare_state_dict(model, optim=optim, process_group=process_group)
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
return dist_cp.state_dict_saver.async_save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
process_group=process_group,
planner=planner,
)


Expand All @@ -157,6 +162,7 @@ def load_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
key_mapping: Optional[Dict[str, str]] = None,
pre_download: bool = False,
):
"""
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
Expand Down Expand Up @@ -195,7 +201,7 @@ def load_model_and_optim_state(
"""
dir = normalize_path(dir)
state_dict = _prepare_state_dict(model, optim, process_group=process_group)
reader = RemoteFileSystemReader(dir)
reader = RemoteFileSystemReader(dir, pre_download=pre_download)

if key_mapping is not None:
metadata = reader.read_metadata()
Expand Down Expand Up @@ -264,6 +270,7 @@ def unshard_checkpoint(
optim: Optional[bool] = None,
save_overwrite: bool = False,
use_safetensors: bool = False,
pre_download: bool = False,
) -> Tuple[Path, Optional[Path]]:
"""
Convert a checkpoint saved via :func:`save_model_and_optim_state()` into unsharded
Expand Down Expand Up @@ -331,7 +338,7 @@ def save(state_dict: Dict[str, Any], path: Path):
model_sd: Dict[str, Any] = {}
_load_state_dict(
model_sd,
storage_reader=RemoteFileSystemReader(dir),
storage_reader=RemoteFileSystemReader(dir, pre_download=pre_download),
planner=_EmptyStateDictLoadPlanner(keys=["model"]),
no_dist=True,
)
Expand All @@ -345,7 +352,7 @@ def save(state_dict: Dict[str, Any], path: Path):
optim_sd: Dict[str, Any] = {}
_load_state_dict(
optim_sd,
storage_reader=RemoteFileSystemReader(dir),
storage_reader=RemoteFileSystemReader(dir, pre_download=pre_download),
planner=_EmptyStateDictLoadPlanner(keys=["optim"]),
no_dist=True,
)
Expand Down
11 changes: 9 additions & 2 deletions src/olmo_core/distributed/checkpoint/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,18 +272,25 @@ class RemoteFileSystemReader(dist_cp.StorageReader):
that can read data directly from cloud storage as well as a local directory.
"""

def __init__(self, path: PathOrStr, *, thread_count: Optional[int] = None):
def __init__(
self, path: PathOrStr, *, thread_count: Optional[int] = None, pre_download: bool = False
):
super().__init__()
if thread_count is not None and thread_count <= 0:
raise ValueError("thread count must be at least 1")
self.path = normalize_path(path)
self.thread_count = thread_count or get_default_thread_count()
self.pre_download = pre_download
self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
self.load_id = generate_uuid()
self._metadata: Optional[Metadata] = None

def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
return get_bytes_range(f"{self.path}/{relative_path}", offset, length)
if self.pre_download:
full_path = str(resource_path(self.path, relative_path))
else:
full_path = f"{self.path}/{relative_path}"
return get_bytes_range(full_path, offset, length)

def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
sinfo = self.storage_data[read_item.storage_index]
Expand Down
16 changes: 1 addition & 15 deletions src/olmo_core/train/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,21 +157,6 @@ def load(
"""
dir = normalize_path(dir)

if is_url(dir) and self.pre_download:
target = self.work_dir / "load" / os.path.basename(dir)
log.info(f"Pre-downloading checkpoint from '{dir}' to '{target}'...")
if get_fs_local_rank() == 0:
copy_dir(dir, target, save_overwrite=self.save_overwrite)
barrier(self.process_group)
return self.load(
target,
model,
optim,
load_optimizer_state=load_optimizer_state,
load_trainer_state=load_trainer_state,
key_mapping=key_mapping,
)

# Maybe load trainer state.
trainer_state: Optional[Dict[str, Any]] = None
if load_trainer_state:
Expand All @@ -193,6 +178,7 @@ def load(
optim if load_optimizer_state else None,
process_group=self.process_group,
key_mapping=key_mapping,
pre_download=is_url(dir) and self.pre_download,
)

return trainer_state
Expand Down

0 comments on commit aa4afb3

Please sign in to comment.