Skip to content

Commit

Permalink
fix resuming from checkpoint (#2001)
Browse files Browse the repository at this point in the history
  • Loading branch information
SumanthRH authored Sep 29, 2023
1 parent 80da9cf commit 658492f
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions examples/by_feature/deepspeed_with_config_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,15 +602,22 @@ def group_texts(examples):
resume_step -= starting_epoch * num_update_steps_per_epoch
completed_steps = resume_step

# update progress bar if resumed from checkpoint
progress_bar.update(completed_steps)

for epoch in range(starting_epoch, args.num_train_epochs):
model.train()
if args.with_tracking:
total_loss = 0

# skip new `skip_first_batches` to skip the batches when resuming from ckpt
if args.resume_from_checkpoint:
train_dataloader = accelerator.skip_first_batches(train_dataloader, num_batches=resume_step)
for step, batch in enumerate(train_dataloader):
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
# In particular, DeepSpeed handles `gradient_accumulation` via `DeepSpeedEngine`.
# Below, we use `accelerator.accumulate` if the user
# wants to switch to other approaches such as plain DDP, PyTorch FSDP ...
Expand Down

0 comments on commit 658492f

Please sign in to comment.