Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch_sampler maybe None error #3025

Merged
merged 6 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
self.iteration = epoch
if hasattr(self.batch_sampler.sampler, "set_epoch"):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(epoch)
elif hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import random
import unittest

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

from accelerate import Accelerator
Expand Down Expand Up @@ -44,6 +45,21 @@ def __iter__(self):
stop = random.random() < self.p_stop


class SimpleIterableDataset(IterableDataset):
def __init__(self, num_samples=1000):
self.num_samples = num_samples

def __iter__(self):
for _ in range(self.num_samples):
yield torch.rand(1)

def __len__(self):
return self.num_samples

def set_epoch(self, epoch):
self.epoch = epoch


class DataLoaderTester(unittest.TestCase):
def check_batch_sampler_shards(self, batch_sampler, expected, split_batches=False, even_batches=True):
batch_sampler_shards = [
Expand Down Expand Up @@ -364,6 +380,14 @@ def test_iterable_dataset_shard(self):
self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=False, split_batches=True)
self.check_iterable_dataset_shards(dataset, seed, batch_size=4, drop_last=True, split_batches=True)

def test_iterable_dataset_using_none_batch_size(self):
dataset = SimpleIterableDataset(100)
accelerator = Accelerator()
dataloader = DataLoader(dataset, batch_size=None)
dataloader = accelerator.prepare(dataloader)
for d in dataloader:
assert isinstance(d, torch.Tensor)

def test_skip_batch_sampler(self):
batch_sampler = BatchSampler(range(16), batch_size=4, drop_last=False)
new_batch_sampler = SkipBatchSampler(batch_sampler, 2)
Expand Down
Loading