-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shuffle
option to chainermn.scatter_dataset
#92
Changes from 7 commits
860725c
d04bb23
a7ed202
609fa2d
9660899
a13db87
74e6b09
d3233f8
246482a
9830897
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import chainer.datasets | ||
import numpy | ||
import warnings | ||
|
||
|
||
def scatter_dataset(dataset, comm): | ||
def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): | ||
"""Scatter the given dataset to the workers in the communicator. | ||
|
||
The dataset of worker 0 (i.e., the worker whose ``comm.rank`` is 0) is | ||
|
@@ -15,6 +16,8 @@ def scatter_dataset(dataset, comm): | |
dataset: A dataset (e.g., ``list``, ``numpy.ndarray``, | ||
``chainer.datasets.TupleDataset``, ...). | ||
comm: ChainerMN communicator or MPI4py communicator. | ||
shuffle: Shuffle the dataset before being scattered. | ||
root: The root process of the scatter operation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add explanation.
|
||
|
||
Returns: | ||
Scattered dataset. | ||
|
@@ -24,24 +27,33 @@ def scatter_dataset(dataset, comm): | |
comm = comm.mpi_comm | ||
assert hasattr(comm, 'send') | ||
assert hasattr(comm, 'recv') | ||
assert 0 <= root and root < comm.size | ||
|
||
# We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug. | ||
# For large datasets, when using `mpi_comm.scatter`, it causes MemoryError. | ||
if comm.rank == 0: | ||
# import sys | ||
# sys.stderr.write("scatter_dataset(): root={}".format(root)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove unused code 😉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed! |
||
if comm.rank == root: | ||
mine = None | ||
n_total_samples = len(dataset) | ||
n_sub_samples = (n_total_samples + comm.size - 1) // comm.size | ||
|
||
if shuffle: | ||
order = numpy.random.RandomState(seed).permutation(n_total_samples) | ||
else: | ||
order = numpy.arange(n_total_samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reflected the comment. Thanks. |
||
|
||
for i in range(comm.size): | ||
b = n_total_samples * i // comm.size | ||
e = b + n_sub_samples | ||
subds = chainer.datasets.SubDataset(dataset, b, e) | ||
if i == 0: | ||
subds = chainer.datasets.SubDataset(dataset, b, e, order) | ||
if i == root: | ||
mine = subds | ||
else: | ||
comm.send(subds, dest=i) | ||
return mine | ||
else: | ||
return comm.recv(source=0) | ||
return comm.recv(source=root) | ||
|
||
|
||
def get_n_iterations_for_one_epoch(dataset, local_batch_size, comm): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add explanation of new arguments to the docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!