diff --git a/chainermn/functions/point_to_point_communication.py b/chainermn/functions/point_to_point_communication.py index 24c1a3f3..854fc868 100644 --- a/chainermn/functions/point_to_point_communication.py +++ b/chainermn/functions/point_to_point_communication.py @@ -1,5 +1,3 @@ -import collections - import chainer from chainer import cuda import chainer.utils @@ -23,22 +21,27 @@ def label(self): def forward(self, inputs): xp = cuda.get_array_module(*inputs) - if len(inputs) == 1: - inputs = inputs[0] + # The last input is dummy variable, to retain gradient computation + # of this function. + xs = inputs[:-1] + + if len(xs) == 1: + xs = xs[0] - self.comm.send(inputs, self.peer_rank, self.peer_tag) + self.comm.send(xs, self.peer_rank, self.peer_tag) # Return an empty variable, which serves as "delegate_variable." return xp.array([], dtype=xp.float32), def backward(self, inputs, grad_outputs): xp = cuda.get_array_module(*inputs) + dummy_grad = xp.array([], dtype=xp.float32) with cuda.get_device_from_array(*inputs): grad = self.comm.recv(self.peer_rank, self.peer_tag) if isinstance(grad, tuple): - return tuple([xp.array(gy) for gy in grad]) + return tuple([xp.array(gy) for gy in grad] + [dummy_grad]) else: - return xp.array(grad), + return xp.array(grad), dummy_grad class Recv(chainer.Function): @@ -135,12 +138,20 @@ def send(x, communicator, rank, tag=0): 'rank must be different from communicator rank, ' 'otherwise deadlock occurs') - if isinstance(x, collections.Iterable): + xp = cuda.get_array_module(*x) + + # Dummy variable to retain gradient computation of send, + # otherwise the corresponding recv will cause deadlock in backward + # in the case where all inputs for this function does not require_grad. + dummy_var = chainer.Variable(xp.array([], dtype=xp.float32)) + + if isinstance(x, list) or isinstance(x, tuple): + inputs = x + type(x)([dummy_var]) delegate_variable = Send( - communicator, peer_rank=rank, peer_tag=tag)(*x) + communicator, peer_rank=rank, peer_tag=tag)(*inputs) else: delegate_variable = Send( - communicator, peer_rank=rank, peer_tag=tag)(x) + communicator, peer_rank=rank, peer_tag=tag)(x, dummy_var) delegate_variable.name = 'delegate_variable' return delegate_variable diff --git a/tests/chainermn_tests/functions_tests/test_point_to_point_communication.py b/tests/chainermn_tests/functions_tests/test_point_to_point_communication.py index 15b57474..b31f2210 100644 --- a/tests/chainermn_tests/functions_tests/test_point_to_point_communication.py +++ b/tests/chainermn_tests/functions_tests/test_point_to_point_communication.py @@ -1,8 +1,10 @@ import copy import functools +import unittest import chainer import chainer.testing +import chainer.testing.attr import numpy import pytest @@ -10,9 +12,9 @@ import chainermn.functions -class PointToPointCommunication(object): +class TestPointToPointCommunication(unittest.TestCase): - def __init__(self, gpu): + def setup(self, gpu): self.gpu = gpu if self.gpu: self.communicator = chainermn.create_communicator('hierarchical') @@ -55,7 +57,7 @@ def _init_w(self, l): return 1.0 * numpy.arange(100).reshape(10, 10).astype(numpy.float32) \ / ((l + 1) * 100) - def test_communication(self): + def check_communication(self): if self.communicator.rank == 0: # Input process. y = self.f(self.model(self.x)) @@ -99,7 +101,16 @@ def test_communication(self): y, self.communicator, self.rank_send) err.backward() - def test_retain(self): + def test_communication_cpu(self): + self.setup(False) + self.check_communication() + + @chainer.testing.attr.gpu + def test_communication_gpu(self): + self.setup(True) + self.check_communication() + + def check_retain(self): if self.communicator.rank == 0: # Starting process. t = copy.copy(self.x) @@ -127,6 +138,15 @@ def test_retain(self): y, self.communicator, self.rank_send) err.backward() + def test_retain_cpu(self): + self.setup(False) + self.check_retain() + + @chainer.testing.attr.gpu + def test_retain_gpu(self): + self.setup(True) + self.check_retain() + def check_tuple_communication(self, length): if self.communicator.rank == 0: y = [] @@ -153,25 +173,66 @@ def check_tuple_communication(self, length): y, self.communicator, self.rank_send) err.backward() - def test_tuple_communication1(self): + def test_tuple_communication1_cpu(self): + self.setup(False) self.check_tuple_communication(1) - def test_tuple_communication2(self): + def test_tuple_communication2_cpu(self): + self.setup(False) self.check_tuple_communication(2) + @chainer.testing.attr.gpu + def test_tuple_communication1_gpu(self): + self.setup(True) + self.check_tuple_communication(1) + + @chainer.testing.attr.gpu + def test_tuple_communication2_gpu(self): + self.setup(True) + self.check_tuple_communication(2) + + +class TestNonVariableInput(unittest.TestCase): + + def setUp(self): + self.communicator = chainermn.create_communicator('naive') + + if self.communicator.size < 2: + pytest.skip("This test is for multinode") + + self.rank_send = (self.communicator.rank + 1) % self.communicator.size + self.rank_recv = (self.communicator.rank - 1) % self.communicator.size -def test_cpu(): - p2pcom = PointToPointCommunication(False) - p2pcom.test_communication() - p2pcom.test_retain() - p2pcom.test_tuple_communication1() - p2pcom.test_tuple_communication2() + def test_non_variable_send(self): + """Checks if backward will be called even if inputs are not Variable. + This test confirms whether deadlock occurs when numpy/cupy array is + given as an input of send. + In this case, the input will be converted to chainer Variable without + ``requires_grad``, thus ``backward`` will not be called without any + modification. + """ + if self.communicator.rank == 0: + x = numpy.ones((1, 10)).astype(numpy.float32) + phi = chainermn.functions.send( + x, self.communicator, rank=self.rank_send) + x = chainermn.functions.pseudo_connect(phi, x) + y = chainer.functions.sum(x) + t = numpy.array(0).astype(numpy.float32) + z = chainer.functions.mean_squared_error(y, t) + z.backward() -@chainer.testing.attr.gpu -def test_gpu(): - p2pcom = PointToPointCommunication(True) - p2pcom.test_communication() - p2pcom.test_retain() - p2pcom.test_tuple_communication1() - p2pcom.test_tuple_communication2() + elif self.communicator.rank == self.communicator.size - 1: + x = chainermn.functions.recv( + self.communicator, rank=self.rank_recv) + y = chainer.functions.sum(x) + t = numpy.array(0).astype(numpy.float32) + z = chainer.functions.mean_squared_error(y, t) + z.backward() + + else: + x = chainermn.functions.recv( + self.communicator, rank=self.rank_recv) + phi = chainermn.functions.send( + x, self.communicator, rank=self.rank_next) + phi.backward()