Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support all-to-all #135

Merged
merged 11 commits into from
Dec 8, 2017
62 changes: 62 additions & 0 deletions chainermn/communicators/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections

import mpi4py
import numpy

import chainer.cuda
Expand All @@ -9,6 +10,11 @@
from chainermn import nccl


def _cnt_to_dsp(cnt):
"""Utility to convert length array to cumulative array."""
return [0] + numpy.cumsum(cnt)[:-1].tolist()


class _MessageType(object):

def __init__(self, obj):
Expand Down Expand Up @@ -128,6 +134,62 @@ def recv(self, source, tag):
self.mpi_comm.Recv(buf, source=source, tag=tag)
return buf.reshape(shape)

def alltoall(self, xs):
"""A primitive of inter-process all-to-all function.

This method tries to invoke all-to-all communication within the
communicator. All processes in the communicator are expected to
invoke ``alltoall()``. This method relies on mpi4py fast communication
optimized for numpy arrays, as well as ``send()`` and ``recv()``.

Args:
xs (tuple of numpy.ndarray)

Returns:
ys (tuple of numpy.ndarray):
Received arrays. The length of tuple equals to
the communicator size.
"""
chainer.utils.experimental(
'chainermn.communicators.CommunicatorBase.all_to_all')

if len(xs) != self.size:
raise ValueError(
'The length of data must be same as communicator size.')

# Mediate #axes of arrays.
sndims = numpy.array([x.ndim for x in xs], dtype=numpy.int32)
rndims = numpy.empty(self.size, dtype=numpy.int32)
self.mpi_comm.Alltoall(
[sndims, mpi4py.MPI.INT],
[rndims, mpi4py.MPI.INT])

# Arbitrate shapes of arrays.
sshapes = numpy.hstack([x.shape for x in xs]).astype(numpy.int32)
rshapes = numpy.empty(sum(rndims), dtype=numpy.int32)
self.mpi_comm.Alltoallv(
[sshapes, (sndims, _cnt_to_dsp(sndims)), mpi4py.MPI.INT],
[rshapes, (rndims, _cnt_to_dsp(rndims)), mpi4py.MPI.INT])
shapes = [rshapes[i:i + l]
for i, l in zip(_cnt_to_dsp(rndims), rndims)]

# Collective communication.
slens = [numpy.prod(x.shape) for x in xs]
xp = chainer.cuda.get_array_module(xs[0])
sbuf = xp.hstack([x.reshape(-1) for x in xs])
rlens = [numpy.prod(s) for s in shapes]
rbuf = numpy.empty(sum(rlens), dtype=numpy.float32)
if xp is not numpy:
sbuf = _memory_utility.array_to_buffer_object(sbuf)[0]
chainer.cuda.Stream.null.synchronize()
self.mpi_comm.Alltoallv(
[sbuf, (slens, _cnt_to_dsp(slens)), mpi4py.MPI.FLOAT],
[rbuf, (rlens, _cnt_to_dsp(rlens)), mpi4py.MPI.FLOAT])
ys = [rbuf[i:i + l].reshape(s)
for i, l, s in zip(_cnt_to_dsp(rlens), rlens, shapes)]

return tuple(ys)

def broadcast_data(self, model):
raise NotImplementedError()

Expand Down
1 change: 1 addition & 0 deletions chainermn/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from chainermn.functions.collective_communication import all_to_all # NOQA
from chainermn.functions.point_to_point_communication import recv # NOQA
from chainermn.functions.point_to_point_communication import send # NOQA
from chainermn.functions.pseudo_connect import pseudo_connect # NOQA
63 changes: 63 additions & 0 deletions chainermn/functions/collective_communication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import chainer
from chainer import cuda


class Alltoall(chainer.Function):
"""Collective all-to-all communication."""

def __init__(self, comm, device):
chainer.utils.experimental('chainermn.functions.Alltoall')
self.comm = comm
self.device = device

def forward(self, inputs):
if len(inputs) != self.comm.size:
raise ValueError(
'The length of inputs must be same as communicator size.')

xs = tuple([x for x in inputs])
ys = self.comm.alltoall(xs)

if isinstance(self.device, int) and self.device >= 0:
ys = tuple([cuda.to_gpu(y, device=self.device) for y in ys])

return ys

def backward(self, inputs, grad_outputs):
assert self.comm.size == len(grad_outputs)

xp = cuda.get_array_module(*inputs)
with cuda.get_device_from_array(*inputs):
gys = tuple([gy for gy in grad_outputs])
gx = self.comm.alltoall(gys)
gx = [xp.array(_gx) for _gx in gx]
return tuple(gx)


def all_to_all(comm, xs, device=-1):
"""Differentiable all-to-all communication between workers.

This function invokes all-to-all communications among processes specified
by the communicator. Backward will be invoked as well as the ordinary
chainer functions, just passing input gradients back.
Unlike point-to-point communication such as ``chainermn.functions.send``
and ``chainermn.functions.recv``, users need not to care about
delegate variables, since ``backward()`` will not be invoked until
all gradients from output direction arrive.
Please refer to ``chainermn.functions.pseudo_connect`` about the detail
of delegate variables.

Args:
comm: ChainerMN communicator.
xs (list of chainer.Variables): Variables to send.
device (int): Target device specifier.

Returns:
ys (list of chainer.Variables): Received variables.
d: A delegate variable.
"""

if len(xs) != comm.size:
raise ValueError('The length of xs must be same as communicator size.')

return Alltoall(comm, device)(*xs)
1 change: 1 addition & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ Functions
.. autofunction:: chainermn.functions.send
.. autofunction:: chainermn.functions.recv
.. autofunction:: chainermn.functions.pseudo_connect
.. autofunction:: chainermn.functions.all_to_all
62 changes: 62 additions & 0 deletions tests/functions_tests/test_collective_communication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import nose.plugins.skip
import unittest

import chainer
import chainer.testing
import chainer.testing.attr
import numpy

import chainermn
import chainermn.functions


@chainer.testing.parameterize(
{'gpu': True},
{'gpu': False},
)
class TestPointToPointCommunication(unittest.TestCase):

def setUp(self):
if self.gpu:
self.communicator = chainermn.create_communicator('hierarchical')
device = self.communicator.intra_rank
chainer.cuda.get_device(device).use()
else:
self.communicator = chainermn.create_communicator('naive')
device = -1

if self.communicator.size < 2:
raise nose.plugins.skip.SkipTest()

self.device = device

def check_all_to_all(self, xs):
ys = chainermn.functions.all_to_all(self.communicator, xs, self.device)

y = chainer.functions.sum(ys[0])
for _y in ys[1:]:
y += chainer.functions.sum(_y)

y.backward()

self.assertIsNotNone(xs[0].grad)

def test_all_to_all_cpu(self):
data = [
chainer.Variable(numpy.zeros(
(self.communicator.rank, i), dtype=numpy.float32))
for i in range(self.communicator.size)]
self.check_all_to_all(data)

@chainer.testing.attr.gpu
def test_all_to_all_gpu(self):
if not self.gpu:
raise nose.plugins.skip.SkipTest()
chainer.cuda.get_device_from_id(self.device).use()
data = [
chainer.Variable(numpy.zeros(
(self.communicator.rank, i), dtype=numpy.float32))
for i in range(self.communicator.size)]
for x in data:
x.to_gpu()
self.check_all_to_all(data)