Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcast root during the replication call #3655

Merged

Conversation

jnke2016
Copy link
Contributor

A Raft rapidsai/raft#1573 assigning deterministic ranks to dask workers was merged, breaking batch algorithms like batch_edge_betweenness_centrality by picking the wrong worker as the root for the broadcast operation.
This PR ensures that the worker with rank = 0 is the root of the broadcast operation.

@jnke2016 jnke2016 requested a review from a team as a code owner June 10, 2023 00:39
@rlratzel rlratzel added bug Something isn't working non-breaking Non-breaking change labels Jun 10, 2023
@rlratzel rlratzel requested a review from VibhuJawa June 10, 2023 03:51
@@ -102,8 +104,20 @@ def create(cls, data, client=None, batch_enabled=False):
else:
raise TypeError("Graph data must be dask-cudf dataframe")

broadcast_worker = None
if batch_enabled:
worker_ranks = client.run(_get_nvml_device_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will fail on MNMG setups. The worker_rank here gets nvml_device_index which will be zero for 2 devices (the first device on both nodes) .

I think we should figure out a way to get this from raft instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we directly fetch from the upstream raft_dask library to prevent breaking in the future too ?

Suggested change
worker_ranks = client.run(_get_nvml_device_index)
from raft_dask.common.comms import _func_worker_ranks
worker_ranks = _func_worker_ranks(client)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That PR is not merged yet. That is why I added a FIXME in line 286

Copy link
Contributor Author

@jnke2016 jnke2016 Jun 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That PR is not merged yet. That is why I added a FIXME in line 286

No need to wait for your PR to be merged because _func_worker_ranks already exits. Its implementation will just change once your PR is merge.

I think this will fail on MNMG setups

I don't think this can fail because the address of rank 0 in this PR always matches the raft one. In fact, client.run(_get_nvml_device_index) is deterministic and always returns the result in increasing network ID or IP address number(the keys of the dictionary). On a MNMG run, even though there are multiple pairs with rank 0, this PR will always pick the one in the first node which is consistent with the raft one. This is true because the raft PR only applies the rank offsets to the second node and above. Furthermore, I extensively tested this PR on Friday on both 2 and 4 nodes and all runs passed.

How about we directly fetch from the upstream raft_dask library to prevent breaking in the future too ?

Right and this is what I had in mind when I added a FIXME. This will avoid code duplication and prevent breaking changes.

Copy link
Member

@VibhuJawa VibhuJawa Jun 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a MNMG run, even though there are multiple pairs with rank 0, this PR will always pick the one in the first node which is consistent with the raft one

Does this always hold true, Maybe i am missing the logic on this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, getting the ranks that were established after the comms initialization guarantee consistency.

@@ -102,8 +104,20 @@ def create(cls, data, client=None, batch_enabled=False):
else:
raise TypeError("Graph data must be dask-cudf dataframe")

broadcast_worker = None
if batch_enabled:
worker_ranks = client.run(_get_nvml_device_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we directly fetch from the upstream raft_dask library to prevent breaking in the future too ?

Suggested change
worker_ranks = client.run(_get_nvml_device_index)
from raft_dask.common.comms import _func_worker_ranks
worker_ranks = _func_worker_ranks(client)

@BradReesWork BradReesWork added this to the 23.06 milestone Jun 12, 2023
Comment on lines 280 to 287
def _get_nvml_device_index():
"""
Return NVML device index based on environment variable
'CUDA_VISIBLE_DEVICES'.
"""
# FIXME: Leverage the one from raft instead.
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES")
return nvml_device_index(0, CUDA_VISIBLE_DEVICES)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove it now

Suggested change
def _get_nvml_device_index():
"""
Return NVML device index based on environment variable
'CUDA_VISIBLE_DEVICES'.
"""
# FIXME: Leverage the one from raft instead.
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES")
return nvml_device_index(0, CUDA_VISIBLE_DEVICES)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya. I was pushing this now in another commit

Comment on lines 20 to 21
import os
from dask_cuda.utils import nvml_device_index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this now.

Suggested change
import os
from dask_cuda.utils import nvml_device_index

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

Copy link
Member

@VibhuJawa VibhuJawa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@raydouglass raydouglass merged commit 653bbd5 into rapidsai:branch-23.06 Jun 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants