Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tuple data communication #139

Merged
merged 16 commits into from
Dec 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions chainermn/communicators/_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
import mpi4py.MPI
import collections

import numpy

import chainer.cuda
import chainer.utils
from chainermn.communicators import _communication_utility
from chainermn.communicators import _memory_utility
from chainermn import nccl


class _MessageType(object):

def __init__(self, obj):
if isinstance(obj, numpy.ndarray) \
or chainer.cuda.get_array_module(obj) is not numpy:
self.is_tuple = False
self.narr = 1
self.ndims = [obj.ndim]
self.shapes = [obj.shape]
elif isinstance(obj, collections.Iterable):
self.is_tuple = True
self.narr = len(obj)
self.ndims = [x.ndim for x in obj]
self.shapes = [x.shape for x in obj]
else:
raise ValueError(
'Message object must be numpy/cupy array or tuple.')


class CommunicatorBase(object):

def __init__(self, mpi_comm):
Expand Down Expand Up @@ -41,7 +62,7 @@ def split(self, color, key):
"""
return self.__class__(mpi_comm=self.mpi_comm.Split(color, key))

def send(self, array, dest, tag):
def send(self, obj, dest, tag):
"""A primitive for inter-process transmitter.

This method sends numpy-array to target process.
Expand All @@ -51,23 +72,26 @@ def send(self, array, dest, tag):
chainer.Variable objects. Please be sure.

Args:
array: numpy or cupy array object.
obj: data to be sent (tuple, list or raw numpy/cupy array)
dest (int): Target process specifier.
tag (int): Message ID (MPI feature).

"""
chainer.utils.experimental(
'chainermn.communicators.CommunicatorBase.send')
assert array.dtype == numpy.float32
ndim = numpy.array([array.ndim], dtype=numpy.int32)
shape = numpy.array(array.shape, dtype=numpy.int32)
buf = _memory_utility.array_to_buffer_object(array)
self.mpi_comm.Send([ndim, mpi4py.MPI.INT], dest=dest, tag=tag)
self.mpi_comm.Send([shape, mpi4py.MPI.INT], dest=dest, tag=tag)

if chainer.cuda.get_array_module(array) is not numpy:
chainer.cuda.Stream.null.synchronize()
self.mpi_comm.Send(buf, dest=dest, tag=tag)
msgtype = _MessageType(obj)
self.mpi_comm.send(msgtype, dest=dest, tag=tag)

if not msgtype.is_tuple:
obj = [obj]

for array in obj:
if chainer.cuda.get_array_module(array) is not numpy:
chainer.cuda.Stream.null.synchronize()

buf = _memory_utility.array_to_buffer_object(array)
self.mpi_comm.Send(buf, dest=dest, tag=tag)

def recv(self, source, tag):
"""A primitive of inter-process receiver.
Expand All @@ -86,13 +110,23 @@ def recv(self, source, tag):

chainer.utils.experimental(
'chainermn.communicators.CommunicatorBase.recv')
ndim = numpy.empty(1, dtype=numpy.int32)
self.mpi_comm.Recv([ndim, mpi4py.MPI.INT], source=source, tag=tag)
shape = numpy.empty(ndim[0], dtype=numpy.int32)
self.mpi_comm.Recv([shape, mpi4py.MPI.INT], source=source, tag=tag)
buf = numpy.empty(shape.prod(), dtype=numpy.float32)
self.mpi_comm.Recv(buf, source=source, tag=tag)
return buf.reshape(shape)

msgtype = self.mpi_comm.recv(source=source, tag=tag)

if msgtype.is_tuple:
msg = []
for shape in msgtype.shapes:
buf = numpy.empty(numpy.prod(shape), dtype=numpy.float32)
self.mpi_comm.Recv(buf, source=source, tag=tag)
msg.append(buf.reshape(shape))
return tuple(msg)

else:
assert len(msgtype.shapes) == 1
shape = msgtype.shapes[0]
buf = numpy.empty(numpy.prod(shape), dtype=numpy.float32)
self.mpi_comm.Recv(buf, source=source, tag=tag)
return buf.reshape(shape)

def broadcast_data(self, model):
raise NotImplementedError()
Expand Down
68 changes: 45 additions & 23 deletions chainermn/functions/point_to_point_communication.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections

import chainer
from chainer import cuda
import chainer.utils
Expand All @@ -20,23 +22,23 @@ def label(self):

