Skip to content

Commit

Permalink
Fix dataloader skipping other batches
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed May 22, 2023
1 parent bfa74e5 commit 6d5d7b0
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def __iter__(self):
self.gradient_state._remove_dataloader(self)
if batch_index >= self.skip_batches:
yield current_batch
self.skip_batches = 0
break

@property
Expand Down Expand Up @@ -565,6 +566,7 @@ def __iter__(self):
if stop_iteration:
self.gradient_state._remove_dataloader(self)
self.gradient_state._set_remainder(observed_batch_size)
self.skip_batches = 0
if batch_index >= self.skip_batches:
yield batch
batch_index += 1
Expand Down Expand Up @@ -792,6 +794,7 @@ def __iter__(self):
for index, samples in enumerate(self.batch_sampler):
if index >= self.skip_batches:
yield samples
self.skip_batches = 0

@property
def total_length(self):
Expand Down

0 comments on commit 6d5d7b0

Please sign in to comment.