Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raft Handle Updates to cuGraph #1894

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda/recipes/pylibcugraph/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ requirements:
- ucx-py 0.23
- ucx-proc=*=gpu
- cudatoolkit {{ cuda_version }}.*
- rmm {{ minor_version }}.*
run:
- python x.x
- libcugraph={{ version }}
Expand Down
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ set(CUGRAPH_BRANCH_VERSION_raft "${CUGRAPH_VERSION_MAJOR}.${CUGRAPH_VERSION_MINO
# To use a different RAFT locally, set the CMake variable
# RPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${CUGRAPH_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${CUGRAPH_BRANCH_VERSION_raft}
FORK divyegala
PINNED_TAG imp-21.10-handle_stream
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am waiting for this PR to pull raft from the main repo to approve this.

)
32 changes: 16 additions & 16 deletions cpp/include/cugraph/prims/copy_to_adj_matrix_row_col.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ void copy_to_matrix_major(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -128,9 +128,9 @@ void copy_to_matrix_major(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
} else {
Expand Down Expand Up @@ -175,9 +175,9 @@ void copy_to_matrix_major(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -266,9 +266,9 @@ void copy_to_matrix_major(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
} else {
Expand Down Expand Up @@ -310,9 +310,9 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -368,9 +368,9 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
} else {
Expand Down Expand Up @@ -415,9 +415,9 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -504,9 +504,9 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
} else {
Expand Down
16 changes: 8 additions & 8 deletions cpp/include/cugraph/prims/copy_v_transform_reduce_in_out_nbr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,9 @@ void copy_v_transform_reduce_nbr(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand All @@ -627,9 +627,9 @@ void copy_v_transform_reduce_nbr(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand All @@ -650,9 +650,9 @@ void copy_v_transform_reduce_nbr(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand All @@ -674,9 +674,9 @@ void copy_v_transform_reduce_nbr(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -278,7 +278,7 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
handle.get_stream());
}

handle.get_stream_view().synchronize(); // cuco::static_map currently does not take stream
handle.get_stream().synchronize(); // cuco::static_map currently does not take stream

kv_map_ptr.reset();

Expand All @@ -296,7 +296,7 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
thrust::make_tuple(map_keys.begin(), get_dataframe_buffer_begin(map_value_buffer)));
kv_map_ptr->insert(pair_first, pair_first + map_keys.size());
} else {
handle.get_stream_view().synchronize(); // cuco::static_map currently does not take stream
handle.get_stream().synchronize(); // cuco::static_map currently does not take stream

kv_map_ptr.reset();

Expand Down Expand Up @@ -328,9 +328,9 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down Expand Up @@ -554,9 +554,9 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down
29 changes: 14 additions & 15 deletions cpp/include/cugraph/prims/update_frontier_v_push_if_out_nbr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,9 @@ typename GraphViewType::edge_type compute_num_out_nbrs_from_frontier(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down Expand Up @@ -652,8 +652,7 @@ typename GraphViewType::edge_type compute_num_out_nbrs_from_frontier(
auto& col_comm = handle.get_subcomm(cugraph::partition_2d::key_naming_t().col_name());
auto const col_comm_rank = col_comm.get_rank();

rmm::device_uvector<vertex_t> frontier_vertices(local_frontier_sizes[i],
handle.get_stream_view());
rmm::device_uvector<vertex_t> frontier_vertices(local_frontier_sizes[i], handle.get_stream());
device_bcast(col_comm,
local_frontier_vertex_first,
frontier_vertices.data(),
Expand Down Expand Up @@ -726,9 +725,9 @@ typename GraphViewType::edge_type compute_num_out_nbrs_from_frontier(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down Expand Up @@ -848,9 +847,9 @@ void update_frontier_v_push_if_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down Expand Up @@ -1105,9 +1104,9 @@ void update_frontier_v_push_if_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand All @@ -1133,9 +1132,9 @@ void update_frontier_v_push_if_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -1166,7 +1165,7 @@ void update_frontier_v_push_if_out_nbr(
d_tx_buffer_last_boundaries.data(),
d_tx_buffer_last_boundaries.size(),
handle.get_stream());
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
std::vector<size_t> tx_counts(h_tx_buffer_last_boundaries.size());
std::adjacent_difference(
h_tx_buffer_last_boundaries.begin(), h_tx_buffer_last_boundaries.end(), tx_counts.begin());
Expand Down Expand Up @@ -1195,9 +1194,9 @@ void update_frontier_v_push_if_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cugraph/prims/vertex_frontier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ class VertexFrontier {
h_indices.data(), d_indices.data(), d_indices.size(), handle_ptr_->get_stream());
raft::update_host(
h_counts.data(), d_counts.data(), d_counts.size(), handle_ptr_->get_stream());
handle_ptr_->get_stream_view().synchronize();
handle_ptr_->get_stream().synchronize();

size_t offset{0};
for (size_t i = 0; i < h_indices.size(); ++i) {
Expand Down
10 changes: 4 additions & 6 deletions cpp/src/community/legacy/ecg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,10 @@ class EcgLouvain : public cugraph::legacy::Louvain<graph_type> {

void initialize_dendrogram_level(vertex_t num_vertices) override
{
this->dendrogram_->add_level(0, num_vertices, this->handle_.get_stream_view());
this->dendrogram_->add_level(0, num_vertices, this->handle_.get_stream());

get_permutation_vector(num_vertices,
seed_,
this->dendrogram_->current_level_begin(),
this->handle_.get_stream_view());
get_permutation_vector(
num_vertices, seed_, this->dendrogram_->current_level_begin(), this->handle_.get_stream());
}

private:
Expand All @@ -147,7 +145,7 @@ void ecg(raft::handle_t const& handle,
"Invalid input argument: clustering is NULL, should be a device pointer to "
"memory for storing the result");

rmm::device_uvector<weight_t> ecg_weights_v(graph.number_of_edges, handle.get_stream_view());
rmm::device_uvector<weight_t> ecg_weights_v(graph.number_of_edges, handle.get_stream());

thrust::copy(handle.get_thrust_policy(),
graph.edge_data,
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/community/legacy/egonet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ extract(raft::handle_t const& handle,
vertex_t radius)
{
auto v = csr_view.get_number_of_vertices();
auto user_stream_view = handle.get_stream_view();
auto user_stream_view = handle.get_stream();
rmm::device_vector<size_t> neighbors_offsets(n_subgraphs + 1);
rmm::device_vector<vertex_t> neighbors;

Expand All @@ -77,7 +77,7 @@ extract(raft::handle_t const& handle,
reached.reserve(n_subgraphs);
for (vertex_t i = 0; i < n_subgraphs; i++) {
// Allocations and operations are attached to the worker stream
rmm::device_uvector<vertex_t> local_reach(v, handle.get_internal_stream_view(i));
rmm::device_uvector<vertex_t> local_reach(v, handle.get_next_usable_stream(i));
reached.push_back(std::move(local_reach));
}

Expand All @@ -89,8 +89,8 @@ extract(raft::handle_t const& handle,

for (vertex_t i = 0; i < n_subgraphs; i++) {
// get light handle from worker pool
raft::handle_t light_handle(handle, i);
auto worker_stream_view = light_handle.get_stream_view();
raft::handle_t light_handle(handle.get_next_usable_stream(i));
auto worker_stream_view = light_handle.get_stream();

// BFS with cutoff
// consider adding a device API to BFS (ie. accept source on the device)
Expand Down Expand Up @@ -132,7 +132,7 @@ extract(raft::handle_t const& handle,
}

// wait on every one to identify their neighboors before proceeding to concatenation
handle.wait_on_internal_streams();
handle.sync_stream_pool();

// Construct neighboors offsets (just a scan on neighborhod vector sizes)
h_neighbors_offsets[0] = 0;
Expand All @@ -148,7 +148,7 @@ extract(raft::handle_t const& handle,

// Construct the neighboors list concurrently
for (vertex_t i = 0; i < n_subgraphs; i++) {
auto worker_stream_view = handle.get_internal_stream_view(i);
auto worker_stream_view = handle.get_next_usable_stream(i);
thrust::copy(rmm::exec_policy(worker_stream_view),
reached[i].begin(),
reached[i].end(),
Expand All @@ -160,7 +160,7 @@ extract(raft::handle_t const& handle,
}

// wait on every one before proceeding to grouped extraction
handle.wait_on_internal_streams();
handle.sync_stream_pool();

#ifdef TIMING
hr_timer.stop();
Expand Down
Loading