diff --git a/composer/core/state.py b/composer/core/state.py index 67fb7e493e..4506bdeec9 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -474,11 +474,6 @@ def __init__( if self.load_monolith_rank0_only: assert fsdp_config is not None error_message = '' - if fsdp_config['use_orig_params'] == True: - error_message += textwrap.dedent( - "load_monolith_rank0_only requires fsdp_config['use_orig_params'] to be False. " - "Either set fsdp_config['use_orig_params'] = False or set load_monolith_rank0_only = False. ", - ) if fsdp_config['sync_module_states'] == False: error_message += textwrap.dedent( "load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. " diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index c4820e2658..b903134ab5 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -1511,106 +1511,6 @@ def test_set_dataloaders_to_cur_epoch( # Epoch count starts at O assert trainer.state.train_dataloader.batch_sampler.epoch == max_duration - 1 - @pytest.mark.parametrize( - 'world_size', - [ - pytest.param(2, marks=pytest.mark.world_size(2)), - ], - ) - @pytest.mark.parametrize( - 'device', - [ - pytest.param('gpu', marks=pytest.mark.gpu), - ], - ) - @pytest.mark.parametrize( - 'use_orig_params,sync_module_states,model_1_init_device,model_2_init_device', - [ - pytest.param(False, True, 'cpu', 'cpu'), # success - pytest.param(False, True, 'cpu', 'meta'), # success - pytest.param(True, True, 'cpu', 'cpu'), # fail - pytest.param(False, False, 'cpu', 'cpu'), # fail - pytest.param(False, True, 'meta', 'cpu'), # fail - ], - ) - @pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*') - @pytest.mark.filterwarnings( - 'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*', - ) - def test_fsdp_monolith_resumption( - self, - device: str, - world_size: int, - use_orig_params: bool, - sync_module_states: bool, - model_1_init_device: str, - model_2_init_device: str, - tmp_path: pathlib.Path, - ): - save_interval = '1ba' - save_filename = 'ba{batch}-rank{rank}.pt' - resume_file = 'ba1-rank{rank}.pt' - final_checkpoint = 'latest-rank{rank}.pt' - fsdp_config = { - 'use_orig_params': use_orig_params, - 'sync_module_states': sync_module_states, - 'state_dict_type': 'full', - } - - # All ranks use rank 0 folder - tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path)) - save_folder = pathlib.Path(tmp_paths[0]) - - trainer_1 = self.get_trainer( - save_folder=os.path.join(save_folder, 'first'), - save_filename=save_filename, - save_interval=save_interval, - eval_interval=save_interval, - fsdp_config=fsdp_config, - device=device, - precision='amp_fp16', - max_duration='1ep', - train_subset_num_batches=2, - ) - - trainer_1.fit() - trainer_1.close() - - self._assert_expected_num_checkpoints( - save_folder=os.path.join(save_folder, 'first'), - save_interval=save_interval, - num_epochs=1, # set in get_trainer() - num_batches_per_epoch=2, # set in get_trainer() - is_deepspeed=False, - ) - - resume_file = os.path.join(save_folder, 'first', resume_file) - model_init_device = [model_1_init_device, model_2_init_device][dist.get_global_rank()] - fsdp_config['load_monolith_rank0_only'] = True - - success = use_orig_params == False and sync_module_states == True and model_1_init_device == 'cpu' - with contextlib.nullcontext() if success else pytest.raises(ValueError): - trainer_2 = self.get_trainer( - model_init_device=model_init_device, - save_folder=os.path.join(save_folder, 'second'), - save_filename=save_filename, - save_interval=save_interval, - eval_interval=save_interval, - fsdp_config=fsdp_config, - device=device, - precision='amp_fp16', - max_duration='1ep', - train_subset_num_batches=2, - load_path=resume_file, # <-- resume training from file - ) - trainer_2.fit() - trainer_2.close() - - _assert_checkpoints_equivalent( - save_folder / 'first' / final_checkpoint, - save_folder / 'second' / final_checkpoint, - ) - @pytest.mark.parametrize('spin_dataloaders', [False, True]) def test_spin_dataloaders( self, @@ -1674,8 +1574,8 @@ def test_format_load_path(self, tmp_path: pathlib.Path): os.path.join(save_folder, 'second', 'latest-rank{rank}.pt'), ) + @staticmethod def _assert_expected_num_checkpoints( - self, save_folder: str, save_interval: str, num_epochs: int, diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 9457032e81..93c60d8e97 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -37,6 +37,7 @@ from tests.common import RandomClassificationDataset, deep_compare from tests.common.compare import deep_compare from tests.common.markers import world_size +from tests.trainer.test_checkpoint import TestCheckpointResumption, _assert_checkpoints_equivalent # This model is to be used explicitly for this unit test because some old reference checkpoints @@ -120,7 +121,7 @@ def get_trainer( train_metrics=train_metrics, val_metrics=val_metrics, ) - model.to(model_init_device) + model.module.to(model_init_device) dataset = RandomClassificationDataset(shape=(num_features,), size=128) dataloader = DataLoader( dataset, @@ -325,7 +326,6 @@ def test_fsdp_full_state_dict_load( autoresume=autoresume, optimizer=optimizer, fsdp_config=fsdp_config, - save_weights_only=save_weights_only, ) trainer1.fit() state_dict_from_trainer1 = trainer1.state.state_dict() @@ -1127,3 +1127,83 @@ def set_up_planner( trainer2.fit() trainer2.close() + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('use_orig_params', [True, False]) +@pytest.mark.parametrize('sync_module_states', [True, False]) +@pytest.mark.parametrize('model_1_init_device', ['cpu', 'meta']) +@pytest.mark.parametrize('model_2_init_device', ['cpu', 'meta']) +@pytest.mark.filterwarnings('ignore:An unexpected prefix is detected. This case.*') +@pytest.mark.filterwarnings( + 'ignore:``FullyShardedDataParallel.scatter_full_optim_state_dict``is being deprecated and is replaced by.*', +) +def test_fsdp_monolith_resumption( + world_size: int, + use_orig_params: bool, + sync_module_states: bool, + tmp_path: pathlib.Path, + model_1_init_device: str, + model_2_init_device: str, +): + save_interval = '1ba' + save_filename = 'ba{batch}-rank{rank}.pt' + resume_file = 'ba1-rank{rank}.pt' + final_checkpoint = 'latest-rank{rank}.pt' + fsdp_config = FSDPConfig( + use_orig_params=use_orig_params, + sync_module_states=sync_module_states, + state_dict_type='full', + ) + + # All ranks use rank 0 folder + tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path)) + save_folder = pathlib.Path(tmp_paths[0]) + + trainer_1 = get_trainer( + save_folder=os.path.join(save_folder, 'first'), + save_filename=save_filename, + save_interval=save_interval, + fsdp_config=fsdp_config, + precision='amp_fp16', + max_duration='1ep', + ) + + trainer_1.fit() + trainer_1.close() + + TestCheckpointResumption._assert_expected_num_checkpoints( + save_folder=os.path.join(save_folder, 'first'), + save_interval=save_interval, + num_epochs=1, # set in get_trainer() + num_batches_per_epoch=8, # set in get_trainer() + is_deepspeed=False, + ) + + resume_file = os.path.join(save_folder, 'first', resume_file) + model_init_device = [model_1_init_device, model_2_init_device][dist.get_global_rank()] + fsdp_config_dict = dataclasses.asdict(fsdp_config) + fsdp_config_dict['load_monolith_rank0_only'] = True + fsdp_config = FSDPConfig(**fsdp_config_dict) + + success = (sync_module_states == True and model_1_init_device == 'cpu') + + with (does_not_raise() if success else pytest.raises(ValueError)): + trainer_2 = get_trainer( + model_init_device=model_init_device, + save_folder=os.path.join(save_folder, 'second'), + save_filename=save_filename, + save_interval=save_interval, + fsdp_config=fsdp_config, + precision='amp_fp16', + max_duration='1ep', + load_path=resume_file, # <-- resume training from file + ) + trainer_2.fit() + trainer_2.close() + + _assert_checkpoints_equivalent( + save_folder / 'first' / final_checkpoint, + save_folder / 'second' / final_checkpoint, + )