Skip to content

Commit

Permalink
remove orig_params check (#2981)
Browse files Browse the repository at this point in the history
* remove orig_params check

* expect cpu-cpu test to fail

* expect cpu-cpu test to fail

* fix formatting

* Update test_checkpoint.py

* expect True,True to succeed?

* remove commented code

* rerun tests

* how about now

* fix tests

* Update test_checkpoint.py

* WIP debugging

* wip debug

* Update test_checkpoint.py

* Update test_checkpoint.py

* Update test_checkpoint.py

* Update test_checkpoint.py

* Update test_checkpoint.py

* Update test_checkpoint.py

* WIP

* counterexample is hanging guh

* have repro

* less composer repro

* found bug

* make counterexample match composer wrapper model situation

* still doesn't work

* update counterexample with printed models to show they're wrappped the same

* trim unnecessary stuff

* counterexample for the record

* simplified counterexample

* fix tests

* fix quality and delete counterexample

* fix

* asdict

* remove self (lol if only)

* fix test

* fix fix

* fix

* reset save folder changes

* remove meta for now

* replace load_monolith_ with load_fsdp_monolith_

* change load_fsdp_monolith_rank0_only to load_monolith_rank0_only

* resolve conflicts?

* expect failure when no sync module states

* fix constructor()

* WIP

* WIP

* Update test_fsdp_checkpoint.py

* make PR ready for review

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
milocress and mvpatel2000 authored May 16, 2024
1 parent c854be4 commit 435c295
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 108 deletions.
5 changes: 0 additions & 5 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,6 @@ def __init__(
if self.load_monolith_rank0_only:
assert fsdp_config is not None
error_message = ''
if fsdp_config['use_orig_params'] == True:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires fsdp_config['use_orig_params'] to be False. "
"Either set fsdp_config['use_orig_params'] = False or set load_monolith_rank0_only = False. ",
)
if fsdp_config['sync_module_states'] == False:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. "
Expand Down
102 changes: 1 addition & 101 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,106 +1511,6 @@ def test_set_dataloaders_to_cur_epoch(
# Epoch count starts at O
assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1

@pytest.mark.parametrize(
'world_size',
[
pytest.param(2, marks=pytest.mark.world_size(2)),
],
)
@pytest.mark.parametrize(
'device',
[
pytest.param('gpu', marks=pytest.mark.gpu),
],
)
@pytest.mark.parametrize(
'use_orig_params,sync_module_states,model_1_init_device,model_2_init_device',
[
pytest.param(False, True, 'cpu', 'cpu'), # success
pytest.param(False, True, 'cpu', 'meta'), # success
pytest.param(True, True, 'cpu', 'cpu'), # fail
pytest.param(False, False, 'cpu', 'cpu'), # fail
pytest.param(False, True, 'meta', 'cpu'), # fail
],
)
@pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*')
@pytest.mark.filterwarnings(
'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*',
)
def test_fsdp_monolith_resumption(
self,
device: str,
world_size: int,
use_orig_params: bool,
sync_module_states: bool,
model_1_init_device: str,
model_2_init_device: str,
tmp_path: pathlib.Path,
):
save_interval = '1ba'
save_filename = 'ba{batch}-rank{rank}.pt'
resume_file = 'ba1-rank{rank}.pt'
final_checkpoint = 'latest-rank{rank}.pt'
fsdp_config = {
'use_orig_params': use_orig_params,
'sync_module_states': sync_module_states,
'state_dict_type': 'full',
}

# All ranks use rank 0 folder
tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = pathlib.Path(tmp_paths[0])

trainer_1 = self.get_trainer(
save_folder=os.path.join(save_folder, 'first'),
save_filename=save_filename,
save_interval=save_interval,
eval_interval=save_interval,
fsdp_config=fsdp_config,
device=device,
precision='amp_fp16',
max_duration='1ep',
train_subset_num_batches=2,
)

trainer_1.fit()
trainer_1.close()

self._assert_expected_num_checkpoints(
save_folder=os.path.join(save_folder, 'first'),
save_interval=save_interval,
num_epochs=1, # set in get_trainer()
num_batches_per_epoch=2, # set in get_trainer()
is_deepspeed=False,
)

resume_file = os.path.join(save_folder, 'first', resume_file)
model_init_device = [model_1_init_device, model_2_init_device][dist.get_global_rank()]
fsdp_config['load_monolith_rank0_only'] = True

success = use_orig_params == False and sync_module_states == True and model_1_init_device == 'cpu'
with contextlib.nullcontext() if success else pytest.raises(ValueError):
trainer_2 = self.get_trainer(
model_init_device=model_init_device,
save_folder=os.path.join(save_folder, 'second'),
save_filename=save_filename,
save_interval=save_interval,
eval_interval=save_interval,
fsdp_config=fsdp_config,
device=device,
precision='amp_fp16',
max_duration='1ep',
train_subset_num_batches=2,
load_path=resume_file, # <-- resume training from file
)
trainer_2.fit()
trainer_2.close()

_assert_checkpoints_equivalent(
save_folder / 'first' / final_checkpoint,
save_folder / 'second' / final_checkpoint,
)

@pytest.mark.parametrize('spin_dataloaders', [False, True])
def test_spin_dataloaders(
self,
Expand Down Expand Up @@ -1674,8 +1574,8 @@ def test_format_load_path(self, tmp_path: pathlib.Path):
os.path.join(save_folder, 'second', 'latest-rank{rank}.pt'),
)

@staticmethod
def _assert_expected_num_checkpoints(
self,
save_folder: str,
save_interval: str,
num_epochs: int,
Expand Down
84 changes: 82 additions & 2 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from tests.common import RandomClassificationDataset, deep_compare
from tests.common.compare import deep_compare
from tests.common.markers import world_size
from tests.trainer.test_checkpoint import TestCheckpointResumption, _assert_checkpoints_equivalent


# This model is to be used explicitly for this unit test because some old reference checkpoints
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_trainer(
train_metrics=train_metrics,
val_metrics=val_metrics,
)
model.to(model_init_device)
model.module.to(model_init_device)
dataset = RandomClassificationDataset(shape=(num_features,), size=128)
dataloader = DataLoader(
dataset,
Expand Down Expand Up @@ -325,7 +326,6 @@ def test_fsdp_full_state_dict_load(
autoresume=autoresume,
optimizer=optimizer,
fsdp_config=fsdp_config,
save_weights_only=save_weights_only,
)
trainer1.fit()
state_dict_from_trainer1 = trainer1.state.state_dict()
Expand Down Expand Up @@ -1127,3 +1127,83 @@ def set_up_planner(

trainer2.fit()
trainer2.close()


@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('use_orig_params', [True, False])
@pytest.mark.parametrize('sync_module_states', [True, False])
@pytest.mark.parametrize('model_1_init_device', ['cpu', 'meta'])
@pytest.mark.parametrize('model_2_init_device', ['cpu', 'meta'])
@pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*')
@pytest.mark.filterwarnings(
'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*',
)
def test_fsdp_monolith_resumption(
world_size: int,
use_orig_params: bool,
sync_module_states: bool,
tmp_path: pathlib.Path,
model_1_init_device: str,
model_2_init_device: str,
):
save_interval = '1ba'
save_filename = 'ba{batch}-rank{rank}.pt'
resume_file = 'ba1-rank{rank}.pt'
final_checkpoint = 'latest-rank{rank}.pt'
fsdp_config = FSDPConfig(
use_orig_params=use_orig_params,
sync_module_states=sync_module_states,
state_dict_type='full',
)

# All ranks use rank 0 folder
tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path))
save_folder = pathlib.Path(tmp_paths[0])

trainer_1 = get_trainer(
save_folder=os.path.join(save_folder, 'first'),
save_filename=save_filename,
save_interval=save_interval,
fsdp_config=fsdp_config,
precision='amp_fp16',
max_duration='1ep',
)

trainer_1.fit()
trainer_1.close()

TestCheckpointResumption._assert_expected_num_checkpoints(
save_folder=os.path.join(save_folder, 'first'),
save_interval=save_interval,
num_epochs=1, # set in get_trainer()
num_batches_per_epoch=8, # set in get_trainer()
is_deepspeed=False,
)

resume_file = os.path.join(save_folder, 'first', resume_file)
model_init_device = [model_1_init_device, model_2_init_device][dist.get_global_rank()]
fsdp_config_dict = dataclasses.asdict(fsdp_config)
fsdp_config_dict['load_monolith_rank0_only'] = True
fsdp_config = FSDPConfig(**fsdp_config_dict)

success = (sync_module_states == True and model_1_init_device == 'cpu')

with (does_not_raise() if success else pytest.raises(ValueError)):
trainer_2 = get_trainer(
model_init_device=model_init_device,
save_folder=os.path.join(save_folder, 'second'),
save_filename=save_filename,
save_interval=save_interval,
fsdp_config=fsdp_config,
precision='amp_fp16',
max_duration='1ep',
load_path=resume_file, # <-- resume training from file
)
trainer_2.fit()
trainer_2.close()

_assert_checkpoints_equivalent(
save_folder / 'first' / final_checkpoint,
save_folder / 'second' / final_checkpoint,
)

0 comments on commit 435c295

Please sign in to comment.