Skip to content

Commit

Permalink
Fix skip first batch being perminant (#1466)
Browse files Browse the repository at this point in the history
* Better version of fix

* Failing diff test

* Special str
  • Loading branch information
muellerzr authored May 22, 2023
1 parent bf3cd30 commit 7092089
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 9 deletions.
14 changes: 13 additions & 1 deletion docs/source/usage_guides/checkpoint.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,17 @@ train_dataloader = accelerator.prepare(train_dataloader)
accelerator.load_state("my_state")

# Assume the checkpoint was saved 100 steps into the epoch
accelerator.skip_first_batches(train_dataloader, 100)
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100)

# After the first iteration, go back to `train_dataloader`

# First epoch
for batch in skipped_dataloader:
# Do something
pass

# Second epoch
for batch in train_dataloader:
# Do something
pass
```
7 changes: 5 additions & 2 deletions examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ def training_function(config, args):
# New Code #
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
for step, batch in enumerate(train_dataloader):
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
outputs = model(**batch)
Expand Down
7 changes: 5 additions & 2 deletions examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,12 @@ def training_function(config, args):
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
for batch in train_dataloader:
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for batch in active_dataloader:
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
inputs = (batch["image"] - mean) / std
Expand Down
7 changes: 5 additions & 2 deletions examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,12 @@ def collate_fn(examples):
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
for step, batch in enumerate(train_dataloader):
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
outputs = model(**batch)
Expand Down
10 changes: 8 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,13 +2750,19 @@ def skip_first_batches(self, dataloader, num_batches: int = 0):
>>> accelerator = Accelerator()
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
>>> for input, target in accelerator.skip_first_batches(dataloader, num_batches=2):
>>> skipped_dataloader = accelerator.skip_first_batches(dataloader, num_batches=2)
>>> # for the first epoch only
>>> for input, target in skipped_dataloader:
... optimizer.zero_grad()
... output = model(input)
... loss = loss_func(output, target)
... accelerator.backward(loss)
... optimizer.step()
>>> # subsequent epochs
>>> for input, target in dataloader:
... optimizer.zero_grad()
... ...
```
"""
return skip_first_batches(dataloader, num_batches=num_batches)
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def test_cv_examples(self):
" " * 16 + "},\n\n",
" " * 16 + "step=epoch,\n",
" " * 12,
" " * 8 + "for step, batch in enumerate(active_dataloader):\n",
]
self.one_complete_example("complete_cv_example.py", True, cv_path, special_strings)
self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings)
Expand Down

0 comments on commit 7092089

Please sign in to comment.