-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MultiNodeChainList with self branching #102
Changes from 7 commits
9ae7203
08c98b8
a2a9088
30fc321
9b303ae
2e198b1
7923111
81b336b
05d2632
77fa9e6
6ddbc51
307d6c2
aafafe8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -77,8 +77,14 @@ def backward(self, inputs, grad_outputs): | |
xp = cuda.get_array_module(*inputs) | ||
gw, = grad_outputs | ||
self.comm.send(gw, self.peer_rank, self.peer_tag) | ||
dummy_var = xp.array([[]], dtype=xp.float32) | ||
return dummy_var | ||
|
||
if inputs == (): | ||
dummy_var = xp.array([], dtype=xp.float32) | ||
else: | ||
var, = inputs | ||
dummy_var = xp.zeros(var.shape, dtype=xp.float32) | ||
|
||
return dummy_var, | ||
|
||
|
||
def send(x, communicator, rank, tag=0): | ||
|
@@ -104,6 +110,7 @@ def send(x, communicator, rank, tag=0): | |
|
||
""" | ||
chainer.utils.experimental('chainermn.functions.send') | ||
assert rank != communicator.rank | ||
return Send(communicator, peer_rank=rank, peer_tag=tag)(x) | ||
|
||
|
||
|
@@ -136,6 +143,7 @@ def recv(communicator, rank, delegate_variable=None, tag=0, device=-1): | |
|
||
""" | ||
chainer.utils.experimental('chainermn.functions.recv') | ||
assert rank != communicator.rank | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assertion comment |
||
if delegate_variable is None: | ||
return Recv( | ||
communicator, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from six.moves import queue | ||
|
||
import chainer | ||
import chainermn | ||
import chainermn.communicators | ||
|
@@ -126,14 +128,10 @@ def add_link(self, link, rank_in=None, rank_out=None): | |
if isinstance(rank_out, int): | ||
rank_out = [rank_out] | ||
|
||
assert rank_in is None or self._comm.rank not in rank_in,\ | ||
"cannot specify self rank for rank_in" | ||
assert rank_out is None or self._comm.rank not in rank_out,\ | ||
"cannot specify self rank for rank_out" | ||
|
||
self._rank_inouts.append((rank_in, rank_out)) | ||
|
||
def __call__(self, *inputs): | ||
comm_queue = queue.Queue() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about checking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you. I fixed it. |
||
y = None | ||
delegate_variable = None | ||
|
||
|
@@ -157,23 +155,30 @@ def __call__(self, *inputs): | |
# Preprocess: receiving inputs from the other machines. | ||
xs = [] | ||
for _rank_in in rank_in: | ||
_x = chainermn.functions.recv( | ||
self._comm, | ||
rank=_rank_in, | ||
delegate_variable=delegate_variable, | ||
device=self._device_id) | ||
if _rank_in == self._comm.rank: | ||
# Receive inputs from itself. | ||
if delegate_variable is None: | ||
_x = comm_queue.get() | ||
else: | ||
_x = chainermn.functions.pseudo_connect( | ||
delegate_variable, | ||
comm_queue.get()) | ||
else: | ||
_x = chainermn.functions.recv( | ||
self._comm, | ||
rank=_rank_in, | ||
delegate_variable=delegate_variable, | ||
device=self._device_id) | ||
|
||
xs.append(_x) | ||
|
||
# Guarantee the backward path to the previous graph | ||
# component to be executed in the last to avoid dead-lock. | ||
if delegate_variable is not None \ | ||
and _x.creator is not None: | ||
_x.creator.rank = -1 | ||
|
||
xs.append(_x) | ||
delegate_variable = _x | ||
|
||
# Prevent "double-backwarding," i.e., backprop | ||
# the same edge more than twice. | ||
delegate_variable = None | ||
# Prevent "double-backwarding," i.e., backprop | ||
# the same edge more than twice. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More than once ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly. Don't we say "more than twice" in this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case of ratio, there's no big difference between "> 2x" and ">= 2x" , so both of them can be translated into a Japanese word "2倍以上". But in this case, I don't think it applies. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for suggestion. Fixed it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
delegate_variable = None | ||
|
||
# Actual forward. | ||
x = f(*tuple(xs)) | ||
|
@@ -186,17 +191,26 @@ def __call__(self, *inputs): | |
|
||
else: # Send outputs to the other machines. | ||
for i_comp, _rank_out in enumerate(rank_out): | ||
if i_comp == 0: | ||
if _rank_out == self._comm.rank: | ||
# Send outputs to itself. | ||
if delegate_variable is not None: | ||
x = chainermn.functions.pseudo_connect( | ||
delegate_variable, | ||
x) | ||
comm_queue.put(x) | ||
delegate_variable = x | ||
elif i_comp == 0: | ||
delegate_variable = chainermn.functions.send( | ||
x, self._comm, | ||
rank=_rank_out) | ||
else: | ||
# If the model has multiple targets for send, | ||
# we must guarantee backwards of each send to be | ||
# called in the reversed order. | ||
x = chainermn.functions.pseudo_connect( | ||
delegate_variable, | ||
x) | ||
if delegate_variable is not None: | ||
x = chainermn.functions.pseudo_connect( | ||
delegate_variable, | ||
x) | ||
delegate_variable = chainermn.functions.send( | ||
x, self._comm, | ||
rank=_rank_out) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add an assertion comment like "Cannot send to the local process itself" or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Or should it be an internal error and should not happen?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed (use
ValueError
instead).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