diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index d8acf69539c..2e0cc71b2e0 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -825,6 +825,7 @@ def __iter__(self): for index, batch in enumerate(super().__iter__()): if index >= self.skip_batches: yield batch + self.skip_batches = 0 def skip_first_batches(dataloader, num_batches=0): diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 70600ea17ef..079423f50ca 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -365,12 +365,19 @@ def test_skip_batch_sampler(self): batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False) new_batch_sampler = SkipBatchSampler(batch_sampler, 2) self.assertListEqual(list(new_batch_sampler), [[8, 9, 10, 11], [12, 13, 14, 15]]) + self.assertListEqual(list(new_batch_sampler), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) def test_skip_data_loader(self): dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2) self.assertListEqual([t.tolist() for t in dataloader], [[8, 9, 10, 11], [12, 13, 14, 15]]) + self.assertListEqual( + [t.tolist() for t in dataloader], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] + ) def test_skip_first_batches(self): dataloader = DataLoader(list(range(16)), batch_size=4) new_dataloader = skip_first_batches(dataloader, num_batches=2) self.assertListEqual([t.tolist() for t in new_dataloader], [[8, 9, 10, 11], [12, 13, 14, 15]]) + self.assertListEqual( + [t.tolist() for t in new_dataloader], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] + )