-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix broadcast root during the replication call #3655
Fix broadcast root during the replication call #3655
Conversation
@@ -102,8 +104,20 @@ def create(cls, data, client=None, batch_enabled=False): | |||
else: | |||
raise TypeError("Graph data must be dask-cudf dataframe") | |||
|
|||
broadcast_worker = None | |||
if batch_enabled: | |||
worker_ranks = client.run(_get_nvml_device_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will fail on MNMG setups. The worker_rank
here gets nvml_device_index
which will be zero for 2 devices (the first device on both nodes) .
I think we should figure out a way to get this from raft instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we directly fetch from the upstream raft_dask
library to prevent breaking in the future too ?
worker_ranks = client.run(_get_nvml_device_index) | |
from raft_dask.common.comms import _func_worker_ranks | |
worker_ranks = _func_worker_ranks(client) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That PR is not merged yet. That is why I added a FIXME in line 286
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That PR is not merged yet. That is why I added a FIXME in line 286
No need to wait for your PR to be merged because _func_worker_ranks
already exits. Its implementation will just change once your PR is merge.
I think this will fail on MNMG setups
I don't think this can fail because the address of rank 0 in this PR always
matches the raft one. In fact, client.run(_get_nvml_device_index)
is deterministic and always returns the result in increasing network ID or IP address number(the keys of the dictionary). On a MNMG run, even though there are multiple pairs with rank 0
, this PR will always pick the one in the first node which is consistent with the raft one. This is true because the raft PR only applies the rank offsets to the second node and above. Furthermore, I extensively tested this PR on Friday on both 2 and 4 nodes and all runs passed.
How about we directly fetch from the upstream raft_dask library to prevent breaking in the future too ?
Right and this is what I had in mind when I added a FIXME. This will avoid code duplication and prevent breaking changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On a MNMG run, even though there are multiple pairs with rank 0, this PR will always pick the one in the first node which is consistent with the raft one
Does this always hold true, Maybe i am missing the logic on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed, getting the ranks that were established after the comms initialization guarantee consistency.
@@ -102,8 +104,20 @@ def create(cls, data, client=None, batch_enabled=False): | |||
else: | |||
raise TypeError("Graph data must be dask-cudf dataframe") | |||
|
|||
broadcast_worker = None | |||
if batch_enabled: | |||
worker_ranks = client.run(_get_nvml_device_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we directly fetch from the upstream raft_dask
library to prevent breaking in the future too ?
worker_ranks = client.run(_get_nvml_device_index) | |
from raft_dask.common.comms import _func_worker_ranks | |
worker_ranks = _func_worker_ranks(client) |
def _get_nvml_device_index(): | ||
""" | ||
Return NVML device index based on environment variable | ||
'CUDA_VISIBLE_DEVICES'. | ||
""" | ||
# FIXME: Leverage the one from raft instead. | ||
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES") | ||
return nvml_device_index(0, CUDA_VISIBLE_DEVICES) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove it now
def _get_nvml_device_index(): | |
""" | |
Return NVML device index based on environment variable | |
'CUDA_VISIBLE_DEVICES'. | |
""" | |
# FIXME: Leverage the one from raft instead. | |
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES") | |
return nvml_device_index(0, CUDA_VISIBLE_DEVICES) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya. I was pushing this now in another commit
import os | ||
from dask_cuda.utils import nvml_device_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this now.
import os | |
from dask_cuda.utils import nvml_device_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
A Raft rapidsai/raft#1573 assigning deterministic ranks to dask workers was merged, breaking batch algorithms like batch_edge_betweenness_centrality by picking the wrong worker as the root for the broadcast operation.
This PR ensures that the worker with rank = 0 is the root of the broadcast operation.