Skip to content

Commit

Permalink
Call New replicate_edgelist Function (#4441)
Browse files Browse the repository at this point in the history
Closes #4440

This PR updates `enable_batch` to use the updated implementation for `replicate_edgelist`.

Authors:
  - Ralph Liu (https://github.com/nv-rliu)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)
  - Joseph Nke (https://github.com/jnke2016)

URL: #4441
  • Loading branch information
nv-rliu authored May 24, 2024
1 parent e6c842f commit 6c333d3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

from cugraph.structure import graph_primtypes_wrapper
from cugraph.structure.replicate_edgelist import replicate_cudf_dataframe
from cugraph.structure.symmetrize import symmetrize
from cugraph.structure.number_map import NumberMap
import cugraph.dask.common.mg_utils as mg_utils
Expand Down Expand Up @@ -680,16 +681,12 @@ def enable_batch(self):

def _replicate_edgelist(self):
client = mg_utils.get_client()
comms = Comms.get_comms()

# FIXME: There might be a better way to control it
if client is None:
return
work_futures = replication.replicate_cudf_dataframe(
self.edgelist.edgelist_df, client=client, comms=comms
)

self.batch_edgelists = work_futures
self.batch_edgelists = replicate_cudf_dataframe(self.edgelist.edgelist_df)

def _replicate_adjlist(self):
client = mg_utils.get_client()
Expand Down
8 changes: 2 additions & 6 deletions python/cugraph/cugraph/structure/replicate_edgelist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -269,7 +269,6 @@ def replicate_cudf_dataframe(cudf_dataframe):
)

_client = default_client()

if not isinstance(cudf_dataframe, dask_cudf.DataFrame):
if isinstance(cudf_dataframe, cudf.DataFrame):
df = dask_cudf.from_cudf(
Expand All @@ -287,10 +286,7 @@ def replicate_cudf_dataframe(cudf_dataframe):
df = get_persisted_df_worker_map(df, _client)

ddf = _mg_call_plc_replicate(
_client,
Comms.get_session_id(),
df,
"dataframe",
_client, Comms.get_session_id(), df, "dataframe", cudf_dataframe.columns
)

return ddf
Expand Down

0 comments on commit 6c333d3

Please sign in to comment.