Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model-parallel MNIST example #98

Merged
merged 53 commits into from
Aug 4, 2017
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
2097638
Model-parallel example.
levelfour Jul 10, 2017
ccaebe1
Use empty dataset.
levelfour Jul 14, 2017
1e74c5d
MultiNodeChain
levelfour Jul 14, 2017
f751651
Update docstring.
levelfour Jul 14, 2017
7cc8cb9
Rename train_mnist.py to train_mnist_data_parallel.py.
levelfour Jul 14, 2017
d2a24f0
Merge branch 'master' into model-parallel-mnist
levelfour Jul 14, 2017
d948187
Check if comm is ChainerMN communicator.
levelfour Jul 14, 2017
1c862a9
Add MultiNodeChainGroup.
levelfour Jul 14, 2017
7fcff10
Add tests for recv_retain and get_empty_dataset.
levelfour Jul 14, 2017
ff73511
Add test for MultiNodeChain.
levelfour Jul 14, 2017
1501b9d
Fix PEP8.
levelfour Jul 14, 2017
774aac2
Revise the design of MultiNodeChainGroup.
levelfour Jul 21, 2017
360d1af
Merge.
levelfour Jul 21, 2017
b221fb6
Extend MultiNodeChainGroup for reversing send and recv.
levelfour Jul 21, 2017
36d6df3
Branching send & recv.
levelfour Jul 21, 2017
b3e4802
Do not need division.
levelfour Jul 25, 2017
997f4ff
Refactoring.
levelfour Jul 25, 2017
e320ae8
Add test for branching model.
levelfour Jul 28, 2017
69beef0
Typo.
levelfour Jul 28, 2017
31d4abb
Update docs.
levelfour Jul 28, 2017
f34e3e5
Assertion.
levelfour Jul 28, 2017
3c5b957
Increase processes for Travis test.
levelfour Jul 28, 2017
a9a0bb3
Fix names.
levelfour Jul 29, 2017
66989b5
Fix for PEP8.
levelfour Jul 29, 2017
516d7ff
Fix for PEP8.
levelfour Jul 29, 2017
bc8ffe5
Fix multiple recv.
levelfour Jul 30, 2017
7904d23
Rename: merge -> pseudo_connect.
levelfour Jul 30, 2017
c01d22f
Fix test_branching_model to respond to multiple recv.
levelfour Jul 30, 2017
74e0ecb
Fix for PEP8.
levelfour Jul 30, 2017
94eb451
Deprecated argument backward_pointer for send.
levelfour Aug 2, 2017
a5599d5
Revert renaming.
levelfour Aug 2, 2017
d2a882f
README for MNIST example
levelfour Aug 2, 2017
ac27f1e
Merge remote-tracking branch 'upstream/master' into model-parallel-mnist
levelfour Aug 2, 2017
a64e5ba
Rename.
levelfour Aug 2, 2017
0a0b26a
Rename.
levelfour Aug 2, 2017
ecb4538
Merge remote-tracking branch 'upstream/master' into model-parallel-mnist
levelfour Aug 2, 2017
20814e5
Trivial fix.
levelfour Aug 2, 2017
4e35f6a
Make PseudoConnect takes variables as arguments.
levelfour Aug 2, 2017
8f04184
Rename: backward_pointer -> delegate_variable.
levelfour Aug 2, 2017
9215d81
Fix for PEP8.
levelfour Aug 2, 2017
c521364
Move pseudo_connect to a new file.
levelfour Aug 4, 2017
6b0dd6d
Add test for PseudoConnect.
levelfour Aug 4, 2017
f79059e
Fix TestPseudoConnect.
levelfour Aug 4, 2017
f4633a8
Fix TestPseudoConnect.
levelfour Aug 4, 2017
56d2375
Expose pseudo_connect.
levelfour Aug 4, 2017
20c09f8
Fix empty dataset.
levelfour Aug 4, 2017
35cd1d2
Fix for PEP8.
levelfour Aug 4, 2017
b4cb4a8
Add docs for pseudo_connect.
levelfour Aug 4, 2017
f6baf69
Fix a little bit.
levelfour Aug 4, 2017
b4ccdcd
Fix test_empty_dataset.
levelfour Aug 4, 2017
5a9cdd2
Add a little bit in docs of pseudo_connect.
levelfour Aug 4, 2017
592fe95
Fix comments.:
levelfour Aug 4, 2017
cd0d7d8
Fix.
levelfour Aug 4, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ script:
- flake8 --config=.flake8.cython .
- autopep8 -r . --global-config .pep8 | tee check_autopep8
- test ! -s check_autopep8
- for NP in 1 2; do PYTHONWARNINGS='ignore::FutureWarning,module::DeprecationWarning' mpiexec -n ${NP} nosetests -v -a '!nccl,!gpu'; done
- for NP in 1 2 3; do PYTHONWARNINGS='ignore::FutureWarning,module::DeprecationWarning' mpiexec -n ${NP} nosetests -v -a '!nccl,!gpu'; done
# - cd tests
# - PYTHONWARNINGS='ignore::FutureWarning,module::DeprecationWarning' nosetests -a '!gpu,!slow' --with-doctest chainer_tests
- if [[ $TRAVIS_OS_NAME == "linux" ]]; then
Expand Down
1 change: 1 addition & 0 deletions chainermn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from chainermn.communicators import create_communicator # NOQA
from chainermn.dataset import scatter_dataset # NOQA
from chainermn.link import MultiNodeChainList # NOQA
from chainermn.multi_node_evaluator import create_multi_node_evaluator # NOQA
from chainermn.multi_node_optimizer import create_multi_node_optimizer # NOQA

