-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model-parallel MNIST example #98
Conversation
README.md
Outdated
@@ -38,7 +38,7 @@ Please refer to the [installation guide](https://chainermn.readthedocs.io/en/lat | |||
You can invoke MNIST example with four workers by the following command: | |||
|
|||
``` | |||
mpiexec -n 4 python examples/mnist/train_mnist.py | |||
mpiexec -n 4 python examples/mnist/train_mnist_data_parallel.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid we should not change this line.
Data parallel model is an advanced feature and users should first see the basic example of MNIST.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it's confusing if train_mnist.py
and train_mnist_model_parallel.py
exist at the same time, so renamed the original train_mnist.py
to train_mnist_data_parallel.py
, in order to explicitly assert this example is for data parallel.
chainermn/link.py
Outdated
import chainermn | ||
|
||
|
||
class SimpleModelInst(chainer.Chain): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we think of a better name? Inst
sounds like an instance of a class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thank you.
#!/usr/bin/env python | ||
# coding: utf-8 | ||
|
||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to PEP8,
Imports should be grouped in the following order:
standard library imports
related third party imports
local application/library specific imports
You should put a blank line between each group of imports.
https://www.python.org/dev/peps/pep-0008/#imports
So we need a blank line after argparse
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am not sure if we need a blank line between Chainer and ChainerMN imports.
@@ -0,0 +1,29 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unittest
should be the first because it's a standard library, while numpy
is a third party lib.
@@ -6,6 +6,7 @@ | |||
import chainer.testing.attr | |||
import chainermn | |||
import chainermn.functions | |||
import copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pep8.
tests/test_link.py
Outdated
err = model() | ||
err.backward() | ||
|
||
def test_cross_model(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
crossing_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
chainermn/link.py
Outdated
if backward_pointer is not None and _x.creator is not None: | ||
_x.creator.rank = -1 | ||
|
||
x = _x if x is None else x + _x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it is reasonable design to sum up received arrays when receiving from multiple workers. I think it would be more natural and useful to hand it to f
as different parameters.
chainermn/datasets/empty_dataset.py
Outdated
import chainer | ||
|
||
|
||
def get_empty_dataset(dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about changing its name to create_empty_dataset
, which is consistent to create_multi_node_optimizer
and create_multi_node_evaluator
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
chainermn/link.py
Outdated
import chainermn.functions.point_to_point_communication | ||
|
||
|
||
class MultiNodeChainGroup(chainer.ChainList): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any reasons for the name? I don't have strong opinion, but, how about changing its name to MultiNodeChainList
, which is consistent to chainer.ChainList
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
chainermn/link.py
Outdated
import chainermn | ||
|
||
|
||
class SimpleModelInst(chainer.Chain): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't understand what Inst
does mean. How about SimpleModelSub
or something like that?
self.add_link(MLP0b(comm), rank_in=1, rank_out=None) | ||
|
||
|
||
class MLP1inst(chainer.Chain): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for inst
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
chainermn/datasets/empty_dataset.py
Outdated
``chainermn.functions.send()``. | ||
|
||
Args: | ||
dataset(chainer.datasets.TupleDataset): Dataset to convert. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset
does not need to be TupleDataset
. Chainer accepts many kinds of datasets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chainermn/datasets/empty_dataset.py
Outdated
~chainer.datasets.TransformDataset: | ||
Dataset consists of only patterns in the original one. | ||
""" | ||
return chainer.datasets.TransformDataset(dataset, lambda data: ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably just [()] * len(dataset)
is enough? (TransformDataset
preserves dataset
, so it will consume unnecessary memory.)
|
||
|
||
class PseudoConnect(chainer.Function): | ||
"""Connect a variable with delegating variable.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Connect two variables with a delegating variable"
or
"Connect a variable to a delegating variable" ?
"""Receive elements from target process. | ||
|
||
This function returns data received from target process. If ``backward()`` | ||
is invoked, it will try to send gradients to the target process. | ||
|
||
.. note:: | ||
If you define non-connected computational graph on one machine, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
machine -> process
@@ -100,3 +102,31 @@ def test_communication(self): | |||
err = chainermn.functions.send( | |||
y, self.communicator, self.rank_send) | |||
err.backward() | |||
|
|||
def test_retain(self): | |||
if self.communicator.rank == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this test work if more than 2 processes invoked?
It should be skipped if communicator.size > 2 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test also works on more than 2 processes. It emulates test_cycle_model
in test_link.py
. FYI, we run test on 3 processes in the latest commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK!
LGTM |
LGTM! |
This PR adds some user-friendly interfaces to implement model-parallel neural nets, with model-parallel MNIST example.
New Features
MultiNodeChainGroup (
chainermn.MultiNodeChainGroup
)This variant of
chainer.ChainList
represents multiple connected component of the entire computational graph.In case of multi-node computation, computational graphs often become non-connected.
Each of connected component is represented by
chainer.Chain
, andMultiNodeChainGroup
combines them.new utility function
chainermn.functions.pseudo_connect
This function is used when we want "delegate_variable" to imitate the other variable.
This kind of pathological motivation occurs in multi node environment.
Please see the document of this function for detail.
empty dataset (
chainermn.datasets.get_empty_dataset
)It is used for the model with no actual inputs, rather receiving inputs from the other machine.