Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raft Handle Updates to cuGraph #1894

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ function(find_and_configure_raft)
BUILD_EXPORT_SET cugraph-exports
INSTALL_EXPORT_SET cugraph-exports
CPM_ARGS
GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git
GIT_TAG ${PKG_PINNED_TAG}
GIT_REPOSITORY https://github.com/divyegala/raft.git
divyegala marked this conversation as resolved.
Show resolved Hide resolved
GIT_TAG imp-21.10-handle_stream
SOURCE_SUBDIR cpp
OPTIONS "BUILD_TESTS OFF"
)
Expand Down
32 changes: 16 additions & 16 deletions cpp/include/cugraph/prims/copy_to_adj_matrix_row_col.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ void copy_to_matrix_major(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -128,9 +128,9 @@ void copy_to_matrix_major(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
} else {
Expand Down Expand Up @@ -175,9 +175,9 @@ void copy_to_matrix_major(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -266,9 +266,9 @@ void copy_to_matrix_major(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
} else {
Expand Down Expand Up @@ -310,9 +310,9 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -368,9 +368,9 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
} else {
Expand Down Expand Up @@ -415,9 +415,9 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -504,9 +504,9 @@ void copy_to_matrix_minor(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
} else {
Expand Down
16 changes: 8 additions & 8 deletions cpp/include/cugraph/prims/copy_v_transform_reduce_in_out_nbr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,9 @@ void copy_v_transform_reduce_nbr(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand All @@ -627,9 +627,9 @@ void copy_v_transform_reduce_nbr(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand All @@ -650,9 +650,9 @@ void copy_v_transform_reduce_nbr(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand All @@ -674,9 +674,9 @@ void copy_v_transform_reduce_nbr(raft::handle_t const& handle,
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -278,7 +278,7 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
handle.get_stream());
}

handle.get_stream_view().synchronize(); // cuco::static_map currently does not take stream
handle.get_stream().synchronize(); // cuco::static_map currently does not take stream

kv_map_ptr.reset();

Expand All @@ -296,7 +296,7 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
thrust::make_tuple(map_keys.begin(), get_dataframe_buffer_begin(map_value_buffer)));
kv_map_ptr->insert(pair_first, pair_first + map_keys.size());
} else {
handle.get_stream_view().synchronize(); // cuco::static_map currently does not take stream
handle.get_stream().synchronize(); // cuco::static_map currently does not take stream

kv_map_ptr.reset();

Expand Down Expand Up @@ -328,9 +328,9 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down Expand Up @@ -554,9 +554,9 @@ void copy_v_transform_reduce_key_aggregated_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down
29 changes: 14 additions & 15 deletions cpp/include/cugraph/prims/update_frontier_v_push_if_out_nbr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,9 @@ typename GraphViewType::edge_type compute_num_out_nbrs_from_frontier(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down Expand Up @@ -652,8 +652,7 @@ typename GraphViewType::edge_type compute_num_out_nbrs_from_frontier(
auto& col_comm = handle.get_subcomm(cugraph::partition_2d::key_naming_t().col_name());
auto const col_comm_rank = col_comm.get_rank();

rmm::device_uvector<vertex_t> frontier_vertices(local_frontier_sizes[i],
handle.get_stream_view());
rmm::device_uvector<vertex_t> frontier_vertices(local_frontier_sizes[i], handle.get_stream());
device_bcast(col_comm,
local_frontier_vertex_first,
frontier_vertices.data(),
Expand Down Expand Up @@ -726,9 +725,9 @@ typename GraphViewType::edge_type compute_num_out_nbrs_from_frontier(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down Expand Up @@ -848,9 +847,9 @@ void update_frontier_v_push_if_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down Expand Up @@ -1105,9 +1104,9 @@ void update_frontier_v_push_if_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand All @@ -1133,9 +1132,9 @@ void update_frontier_v_push_if_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif

Expand Down Expand Up @@ -1166,7 +1165,7 @@ void update_frontier_v_push_if_out_nbr(
d_tx_buffer_last_boundaries.data(),
d_tx_buffer_last_boundaries.size(),
handle.get_stream());
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
std::vector<size_t> tx_counts(h_tx_buffer_last_boundaries.size());
std::adjacent_difference(
h_tx_buffer_last_boundaries.begin(), h_tx_buffer_last_boundaries.end(), tx_counts.begin());
Expand Down Expand Up @@ -1195,9 +1194,9 @@ void update_frontier_v_push_if_out_nbr(
#if 1
// FIXME: temporary hack till UCC is integrated into RAFT (so we can use UCC barrier with DASK
// and MPI barrier with MPI)
host_barrier(comm, handle.get_stream_view());
host_barrier(comm, handle.get_stream());
#else
handle.get_stream_view().synchronize();
handle.get_stream().synchronize();
comm.barrier(); // currently, this is ncclAllReduce
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cugraph/prims/vertex_frontier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ class VertexFrontier {
h_indices.data(), d_indices.data(), d_indices.size(), handle_ptr_->get_stream());
raft::update_host(
h_counts.data(), d_counts.data(), d_counts.size(), handle_ptr_->get_stream());
handle_ptr_->get_stream_view().synchronize();
handle_ptr_->get_stream().synchronize();

size_t offset{0};
for (size_t i = 0; i < h_indices.size(); ++i) {
Expand Down
10 changes: 4 additions & 6 deletions cpp/src/community/legacy/ecg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,10 @@ class EcgLouvain : public cugraph::legacy::Louvain<graph_type> {

void initialize_dendrogram_level(vertex_t num_vertices) override
{
this->dendrogram_->add_level(0, num_vertices, this->handle_.get_stream_view());
this->dendrogram_->add_level(0, num_vertices, this->handle_.get_stream());

get_permutation_vector(num_vertices,
seed_,
this->dendrogram_->current_level_begin(),
this->handle_.get_stream_view());
get_permutation_vector(
num_vertices, seed_, this->dendrogram_->current_level_begin(), this->handle_.get_stream());
}

private:
Expand All @@ -147,7 +145,7 @@ void ecg(raft::handle_t const& handle,
"Invalid input argument: clustering is NULL, should be a device pointer to "
"memory for storing the result");

rmm::device_uvector<weight_t> ecg_weights_v(graph.number_of_edges, handle.get_stream_view());
rmm::device_uvector<weight_t> ecg_weights_v(graph.number_of_edges, handle.get_stream());

thrust::copy(handle.get_thrust_policy(),
graph.edge_data,
Expand Down
14 changes: 7 additions & 7 deletions cpp/src/community/legacy/egonet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ extract(raft::handle_t const& handle,
vertex_t radius)
{
auto v = csr_view.get_number_of_vertices();
auto user_stream_view = handle.get_stream_view();
auto user_stream_view = handle.get_stream();
rmm::device_vector<size_t> neighbors_offsets(n_subgraphs + 1);
rmm::device_vector<vertex_t> neighbors;

Expand All @@ -77,7 +77,7 @@ extract(raft::handle_t const& handle,
reached.reserve(n_subgraphs);
for (vertex_t i = 0; i < n_subgraphs; i++) {
// Allocations and operations are attached to the worker stream
rmm::device_uvector<vertex_t> local_reach(v, handle.get_internal_stream_view(i));
rmm::device_uvector<vertex_t> local_reach(v, handle.get_next_usable_stream(i));
reached.push_back(std::move(local_reach));
}

Expand All @@ -89,8 +89,8 @@ extract(raft::handle_t const& handle,

for (vertex_t i = 0; i < n_subgraphs; i++) {
// get light handle from worker pool
raft::handle_t light_handle(handle, i);
auto worker_stream_view = light_handle.get_stream_view();
raft::handle_t light_handle(handle.get_next_usable_stream(i));
auto worker_stream_view = light_handle.get_stream();

// BFS with cutoff
// consider adding a device API to BFS (ie. accept source on the device)
Expand Down Expand Up @@ -132,7 +132,7 @@ extract(raft::handle_t const& handle,
}

// wait on every one to identify their neighboors before proceeding to concatenation
handle.wait_on_internal_streams();
handle.sync_stream_pool();

// Construct neighboors offsets (just a scan on neighborhod vector sizes)
h_neighbors_offsets[0] = 0;
Expand All @@ -148,7 +148,7 @@ extract(raft::handle_t const& handle,

// Construct the neighboors list concurrently
for (vertex_t i = 0; i < n_subgraphs; i++) {
auto worker_stream_view = handle.get_internal_stream_view(i);
auto worker_stream_view = handle.get_next_usable_stream(i);
thrust::copy(rmm::exec_policy(worker_stream_view),
reached[i].begin(),
reached[i].end(),
Expand All @@ -160,7 +160,7 @@ extract(raft::handle_t const& handle,
}

// wait on every one before proceeding to grouped extraction
handle.wait_on_internal_streams();
handle.sync_stream_pool();

#ifdef TIMING
hr_timer.stop();
Expand Down
Loading