Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed May 22, 2023
1 parent 6d5d7b0 commit 5dad64f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def __iter__(self):
for index, batch in enumerate(super().__iter__()):
if index >= self.skip_batches:
yield batch
self.skip_batches = 0


def skip_first_batches(dataloader, num_batches=0):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,19 @@ def test_skip_batch_sampler(self):
batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False)
new_batch_sampler = SkipBatchSampler(batch_sampler, 2)
self.assertListEqual(list(new_batch_sampler), [[8, 9, 10, 11], [12, 13, 14, 15]])
self.assertListEqual(list(new_batch_sampler), [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]])

def test_skip_data_loader(self):
dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2)
self.assertListEqual([t.tolist() for t in dataloader], [[8, 9, 10, 11], [12, 13, 14, 15]])
self.assertListEqual(
[t.tolist() for t in dataloader], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
)

def test_skip_first_batches(self):
dataloader = DataLoader(list(range(16)), batch_size=4)
new_dataloader = skip_first_batches(dataloader, num_batches=2)
self.assertListEqual([t.tolist() for t in new_dataloader], [[8, 9, 10, 11], [12, 13, 14, 15]])
self.assertListEqual(
[t.tolist() for t in new_dataloader], [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
)

0 comments on commit 5dad64f

Please sign in to comment.