-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add shuffle
option to chainermn.scatter_dataset
#92
Conversation
shuffle
option to chainermn.scatter_dataset
shuffle
option to chainermn.scatter_dataset
Could you also add |
import warnings | ||
|
||
|
||
def scatter_dataset(dataset, comm): | ||
def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add explanation of new arguments to the docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
chainermn/dataset.py
Outdated
|
||
# We cannot use `mpi_comm.scatter`. This is due to MPI4py's bug. | ||
# For large datasets, when using `mpi_comm.scatter`, it causes MemoryError. | ||
if comm.rank == 0: | ||
# import sys | ||
# sys.stderr.write("scatter_dataset(): root={}".format(root)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove unused code 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed!
chainermn/dataset.py
Outdated
if shuffle: | ||
order = numpy.random.RandomState(seed).permutation(n_total_samples) | ||
else: | ||
order = numpy.arange(n_total_samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
order = None
will also work, and it would probably bring (slight) memory/network usage improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reflected the comment. Thanks.
chainermn/dataset.py
Outdated
@@ -15,6 +16,8 @@ def scatter_dataset(dataset, comm): | |||
dataset: A dataset (e.g., ``list``, ``numpy.ndarray``, | |||
``chainer.datasets.TupleDataset``, ...). | |||
comm: ChainerMN communicator or MPI4py communicator. | |||
shuffle: Shuffle the dataset before being scattered. | |||
root: The root process of the scatter operation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Could you add type specification to these two arguments? i.e.,
shuffle (bool)
androot (int)
.dataset
andcomm
accepts various types, and thus we omitted the type specification (following the convention of Chainer, e.g., https://github.com/chainer/chainer/blob/master/chainer/iterators/serial_iterator.py#L26 ) -
Add explanation of
seed
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add explanation.
- I checked Chainer's dataset.py and the sentence of
shuffle
is improved. - Added type specifications to
shuffle
androot
. - Added exaplanation of
seed
Thanks!
Could you also add |
I added |
LGTM! I will merge after travis tests are passed. |
This PR provides
shuffle
optional argument tochainermn.scatter_dataset
.Additionally, it adds
root
option, corresponding to MPI_Scatter'sroot
argument.