def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
# Note: inputs[1] might contain delegate_variable.
x = inputs[0]
self.comm.send(x, self.peer_rank, self.peer_tag)

if len(inputs) == 1:
inputs = inputs[0]

self.comm.send(inputs, self.peer_rank, self.peer_tag)

# Return an empty variable, which serves as "delegate_variable."
return xp.array([], dtype=xp.float32),

def backward(self, inputs, grad_outputs):
xp = cuda.get_array_module(*inputs)
with cuda.get_device_from_array(*inputs):
gy = self.comm.recv(self.peer_rank, self.peer_tag)
if len(inputs) > 1:
# Dummy grad for delegate_variable.
# This grad will not be used, only for silencing type checker.
grad_delegate_variable = inputs[1]
return xp.array(gy), grad_delegate_variable
grad = self.comm.recv(self.peer_rank, self.peer_tag)
if isinstance(grad, tuple):
return tuple([xp.array(gy) for gy in grad])
else:
return xp.array(gy),
return xp.array(grad),


class Recv(chainer.Function):
Expand Down Expand Up @@ -80,24 +82,28 @@ def label(self):
self.peer_rank)

def forward(self, inputs):
x = self.comm.recv(self.peer_rank, self.peer_tag)
data = self.comm.recv(self.peer_rank, self.peer_tag)

if not isinstance(data, tuple):
data = tuple([data])

if isinstance(self.device, int) and self.device >= 0:
return cuda.to_gpu(x, device=self.device),
return tuple([cuda.to_gpu(x, device=self.device) for x in data])
else:
return x,
return data

def backward(self, inputs, grad_outputs):
xp = cuda.get_array_module(*inputs)
gw, = grad_outputs
self.comm.send(gw, self.peer_rank, self.peer_tag)
self.comm.send(grad_outputs, self.peer_rank, self.peer_tag)

# dummy_var is needed to maintain Chainer's constraint.
if inputs == ():
dummy_var = xp.array([], dtype=xp.float32)
dummy_var = tuple([xp.array([], dtype=xp.float32)])
else:
var, = inputs
dummy_var = xp.zeros(var.shape, dtype=xp.float32)
dummy_var = tuple([xp.zeros(x.shape, dtype=xp.float32)
for x in inputs])

return dummy_var,
return dummy_var


def send(x, communicator, rank, tag=0):
Expand Down Expand Up @@ -129,12 +135,20 @@ def send(x, communicator, rank, tag=0):
'rank must be different from communicator rank, '
'otherwise deadlock occurs')

delegate_variable = Send(communicator, peer_rank=rank, peer_tag=tag)(x)
if isinstance(x, collections.Iterable):
delegate_variable = Send(
communicator, peer_rank=rank, peer_tag=tag)(*x)
else:
delegate_variable = Send(
communicator, peer_rank=rank, peer_tag=tag)(x)

delegate_variable.name = 'delegate_variable'
return delegate_variable


