Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix skip first batch being perminant #1466

Merged
merged 3 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion docs/source/usage_guides/checkpoint.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,17 @@ train_dataloader = accelerator.prepare(train_dataloader)
accelerator.load_state("my_state")

# Assume the checkpoint was saved 100 steps into the epoch
accelerator.skip_first_batches(train_dataloader, 100)
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100)

# After the first iteration, go back to `train_dataloader`

# First epoch
for batch in skipped_dataloader:
# Do something
pass

# Second epoch
for batch in train_dataloader:
# Do something
pass
```
7 changes: 5 additions & 2 deletions examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ def training_function(config, args):
# New Code #
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
for step, batch in enumerate(train_dataloader):
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
outputs = model(**batch)
Expand Down
7 changes: 5 additions & 2 deletions examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,12 @@ def training_function(config, args):
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
for batch in train_dataloader:
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for batch in active_dataloader:
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
inputs = (batch["image"] - mean) / std
Expand Down
7 changes: 5 additions & 2 deletions examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,12 @@ def collate_fn(examples):
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
for step, batch in enumerate(train_dataloader):
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
outputs = model(**batch)
Expand Down
10 changes: 8 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,13 +2750,19 @@ def skip_first_batches(self, dataloader, num_batches: int = 0):

>>> accelerator = Accelerator()
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)

>>> for input, target in accelerator.skip_first_batches(dataloader, num_batches=2):
>>> skipped_dataloader = accelerator.skip_first_batches(dataloader, num_batches=2)
>>> # for the first epoch only
>>> for input, target in skipped_dataloader:
... optimizer.zero_grad()
... output = model(input)
... loss = loss_func(output, target)
... accelerator.backward(loss)
... optimizer.step()

>>> # subsequent epochs
>>> for input, target in dataloader:
... optimizer.zero_grad()
... ...
```
"""
return skip_first_batches(dataloader, num_batches=num_batches)
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def test_cv_examples(self):
" " * 16 + "},\n\n",
" " * 16 + "step=epoch,\n",
" " * 12,
" " * 8 + "for step, batch in enumerate(active_dataloader):\n",
]
self.one_complete_example("complete_cv_example.py", True, cv_path, special_strings)
self.one_complete_example("complete_cv_example.py", False, cv_path, special_strings)
Expand Down