Skip to content

Commit

Permalink
Merge pull request #92 from chainer/add-shuffled-scatter
Browse files Browse the repository at this point in the history
Add `shuffle` option to `chainermn.scatter_dataset`
  • Loading branch information
iwiwi authored Aug 24, 2017
2 parents da39cb0 + 9830897 commit ede666c
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 23 deletions.
25 changes: 20 additions & 5 deletions chainermn/dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import chainer.datasets
import numpy
import warnings


def scatter_dataset(dataset, comm):
def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None):
"""Scatter the given dataset to the workers in the communicator.
The dataset of worker 0 (i.e., the worker whose ``comm.rank`` is 0) is
Expand All @@ -15,6 +16,14 @@ def scatter_dataset(dataset, comm):
dataset: A dataset (e.g., ``list``, ``numpy.ndarray``,
``chainer.datasets.TupleDataset``, ...).
comm: ChainerMN communicator or MPI4py communicator.
shuffle (bool): If ``True``, the order of examples is shuffled
before being scattered.
root (int): The root process of the scatter operation.
seed (int): Seed the generator used for the permutation of indexes.
If an integer being convertible to 32 bit unsigned integers is
specified, it is guaranteed that each sample
in the given dataset always belongs to a specific subset.
If ``None``, the permutation is changed randomly.
Returns:
Scattered dataset.
Expand All @@ -24,24 +33,30 @@ def scatter_dataset(dataset, comm):
comm = comm.mpi_comm
assert hasattr(comm, 'send')
assert hasattr(comm, 'recv')
assert 0 <= root and root < comm.size

# We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug.
# For large datasets, when using `mpi_comm.scatter`, it causes MemoryError.
if comm.rank == 0:
if comm.rank == root:
mine = None
n_total_samples = len(dataset)
n_sub_samples = (n_total_samples + comm.size - 1) // comm.size
order = None

if shuffle:
order = numpy.random.RandomState(seed).permutation(n_total_samples)

for i in range(comm.size):
b = n_total_samples * i // comm.size
e = b + n_sub_samples
subds = chainer.datasets.SubDataset(dataset, b, e)
if i == 0:
subds = chainer.datasets.SubDataset(dataset, b, e, order)
if i == root:
mine = subds
else:
comm.send(subds, dest=i)
return mine
else:
return comm.recv(source=0)
return comm.recv(source=root)


def get_n_iterations_for_one_epoch(dataset, local_batch_size, comm):
Expand Down
2 changes: 1 addition & 1 deletion examples/imagenet/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def main():
else:
train = None
val = None
train = chainermn.scatter_dataset(train, comm)
train = chainermn.scatter_dataset(train, comm, shuffle=True)
val = chainermn.scatter_dataset(val, comm)

# We need to change the start method of multiprocessing module if we are
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist/train_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def main():
train, test = chainer.datasets.get_mnist()
else:
train, test = None, None
train = chainermn.scatter_dataset(train, comm)
test = chainermn.scatter_dataset(test, comm)
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm, shuffle=True)

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
Expand Down
30 changes: 17 additions & 13 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def setUp(self):
self.mpi_comm = mpi4py.MPI.COMM_WORLD
self.communicator = NaiveCommunicator(self.mpi_comm)

def check_scatter_dataset(self, original_dataset):
def check_scatter_dataset(self, original_dataset, shuffle=False, root=0):
my_dataset = chainermn.scatter_dataset(
original_dataset, self.communicator)
sub_datasets = self.mpi_comm.gather(my_dataset)
original_dataset, self.communicator,
shuffle=shuffle, root=root)
sub_datasets = self.mpi_comm.gather(my_dataset, root=root)

if self.mpi_comm.rank == 0:
if self.mpi_comm.rank == root:
# Test the sizes
sub_sizes = [len(sub_dataset) for sub_dataset in sub_datasets]
self.assertEqual(len(set(sub_sizes)), 1)
Expand All @@ -36,12 +37,15 @@ def check_scatter_dataset(self, original_dataset):
def test_scatter_dataset(self):
n = self.communicator.size

self.check_scatter_dataset([])
self.check_scatter_dataset([0])
self.check_scatter_dataset(list(range(n)))
self.check_scatter_dataset(list(range(n * 5 - 1)))

self.check_scatter_dataset(np.array([]))
self.check_scatter_dataset(np.array([0]))
self.check_scatter_dataset(np.arange(n))
self.check_scatter_dataset(np.arange(n * 5 - 1))
for shuffle in [True, False]:
for root in range(self.communicator.size):
self.check_scatter_dataset([], shuffle, root)
self.check_scatter_dataset([0], shuffle, root)
self.check_scatter_dataset(list(range(n)), shuffle, root)
self.check_scatter_dataset(list(range(n * 5 - 1)),
shuffle, root)

self.check_scatter_dataset(np.array([]), shuffle, root)
self.check_scatter_dataset(np.array([0]), shuffle, root)
self.check_scatter_dataset(np.arange(n), shuffle, root)
self.check_scatter_dataset(np.arange(n * 5 - 1), shuffle, root)
4 changes: 2 additions & 2 deletions tests/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def test_mnist(self, display_log=True):
else:
train, test = None, None

train = chainermn.scatter_dataset(train, comm)
test = chainermn.scatter_dataset(test, comm)
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm, shuffle=True)

train_iter = chainer.iterators.SerialIterator(train, batchsize)
test_iter = chainer.iterators.SerialIterator(test, batchsize,
Expand Down

0 comments on commit ede666c

Please sign in to comment.