Expand Down
1 change: 1 addition & 0 deletions chainermn/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from chainermn.datasets.empty_dataset import create_empty_dataset # NOQA
21 changes: 21 additions & 0 deletions chainermn/datasets/empty_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import chainer


def create_empty_dataset(dataset):
"""Creates an empty dataset for models with no inputs and outputs.

This function generates an empty dataset, i.e., ``__getitem__()`` only
returns ``None``. Its dataset is compatible with the original one.
Such datasets used for models which do not take any inputs,
neither return any outputs. We expect models, e.g., whose ``forward()``
is starting with ``chainermn.functions.recv()`` and ending with
``chainermn.functions.send()``.

Args:
dataset(chainer.datasets.TupleDataset): Dataset to convert.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset does not need to be TupleDataset. Chainer accepts many kinds of datasets.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Returns:
~chainer.datasets.TransformDataset:
Dataset consists of only patterns in the original one.
"""
return chainer.datasets.TransformDataset(dataset, lambda data: ())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably just [()] * len(dataset) is enough? (TransformDataset preserves dataset, so it will consume unnecessary memory.)

74 changes: 56 additions & 18 deletions chainermn/functions/point_to_point_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,23 @@ def __init__(self, comm, peer_rank, peer_tag):

def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
x, = inputs
# Note: inputs[1] might contain delegate_variable.
x = inputs[0]
self.comm.send(x, self.peer_rank, self.peer_tag)
return xp.array([]),
# Return an empty variable, which serves as "delegate_variable."
return xp.array([], dtype=xp.float32),

def backward(self, inputs, grad_outputs):
xp = cuda.get_array_module(*inputs)
with cuda.get_device_from_array(*inputs):
gy = self.comm.recv(self.peer_rank, self.peer_tag)
return xp.array(gy),
if len(inputs) > 1:
# Dummy grad for delegate_variable.
# This grad will not be used, only for silencing type checker.
grad_delegate_variable = inputs[1]
return xp.array(gy), grad_delegate_variable
else:
return xp.array(gy),


class Recv(chainer.Function):
Expand All @@ -38,17 +46,25 @@ def __init__(self, comm, peer_rank, peer_tag, device=-1):
def __call__(self, *inputs):
xp = cuda.get_array_module(*inputs)

if chainer.__version__.startswith('1.'):
# For backward compatibility.
dummy_var = chainer.Variable(xp.array([]), volatile='auto')
else:
# This variable is necessary to backprop correctly in Chainer v2.
# This trick relies on the fact chainer.Variable.requires_grad is
# True by default at Chainer v2.0.0.
dummy_var = chainer.Variable(xp.array([]))
if inputs == ():
# Expected to be invoked without any args in usual case.
if chainer.__version__.startswith('1.'):
# For backward compatibility.
dummy_var = chainer.Variable(
xp.array([], dtype=xp.float32),
volatile='auto')
else:
# This variable is necessary to backprop correctly
# in Chainer v2. This trick relies on the fact
# chainer.Variable.requires_grad is True by default
# in Chainer v2.0.0.
dummy_var = chainer.Variable(xp.array([], dtype=xp.float32))

