-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model-parallel MNIST example #98
Changes from 44 commits
2097638
ccaebe1
1e74c5d
f751651
7cc8cb9
d2a24f0
d948187
1c862a9
7fcff10
ff73511
1501b9d
774aac2
360d1af
b221fb6
36d6df3
b3e4802
997f4ff
e320ae8
69beef0
31d4abb
f34e3e5
3c5b957
a9a0bb3
66989b5
516d7ff
bc8ffe5
7904d23
c01d22f
74e0ecb
94eb451
a5599d5
d2a882f
ac27f1e
a64e5ba
0a0b26a
ecb4538
20814e5
4e35f6a
8f04184
9215d81
c521364
6b0dd6d
f79059e
f4633a8
56d2375
20c09f8
35cd1d2
b4cb4a8
f6baf69
b4ccdcd
5a9cdd2
592fe95
cd0d7d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from chainermn.datasets.empty_dataset import create_empty_dataset # NOQA |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import chainer | ||
|
||
|
||
def create_empty_dataset(dataset): | ||
"""Creates an empty dataset for models with no inputs and outputs. | ||
|
||
This function generates an empty dataset, i.e., ``__getitem__()`` only | ||
returns ``None``. Its dataset is compatible with the original one. | ||
Such datasets used for models which do not take any inputs, | ||
neither return any outputs. We expect models, e.g., whose ``forward()`` | ||
is starting with ``chainermn.functions.recv()`` and ending with | ||
``chainermn.functions.send()``. | ||
|
||
Args: | ||
dataset(chainer.datasets.TupleDataset): Dataset to convert. | ||
|
||
Returns: | ||
~chainer.datasets.TransformDataset: | ||
Dataset consists of only patterns in the original one. | ||
""" | ||
return chainer.datasets.TransformDataset(dataset, lambda data: ()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably just |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,15 +14,23 @@ def __init__(self, comm, peer_rank, peer_tag): | |
|
||
def forward(self, inputs): | ||
xp = cuda.get_array_module(*inputs) | ||
x, = inputs | ||
# Note: inputs[1] might contain delegate_variable. | ||
x = inputs[0] | ||
self.comm.send(x, self.peer_rank, self.peer_tag) | ||
return xp.array([]), | ||
# Return an empty variable, which serves as "delegate_variable." | ||
return xp.array([], dtype=xp.float32), | ||
|
||
def backward(self, inputs, grad_outputs): | ||
xp = cuda.get_array_module(*inputs) | ||
with cuda.get_device_from_array(*inputs): | ||
gy = self.comm.recv(self.peer_rank, self.peer_tag) | ||
return xp.array(gy), | ||
if len(inputs) > 1: | ||
# Dummy grad for delegate_variable. | ||
# This grad will not be used, only for silencing type checker. | ||
grad_delegate_variable = inputs[1] | ||
return xp.array(gy), grad_delegate_variable | ||
else: | ||
return xp.array(gy), | ||
|
||
|
||
class Recv(chainer.Function): | ||
|
@@ -38,17 +46,25 @@ def __init__(self, comm, peer_rank, peer_tag, device=-1): | |
def __call__(self, *inputs): | ||
xp = cuda.get_array_module(*inputs) | ||
|
||
if chainer.__version__.startswith('1.'): | ||
# For backward compatibility. | ||
dummy_var = chainer.Variable(xp.array([]), volatile='auto') | ||
else: | ||
# This variable is necessary to backprop correctly in Chainer v2. | ||
# This trick relies on the fact chainer.Variable.requires_grad is | ||
# True by default at Chainer v2.0.0. | ||
dummy_var = chainer.Variable(xp.array([])) | ||
if inputs == (): | ||
# Expected to be invoked without any args in usual case. | ||
if chainer.__version__.startswith('1.'): | ||
# For backward compatibility. | ||
dummy_var = chainer.Variable( | ||
xp.array([], dtype=xp.float32), | ||
volatile='auto') | ||
else: | ||
# This variable is necessary to backprop correctly | ||
# in Chainer v2. This trick relies on the fact | ||
# chainer.Variable.requires_grad is True by default | ||
# in Chainer v2.0.0. | ||
dummy_var = chainer.Variable(xp.array([], dtype=xp.float32)) | ||
|
||
return super(Recv, self).__call__(dummy_var) | ||
|
||
ret = super(Recv, self).__call__(dummy_var) | ||
return ret | ||
else: | ||
# Used for retaining computational graph. | ||
return super(Recv, self).__call__(*inputs) | ||
|
||
def forward(self, inputs): | ||
x = self.comm.recv(self.peer_rank, self.peer_tag) | ||
|
@@ -61,7 +77,7 @@ def backward(self, inputs, grad_outputs): | |
xp = cuda.get_array_module(*inputs) | ||
gw, = grad_outputs | ||
self.comm.send(gw, self.peer_rank, self.peer_tag) | ||
dummy_var = xp.array([[]]) | ||
dummy_var = xp.array([[]], dtype=xp.float32) | ||
return dummy_var | ||
|
||
|
||
|
@@ -83,23 +99,33 @@ def send(x, communicator, rank, tag=0): | |
Returns: | ||
~chainer.Variable: | ||
A dummy variable with no actual data, only holding the | ||
computational graph. If ``backward()`` is invoked by this dummy | ||
variable, it will try to receive gradients from the target process. | ||
computational graph. We call this ``delegate_variable``. | ||
If ``backward()`` is invoked by delegate_variable, | ||
it will try to receive gradients from the target process. | ||
|
||
""" | ||
chainer.utils.experimental('chainermn.functions.send') | ||
return Send(communicator, peer_rank=rank, peer_tag=tag)(x) | ||
|
||
|
||
def recv(communicator, rank, tag=0, device=-1): | ||
def recv(communicator, rank, delegate_variable=None, tag=0, device=-1): | ||
"""Receive elements from target process. | ||
|
||
This function returns data received from target process. If ``backward()`` | ||
is invoked, it will try to send gradients to the target process. | ||
|
||
.. note:: | ||
If you define non-connected computational graph on one machine, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. machine -> process |
||
you have to use ``delegate_variable`` to specify the output of | ||
previous computational graph component. | ||
Otherwise ``backward()`` does not work well. | ||
|
||
Args: | ||
communicator (chainer.communicators.CommunicatorBase): | ||
ChainerMN communicator. | ||
rank (int): Target process specifier. | ||
delegate_variable (chainer.Variable): | ||
Pointer to the other non-connected component. | ||
tag (int): Optional message ID (MPI feature). | ||
device (int): Target device specifier. | ||
|
||
|
@@ -109,4 +135,16 @@ def recv(communicator, rank, tag=0, device=-1): | |
by this variable, it will send gradients to the target process. | ||
|
||
""" | ||
return Recv(communicator, peer_rank=rank, peer_tag=tag, device=device)() | ||
chainer.utils.experimental('chainermn.functions.recv') | ||
if delegate_variable is None: | ||
return Recv( | ||
communicator, | ||
peer_rank=rank, | ||
peer_tag=tag, | ||
device=device)() | ||
else: | ||
return Recv( | ||
communicator, | ||
peer_rank=rank, | ||
peer_tag=tag, | ||
device=device)(delegate_variable) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import chainer | ||
from chainer import cuda | ||
import chainer.utils | ||
|
||
|
||
class PseudoConnect(chainer.Function): | ||
"""Connect a variable with delegating variable.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Connect two variables with a delegating variable" |
||
|
||
def forward(self, inputs): | ||
# delegate_variable = inputs[0] | ||
actual_variables = inputs[1:] | ||
return actual_variables | ||
|
||
def backward(self, inputs, grad_outputs): | ||
delegate_variable = inputs[0] | ||
# actual_variables = inputs[1:] | ||
xp = cuda.get_array_module(*inputs) | ||
|
||
# delegate_variable do not need backward gradients, instead sending | ||
# back dummy grads in order to take consistency of shapes of grads. | ||
grad_delegate_variable = xp.zeros_like(delegate_variable) | ||
|
||
# grad_outputs corresponds to grads of actual_variables. | ||
return tuple([grad_delegate_variable] + list(grad_outputs)) | ||
|
||
|
||
def pseudo_connect(delegate_variable, *actual_variables): | ||
"""Connect independent connected graph component. | ||
|
||
In model-parallel framework, models sometimes have many non-connected | ||
components. When some additional components follow model outputs, | ||
outputs of the last component must be merged with model outputs. | ||
Otherwise backprop does not work well, got stuck into dead lock. | ||
|
||
Args: | ||
delegate_variable (chainer.Variable): | ||
Pointer to the previous non-connected graph component. | ||
actual_variables (tuple of chainer.Variable): | ||
Actual values which ``delegate_variable`` imitate. | ||
|
||
Returns: | ||
~chainer.Variable: | ||
A variable with the given values combined with delegating variable. | ||
""" | ||
chainer.utils\ | ||
.experimental('chainermn.functions.pseudo_connect.pseudo_connect') | ||
return PseudoConnect()(delegate_variable, *actual_variables) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset
does not need to beTupleDataset
. Chainer accepts many kinds of datasets.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: https://github.com/chainer/chainer/blob/v2.0.2/chainer/iterators/serial_iterator.py#L26