def recv(communicator, rank, delegate_variable=None, tag=0, device=-1):
def recv(
communicator, rank, delegate_variable=None, tag=0, device=-1,
force_tuple=False):
"""Receive elements from target process.

This function returns data received from target process. If ``backward()``
Expand All @@ -155,6 +169,9 @@ def recv(communicator, rank, delegate_variable=None, tag=0, device=-1):
Pointer to the other non-connected component.
tag (int): Optional message ID (MPI feature).
device (int): Target device specifier.
force_tuple (bool): If ``False`` (the default) a Variable will be
returned when the number of outputs is one. Otherwise, this
method returns a tuple even when the number of outputs is one.

Returns:
~chainer.Variable:
Expand All @@ -170,15 +187,20 @@ def recv(communicator, rank, delegate_variable=None, tag=0, device=-1):
'otherwise deadlock occurs')

if delegate_variable is None:
return Recv(
res = Recv(
communicator,
peer_rank=rank,
peer_tag=tag,
device=device)()
else:
delegate_variable.name = 'delegate_variable'
return Recv(
res = Recv(
communicator,
peer_rank=rank,
peer_tag=tag,
device=device)(delegate_variable)

if force_tuple and not isinstance(res, tuple):
return tuple([res])
else:
return res
37 changes: 35 additions & 2 deletions chainermn/links/multi_node_chain_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MultiNodeChainList(chainer.ChainList):
is invoked in forward computation according to the order they are added,
and in backward computation according to the reversed order.

.. admonition:: Example
.. admonition:: Example (basic usage)

This is a simple example of the model which sends its outputs to
rank=1 machine::
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, comm, n_in, n_hidden, n_out):
rank_in=None,
rank_out=1)

.. admonition:: Example
.. admonition:: Example (split MLP on 2 processes)

This is the other example of two models interacting each other::

Expand Down Expand Up @@ -99,6 +99,39 @@ def __init__(self, comm):
to ``Model1``, then ``MLP`` in ``Model1`` will receive it and send
its outputs to the second ``MLP`` in ``Model0``.

.. admonition:: Example (sending tuples)

This is the example for sending a tuple::

import chainer
import chainer.functions as F
import chainermn

class NN0(chainer.Chain):
def __call__(self, x):
y0 = some_calculation_nn0_0(x)
y1 = some_calculation_nn1_1(x)
return y0, y1

class NN1(chainer.Chain):
def __call__(self, y):
y0, y1 = y # unpack tuple from NN0
return some_calculation_nn1(y0, y1)

class Model_on_Process_0(chainermn.MultiNodeChainList):
def __init__(self, comm):
super(Model_on_Process_0, self).__init__(comm=comm)
self.add_link(NN0(), rank_in=None, rank_out=1)

class Model_on_Process_1(chainermn.MultiNodeChainList):
def __init__(self, comm):
super(Model_on_Process_1, self).__init__(comm=comm)
self.add_link(NN1(), rank_in=0, rank_out=None)

In this example, ``Model_on_Process_0`` sends two elemental tuple
``(y0, y1)`` (returned by ``NN0.__call__``) to ``Model_on_Process_1``,
which can be unpacked as shown in ``NN1.__call__``.

Args:
comm (chainermn.communicators._base.CommunicatorBase):
ChainerMN communicator.
Expand Down
33 changes: 33 additions & 0 deletions tests/functions_tests/test_point_to_point_communication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import functools
import nose.plugins.skip
import unittest

Expand Down Expand Up @@ -130,3 +131,35 @@ def test_retain(self):
err = chainermn.functions.send(
y, self.communicator, self.rank_send)
err.backward()

def check_tuple_communication(self, length):
if self.communicator.rank == 0:
y = []
for i in range(length):
_y = self.f(self.model(self.x))
y.append(_y)
err = chainermn.functions.send(
y, self.communicator, self.rank_send)
err.backward()

elif self.communicator.rank == self.communicator.size - 1:
y = chainermn.functions.recv(
self.communicator, self.rank_recv, device=self.device,
force_tuple=True)
self.assertTrue(isinstance(y, tuple))
z = functools.reduce(lambda x, y: x + y, y)
err = self.evaluation(z, self.x)
err.backward()

else:
y = chainermn.functions.recv(
self.communicator, self.rank_recv, device=self.device)
err = chainermn.functions.send(
y, self.communicator, self.rank_send)
err.backward()

def test_tuple_communication1(self):
self.check_tuple_communication(1)

def test_tuple_communication2(self):
self.check_tuple_communication(2)
25 changes: 25 additions & 0 deletions tests/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,31 @@ def test_send_and_recv3(self):
def test_send_and_recv4(self):
self.check_send_and_recv(50, 20, 5, 3)

def check_send_and_recv_tuple(self, data):
if self.communicator.size < 2:
raise nose.plugins.skip.SkipTest()

if self.communicator.rank > 0:
rank_prev = (self.communicator.rank - 1) % self.communicator.size
data_recv = self.communicator.recv(source=rank_prev, tag=0)
for array0, array1 in zip(data, data_recv):
chainer.testing.assert_allclose(array0, array1)

if self.communicator.rank < self.communicator.size - 1:
rank_next = (self.communicator.rank + 1) % self.communicator.size
self.communicator.send(data, dest=rank_next, tag=0)

def test_send_and_recv5(self):
data = [np.ones((50)).astype(np.float32)]
self.check_send_and_recv_tuple(data)

def test_send_and_recv6(self):
data = [
np.ones((50)).astype(np.float32),
np.ones((50, 20)).astype(np.float32),
np.ones((50, 20, 5)).astype(np.float32)]
self.check_send_and_recv_tuple(data)

def check_broadcast_data(self, model):
model.a.W.data[:] = self.communicator.rank
model.b.W.data[:] = self.communicator.rank + 1
Expand Down
Loading