Skip to content

Commit

Permalink
Avoid duplicated checkpoint save (#7555) (#7566)
Browse files Browse the repository at this point in the history
Signed-off-by: Mikołaj Błaż <[email protected]>
Co-authored-by: mikolajblaz <[email protected]>
  • Loading branch information
github-actions[bot] and mikolajblaz authored Oct 11, 2023
1 parent 67dc816 commit 7e5bce4
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ def save_checkpoint(
checkpoint_dir = ckpt_to_dir(filepath)

fs = get_filesystem(checkpoint_dir)
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving')
return

if is_global_rank_zero():
fs.makedirs(checkpoint_dir, exist_ok=True)

Expand Down Expand Up @@ -477,19 +481,24 @@ def save_to(self, model, save_path: str):
# model weights is a directory
dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt))
fs = get_filesystem(dist_ckpt_dir)
if is_global_rank_zero():
fs.makedirs(dist_ckpt_dir, exist_ok=True)
sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if parallel_state.is_unitialized():

def dummy():
return

if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=dist_ckpt_dir)

if fs.isdir(dist_ckpt_dir) and dist_checkpointing.check_is_distributed_checkpoint(dist_ckpt_dir):
logging.info(f'Distributed checkpoint at path {dist_ckpt_dir} already exists, skipping saving')
else:
if is_global_rank_zero():
fs.makedirs(dist_ckpt_dir, exist_ok=True)

sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if parallel_state.is_unitialized():

def dummy():
return

if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=dist_ckpt_dir)

else:

Expand Down

0 comments on commit 7e5bce4

Please sign in to comment.