Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shuffle option to chainermn.scatter_dataset #92

Merged
merged 10 commits into from
Aug 24, 2017
15 changes: 11 additions & 4 deletions chainermn/dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import chainer.datasets
import numpy
import warnings


def scatter_dataset(dataset, comm):
def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add explanation of new arguments to the docstring

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

"""Scatter the given dataset to the workers in the communicator.

The dataset of worker 0 (i.e., the worker whose ``comm.rank`` is 0) is
Expand All @@ -27,15 +28,21 @@ def scatter_dataset(dataset, comm):

# We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug.
# For large datasets, when using `mpi_comm.scatter`, it causes MemoryError.
if comm.rank == 0:
if comm.rank == root:
mine = None
n_total_samples = len(dataset)
n_sub_samples = (n_total_samples + comm.size - 1) // comm.size

if shuffle:
order = numpy.random.RandomState(seed).permutation(n_total_samples)
else:
order = numpy.arange(n_total_samples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order = None will also work, and it would probably bring (slight) memory/network usage improvement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reflected the comment. Thanks.


for i in range(comm.size):
b = n_total_samples * i // comm.size
e = b + n_sub_samples
subds = chainer.datasets.SubDataset(dataset, b, e)
if i == 0:
subds = chainer.datasets.SubDataset(dataset, b, e, order)
if i == root:
mine = subds
else:
comm.send(subds, dest=i)
Expand Down
28 changes: 16 additions & 12 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def setUp(self):
self.mpi_comm = mpi4py.MPI.COMM_WORLD
self.communicator = NaiveCommunicator(self.mpi_comm)

def check_scatter_dataset(self, original_dataset):
def check_scatter_dataset(self, original_dataset, shuffle=False, root=0):
my_dataset = chainermn.scatter_dataset(
original_dataset, self.communicator)
original_dataset, self.communicator,
shuffle=shuffle, root=root)
sub_datasets = self.mpi_comm.gather(my_dataset)

if self.mpi_comm.rank == 0:
if self.mpi_comm.rank == root:
# Test the sizes
sub_sizes = [len(sub_dataset) for sub_dataset in sub_datasets]
self.assertEqual(len(set(sub_sizes)), 1)
Expand All @@ -36,12 +37,15 @@ def check_scatter_dataset(self, original_dataset):
def test_scatter_dataset(self):
n = self.communicator.size

self.check_scatter_dataset([])
self.check_scatter_dataset([0])
self.check_scatter_dataset(list(range(n)))
self.check_scatter_dataset(list(range(n * 5 - 1)))

self.check_scatter_dataset(np.array([]))
self.check_scatter_dataset(np.array([0]))
self.check_scatter_dataset(np.arange(n))
self.check_scatter_dataset(np.arange(n * 5 - 1))
for shuffle in [True, False]:
for root in range(self.communicator.size):
self.check_scatter_dataset([], root, shuffle)
self.check_scatter_dataset([0], root, shuffle)
self.check_scatter_dataset(list(range(n)), root, shuffle)
self.check_scatter_dataset(list(range(n * 5 - 1)),
root, shuffle)

self.check_scatter_dataset(np.array([]), root, shuffle)
self.check_scatter_dataset(np.array([0]), root, shuffle)
self.check_scatter_dataset(np.arange(n), root, shuffle)
self.check_scatter_dataset(np.arange(n * 5 - 1), root, shuffle)
4 changes: 2 additions & 2 deletions tests/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def test_mnist(self, display_log=True):
else:
train, test = None, None

train = chainermn.scatter_dataset(train, comm)
test = chainermn.scatter_dataset(test, comm)
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm, shuffle=True)

train_iter = chainer.iterators.SerialIterator(train, batchsize)
test_iter = chainer.iterators.SerialIterator(test, batchsize,
Expand Down