Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shuffle option to chainermn.scatter_dataset #92

Merged
merged 10 commits into from
Aug 24, 2017
22 changes: 17 additions & 5 deletions chainermn/dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import chainer.datasets
import numpy
import warnings


def scatter_dataset(dataset, comm):
def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add explanation of new arguments to the docstring

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

"""Scatter the given dataset to the workers in the communicator.

The dataset of worker 0 (i.e., the worker whose ``comm.rank`` is 0) is
Expand All @@ -15,6 +16,8 @@ def scatter_dataset(dataset, comm):
dataset: A dataset (e.g., ``list``, ``numpy.ndarray``,
``chainer.datasets.TupleDataset``, ...).
comm: ChainerMN communicator or MPI4py communicator.
shuffle: Shuffle the dataset before being scattered.
root: The root process of the scatter operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add explanation.

  • I checked Chainer's dataset.py and the sentence of shuffle is improved.
  • Added type specifications to shuffle and root.
  • Added exaplanation of seed
    Thanks!


Returns:
Scattered dataset.
Expand All @@ -24,24 +27,33 @@ def scatter_dataset(dataset, comm):
comm = comm.mpi_comm
assert hasattr(comm, 'send')
assert hasattr(comm, 'recv')
assert 0 <= root and root < comm.size

# We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug.
# For large datasets, when using `mpi_comm.scatter`, it causes MemoryError.
if comm.rank == 0:
# import sys
# sys.stderr.write("scatter_dataset(): root={}".format(root))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove unused code 😉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed!

if comm.rank == root:
mine = None
n_total_samples = len(dataset)
n_sub_samples = (n_total_samples + comm.size - 1) // comm.size

if shuffle:
order = numpy.random.RandomState(seed).permutation(n_total_samples)
else:
order = numpy.arange(n_total_samples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order = None will also work, and it would probably bring (slight) memory/network usage improvement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reflected the comment. Thanks.


for i in range(comm.size):
b = n_total_samples * i // comm.size
e = b + n_sub_samples
subds = chainer.datasets.SubDataset(dataset, b, e)
if i == 0:
subds = chainer.datasets.SubDataset(dataset, b, e, order)
if i == root:
mine = subds
else:
comm.send(subds, dest=i)
return mine
else:
return comm.recv(source=0)
return comm.recv(source=root)


def get_n_iterations_for_one_epoch(dataset, local_batch_size, comm):
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist/train_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def main():
train, test = chainer.datasets.get_mnist()
else:
train, test = None, None
train = chainermn.scatter_dataset(train, comm)
test = chainermn.scatter_dataset(test, comm)
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm, shuffle=True)

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
Expand Down
30 changes: 17 additions & 13 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def setUp(self):
self.mpi_comm = mpi4py.MPI.COMM_WORLD
self.communicator = NaiveCommunicator(self.mpi_comm)

def check_scatter_dataset(self, original_dataset):
def check_scatter_dataset(self, original_dataset, shuffle=False, root=0):
my_dataset = chainermn.scatter_dataset(
original_dataset, self.communicator)
sub_datasets = self.mpi_comm.gather(my_dataset)
original_dataset, self.communicator,
shuffle=shuffle, root=root)
sub_datasets = self.mpi_comm.gather(my_dataset, root=root)

if self.mpi_comm.rank == 0:
if self.mpi_comm.rank == root:
# Test the sizes
sub_sizes = [len(sub_dataset) for sub_dataset in sub_datasets]
self.assertEqual(len(set(sub_sizes)), 1)
Expand All @@ -36,12 +37,15 @@ def check_scatter_dataset(self, original_dataset):
def test_scatter_dataset(self):
n = self.communicator.size

self.check_scatter_dataset([])
self.check_scatter_dataset([0])
self.check_scatter_dataset(list(range(n)))
self.check_scatter_dataset(list(range(n * 5 - 1)))

self.check_scatter_dataset(np.array([]))
self.check_scatter_dataset(np.array([0]))
self.check_scatter_dataset(np.arange(n))
self.check_scatter_dataset(np.arange(n * 5 - 1))
for shuffle in [True, False]:
for root in range(self.communicator.size):
self.check_scatter_dataset([], shuffle, root)
self.check_scatter_dataset([0], shuffle, root)
self.check_scatter_dataset(list(range(n)), shuffle, root)
self.check_scatter_dataset(list(range(n * 5 - 1)),
shuffle, root)

self.check_scatter_dataset(np.array([]), shuffle, root)
self.check_scatter_dataset(np.array([0]), shuffle, root)
self.check_scatter_dataset(np.arange(n), shuffle, root)
self.check_scatter_dataset(np.arange(n * 5 - 1), shuffle, root)
4 changes: 2 additions & 2 deletions tests/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def test_mnist(self, display_log=True):
else:
train, test = None, None

train = chainermn.scatter_dataset(train, comm)
test = chainermn.scatter_dataset(test, comm)
train = chainermn.scatter_dataset(train, comm, shuffle=True)
test = chainermn.scatter_dataset(test, comm, shuffle=True)

train_iter = chainer.iterators.SerialIterator(train, batchsize)
test_iter = chainer.iterators.SerialIterator(test, batchsize,
Expand Down