return super(Recv, self).__call__(dummy_var)

ret = super(Recv, self).__call__(dummy_var)
return ret
else:
# Used for retaining computational graph.
return super(Recv, self).__call__(*inputs)

def forward(self, inputs):
x = self.comm.recv(self.peer_rank, self.peer_tag)
Expand All @@ -61,7 +77,7 @@ def backward(self, inputs, grad_outputs):
xp = cuda.get_array_module(*inputs)
gw, = grad_outputs
self.comm.send(gw, self.peer_rank, self.peer_tag)
dummy_var = xp.array([[]])
dummy_var = xp.array([[]], dtype=xp.float32)
return dummy_var


Expand All @@ -83,23 +99,33 @@ def send(x, communicator, rank, tag=0):
Returns:
~chainer.Variable:
A dummy variable with no actual data, only holding the
computational graph. If ``backward()`` is invoked by this dummy
variable, it will try to receive gradients from the target process.
computational graph. We call this ``delegate_variable``.
If ``backward()`` is invoked by delegate_variable,
it will try to receive gradients from the target process.

"""
chainer.utils.experimental('chainermn.functions.send')
return Send(communicator, peer_rank=rank, peer_tag=tag)(x)


def recv(communicator, rank, tag=0, device=-1):
def recv(communicator, rank, delegate_variable=None, tag=0, device=-1):
"""Receive elements from target process.

This function returns data received from target process. If ``backward()``
is invoked, it will try to send gradients to the target process.

.. note::
If you define non-connected computational graph on one machine,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

machine -> process

you have to use ``delegate_variable`` to specify the output of
previous computational graph component.
Otherwise ``backward()`` does not work well.

Args:
communicator (chainer.communicators.CommunicatorBase):
ChainerMN communicator.
rank (int): Target process specifier.
delegate_variable (chainer.Variable):
Pointer to the other non-connected component.
tag (int): Optional message ID (MPI feature).
device (int): Target device specifier.

Expand All @@ -109,4 +135,16 @@ def recv(communicator, rank, tag=0, device=-1):
by this variable, it will send gradients to the target process.

"""
return Recv(communicator, peer_rank=rank, peer_tag=tag, device=device)()
chainer.utils.experimental('chainermn.functions.recv')
if delegate_variable is None:
return Recv(
communicator,
peer_rank=rank,
peer_tag=tag,
device=device)()
else:
return Recv(
communicator,
peer_rank=rank,
peer_tag=tag,
device=device)(delegate_variable)
47 changes: 47 additions & 0 deletions chainermn/functions/pseudo_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import chainer
from chainer import cuda
import chainer.utils


class PseudoConnect(chainer.Function):
"""Connect a variable with delegating variable."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Connect two variables with a delegating variable"
or
"Connect a variable to a delegating variable" ?


def forward(self, inputs):
# delegate_variable = inputs[0]
actual_variables = inputs[1:]
return actual_variables

def backward(self, inputs, grad_outputs):
delegate_variable = inputs[0]
# actual_variables = inputs[1:]
xp = cuda.get_array_module(*inputs)

# delegate_variable do not need backward gradients, instead sending
# back dummy grads in order to take consistency of shapes of grads.
grad_delegate_variable = xp.zeros_like(delegate_variable)

# grad_outputs corresponds to grads of actual_variables.
return tuple([grad_delegate_variable] + list(grad_outputs))


def pseudo_connect(delegate_variable, *actual_variables):
"""Connect independent connected graph component.

In model-parallel framework, models sometimes have many non-connected
components. When some additional components follow model outputs,
outputs of the last component must be merged with model outputs.
Otherwise backprop does not work well, got stuck into dead lock.

Args:
delegate_variable (chainer.Variable):
Pointer to the previous non-connected graph component.
actual_variables (tuple of chainer.Variable):
Actual values which ``delegate_variable`` imitate.

Returns:
~chainer.Variable:
A variable with the given values combined with delegating variable.
"""
chainer.utils\
.experimental('chainermn.functions.pseudo_connect.pseudo_connect')
return PseudoConnect()(delegate_variable, *actual_variables)
Loading