-
Notifications
You must be signed in to change notification settings - Fork 422
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ckpt-rewr] Save state dict API (#3372)
- Loading branch information
1 parent
1dfd3bc
commit f7e17de
Showing
6 changed files
with
377 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright 2024 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Useful functions for saving state dicts to disk.""" | ||
|
||
import logging | ||
import os | ||
import textwrap | ||
import warnings | ||
from pathlib import Path | ||
from typing import Any, Dict, Optional, Union | ||
|
||
import torch | ||
import torch.distributed.checkpoint as DCP | ||
from packaging import version | ||
from torch.distributed._shard.sharded_tensor import ShardedTensor | ||
from torch.distributed._tensor import DTensor | ||
|
||
from composer.utils import dist | ||
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def save_state_dict_to_disk( | ||
state_dict: Dict[str, Any], | ||
destination_file_path: str, | ||
overwrite: bool = False, | ||
save_format: str = 'pt', # or hf, safetensor | ||
) -> Optional[str]: | ||
"""Saves a state dict to local disk. | ||
Args: | ||
state_dict (Dict[str,Any]): The state dict to save. | ||
destination_file_path (str): The path to save the state dict to. If sharded, | ||
this should be the pth to a directory. Otherwise, it should be a path to a file. | ||
overwrite (bool): If True, the file will be overwritten if it exists. | ||
save_format (str): The format to save the state dict in. One of 'pt', 'hf', or 'safetensor'. | ||
Returns: | ||
str: The full path to the saved state dict if (sharded is false and rank 0) or if sharded is true, otherwise None. | ||
""" | ||
if state_dict == {}: | ||
return None | ||
if is_state_dict_sharded(state_dict): | ||
path_saved = _save_sharded_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format) | ||
else: | ||
if dist.get_global_rank() == 0: | ||
path_saved = _save_full_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format) | ||
else: | ||
path_saved = None | ||
|
||
return path_saved | ||
|
||
|
||
def _save_sharded_state_dict_to_disk( | ||
state_dict: Dict[str, Any], | ||
destination_file_path: str, | ||
overwrite: bool = False, | ||
save_format: str = 'pt', | ||
) -> Optional[str]: | ||
|
||
if save_format != 'pt': | ||
raise NotImplementedError( | ||
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].", | ||
) | ||
|
||
if state_dict == {}: | ||
return None | ||
|
||
# If user specifies filename instead of directory suffixes, strip them and warn | ||
if len(Path(destination_file_path).suffixes) > 0: | ||
stripped_path = _strip_suffixes(destination_file_path) | ||
warnings.warn( | ||
textwrap.dedent( | ||
f"""Sharded checkpoints require a directory path not a file path: | ||
{destination_file_path} will have its extensions stripped and checkpoints will be saved in {stripped_path} | ||
as {stripped_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}""", | ||
), | ||
) | ||
destination_file_path = stripped_path | ||
|
||
if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path): | ||
raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.') | ||
|
||
log.debug( | ||
f'Starting saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}', | ||
) | ||
|
||
# For 2.3.0 and above you can use checkpoint_id, but this version works the best for all versions | ||
# of torch (and makes pyright happier) that we support, so we use it for now. | ||
if version.parse(torch.__version__) < version.parse('2.2.0'): | ||
DCP.save_state_dict(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path)) | ||
else: | ||
DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path)) | ||
|
||
return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME | ||
|
||
|
||
def _save_full_state_dict_to_disk( | ||
state_dict: Dict[str, Any], | ||
destination_file_path: str, | ||
overwrite: bool = False, | ||
save_format: str = 'pt', # or hf, safetensor | ||
) -> Optional[str]: | ||
|
||
if save_format != 'pt': | ||
raise NotImplementedError( | ||
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].", | ||
) | ||
|
||
if not overwrite and os.path.exists(destination_file_path): | ||
raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.') | ||
|
||
if dist.get_global_rank() == 0: | ||
_write_checkpoint_file(state_dict=state_dict, filename=destination_file_path) | ||
return destination_file_path | ||
return None | ||
|
||
|
||
def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool: | ||
"""Determines if the state dict is sharded. | ||
Args: | ||
state_dict (Dict[str, Any]): The state dict to check. | ||
Returns: | ||
bool: Whether the state dict is sharded. | ||
""" | ||
for value in state_dict.values(): | ||
if isinstance(value, ShardedTensor) or isinstance(value, DTensor): | ||
return True | ||
if isinstance(value, Dict): | ||
is_sharded = is_state_dict_sharded(value) | ||
if is_sharded: | ||
return True | ||
return False | ||
|
||
|
||
def _strip_suffixes(path: Union[str, Path]) -> str: | ||
path = Path(path) | ||
for _ in path.suffixes: | ||
path = path.with_suffix('') | ||
|
||
return str(path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright 2024 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Dict | ||
|
||
import torch | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.distributed.fsdp.api import CPUOffload | ||
from torch.optim import adam | ||
|
||
from tests.common.models import EvenSimplerMLP, SimpleComposerMLP | ||
|
||
__all__ = [ | ||
'init_model_and_optimizer', | ||
'init_model', | ||
'init_optimizer', | ||
] | ||
|
||
|
||
def init_model_and_optimizer( | ||
use_composer_model: bool, | ||
num_classes=3, | ||
batch_size=5, | ||
num_features=8, | ||
take_step=True, | ||
use_fsdp=False, | ||
tensor_type='sharded_tensor', | ||
device='cuda', | ||
): | ||
model, loss_fn = init_model( | ||
use_composer_model, | ||
num_classes=num_classes, | ||
num_features=num_features, | ||
use_fsdp=use_fsdp, | ||
tensor_type=tensor_type, | ||
device=device, | ||
) | ||
|
||
optimizer = init_optimizer( | ||
model, | ||
loss_fn, | ||
use_composer_model=use_composer_model, | ||
num_classes=num_classes, | ||
batch_size=batch_size, | ||
num_features=num_features, | ||
take_step=take_step, | ||
device=device, | ||
) | ||
|
||
return model, optimizer | ||
|
||
|
||
def init_model( | ||
use_composer_model: bool = False, | ||
num_classes=3, | ||
num_features=8, | ||
use_fsdp=False, | ||
device='cuda', | ||
tensor_type='sharded_tensor', | ||
sync_module_states=True, | ||
cpu_offload=False, | ||
): | ||
if use_composer_model: | ||
model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device=device) | ||
loss_fn = model._loss_fn | ||
else: | ||
model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device=device) | ||
loss_fn = torch.nn.CrossEntropyLoss() | ||
|
||
if use_fsdp: | ||
fsdp_kwargs: Dict[str, Any] = dict( | ||
use_orig_params=True, | ||
sync_module_states=sync_module_states, # To enable easy comparison between rank 0 unsharded model and full state dict | ||
cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None, | ||
device_id=torch.device('cpu') if device == 'cpu' else None, | ||
) | ||
|
||
if tensor_type == 'dtensor': | ||
from torch.distributed.device_mesh import init_device_mesh | ||
device_mesh = init_device_mesh('cuda', (2,)) | ||
fsdp_kwargs['device_mesh'] = device_mesh | ||
|
||
model = FSDP( | ||
model, | ||
**fsdp_kwargs, | ||
) | ||
|
||
return model, loss_fn | ||
|
||
|
||
def init_optimizer( | ||
model, | ||
loss_fn, | ||
use_composer_model: bool = False, | ||
num_classes=3, | ||
batch_size=5, | ||
num_features=8, | ||
take_step=True, | ||
device='cuda', | ||
): | ||
inputs = torch.randn(batch_size, num_features, device=device) | ||
targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device, dtype=torch.long) | ||
batch = (inputs, targets) if use_composer_model else inputs | ||
optimizer = adam.Adam(model.parameters()) | ||
outputs = model(batch) | ||
loss = loss_fn(outputs, targets) | ||
loss.backward() | ||
if take_step: | ||
optimizer.step() | ||
return optimizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright 2024 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import time | ||
import uuid | ||
from copy import deepcopy | ||
from pathlib import Path | ||
|
||
import pytest | ||
import torch | ||
import torch.distributed.checkpoint as DCP | ||
from packaging import version | ||
|
||
from composer.checkpoint.save import save_state_dict_to_disk | ||
from composer.checkpoint.state_dict import get_model_state_dict | ||
from composer.utils import dist | ||
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME | ||
from tests.checkpoint.helpers import init_model | ||
from tests.common.compare import deep_compare | ||
from tests.common.markers import world_size | ||
|
||
|
||
@world_size(1, 2) | ||
@pytest.mark.gpu | ||
@pytest.mark.parametrize('sharded_model', [False, True]) | ||
def test_save_full_state_dict_to_disk(world_size: int, tmp_path: str, sharded_model: bool): | ||
if world_size == 1 and sharded_model: | ||
pytest.skip("Can't have a sharded model for world_size = 1") | ||
destination_file_path = os.path.join(tmp_path, 'test.pt') | ||
use_fsdp = sharded_model | ||
model, _ = init_model(use_fsdp=use_fsdp, device='cuda', sync_module_states=True) | ||
|
||
state_dict = get_model_state_dict(model, sharded_state_dict=False) | ||
path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path) | ||
time.sleep(1) | ||
if dist.get_global_rank() == 0: | ||
assert path_saved is not None | ||
assert path_saved == destination_file_path | ||
assert os.path.exists(destination_file_path), f'{destination_file_path} does not exist' | ||
loaded_state_dict = torch.load(path_saved, map_location='cuda') | ||
deep_compare(state_dict, loaded_state_dict) | ||
else: | ||
assert path_saved is None | ||
|
||
|
||
@world_size(2) | ||
@pytest.mark.gpu | ||
@pytest.mark.parametrize( | ||
'tensor_type', | ||
[ | ||
'sharded_tensor', | ||
pytest.param( | ||
'dtensor', | ||
marks=pytest.mark.skipif( | ||
version.parse(torch.__version__) < version.parse('2.2.0'), | ||
reason='Requires torch>=2.2.0 for dtensor', | ||
), | ||
), | ||
], | ||
) | ||
def test_save_sharded_state_dict_to_disk(world_size: int, tmp_path: str, tensor_type: str): | ||
|
||
destination_file_path = os.path.join(tmp_path, str(uuid.uuid4())[:8]) | ||
# Sync the path across all ranks | ||
destination_file_path = dist.all_gather_object(destination_file_path)[0] | ||
model, _ = init_model(use_fsdp=True, device='cuda', tensor_type=tensor_type) | ||
|
||
state_dict = get_model_state_dict(model, sharded_state_dict=True) | ||
loaded_in_state_dict = deepcopy(state_dict) | ||
path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path, overwrite=True) | ||
assert path_saved == f'{destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}' | ||
assert path_saved is not None | ||
load_path = str(Path(path_saved).parent) | ||
if version.parse(torch.__version__) < version.parse('2.2.0'): | ||
DCP.load_state_dict(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path)) | ||
else: | ||
DCP.load(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path)) | ||
deep_compare(state_dict, loaded_in_state_dict) |
Oops, something went wrong.