From 4c0531def45b40431c56168989391c2985609f31 Mon Sep 17 00:00:00 2001 From: Chuck Hastings <45364586+ChuckHastings@users.noreply.github.com> Date: Tue, 24 May 2022 15:14:19 -0400 Subject: [PATCH] Fix uniform neighborhood sampling remove duplicates (#2301) @jnke2016 observed that the recent sampling code can return duplicates. After some discussion it was decided to remove the duplicates and return a count of how many duplicates were removed. The C++ implementation is updated in this PR to do that. The C API will remain the same but is ready to support it once we have stopped using the old implementation. Closes #2226 Authors: - Chuck Hastings (https://github.com/ChuckHastings) Approvers: - Seunghwa Kang (https://github.com/seunghwak) URL: https://github.com/rapidsai/cugraph/pull/2301 --- cpp/include/cugraph/algorithms.hpp | 31 +++-- cpp/src/c_api/uniform_neighbor_sampling.cpp | 11 +- cpp/src/sampling/detail/graph_functions.hpp | 10 ++ .../sampling/detail/sampling_utils_impl.cuh | 42 +++++++ cpp/src/sampling/detail/sampling_utils_sg.cu | 55 +++++++++ .../uniform_neighbor_sampling_impl.hpp | 27 +++-- .../sampling/uniform_neighbor_sampling_mg.cpp | 108 ++++++++++-------- .../sampling/uniform_neighbor_sampling_sg.cpp | 108 ++++++++++-------- .../sampling/mg_uniform_neighbor_sampling.cu | 2 +- .../sampling/sg_uniform_neighbor_sampling.cu | 2 +- 10 files changed, 272 insertions(+), 124 deletions(-) diff --git a/cpp/include/cugraph/algorithms.hpp b/cpp/include/cugraph/algorithms.hpp index eef2c1c5ed3..b5c62811bfd 100644 --- a/cpp/include/cugraph/algorithms.hpp +++ b/cpp/include/cugraph/algorithms.hpp @@ -1574,6 +1574,14 @@ uniform_nbr_sample(raft::handle_t const& handle, /** * @brief Uniform Neighborhood Sampling. * + * This function traverses from a set of starting vertices, traversing outgoing edges and + * randomly selects from these outgoing neighbors to extract a subgraph. + * + * Output from this function a set of tuples (src, dst, weight, count), identifying the randomly + * selected edges. src is the source vertex, dst is the destination vertex, weight is the weight + * of the edge and count identifies the number of times this edge was encountered during the + * sampling of this graph (so it is >= 1). + * * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. * @tparam edge_t Type of edge identifiers. Needs to be an integral type. * @tparam weight_t Type of edge weights. Needs to be a floating point type. @@ -1587,22 +1595,25 @@ uniform_nbr_sample(raft::handle_t const& handle, * @param with_replacement boolean flag specifying if random sampling is done with replacement * (true); or, without replacement (false); default = true; * @param seed A seed to initialize the random number generator - * @return tuple device vectors (vertex_t source_vertex, vertex_t destination_vertex, weight_t wgt) + * @return tuple device vectors (vertex_t source_vertex, vertex_t destination_vertex, weight_t + * weight, edge_t count) */ template -std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement = true, - uint64_t seed = 0); +std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement = true, + uint64_t seed = 0); /* * @brief Compute triangle counts. diff --git a/cpp/src/c_api/uniform_neighbor_sampling.cpp b/cpp/src/c_api/uniform_neighbor_sampling.cpp index fa27c28e8eb..612284c93c8 100644 --- a/cpp/src/c_api/uniform_neighbor_sampling.cpp +++ b/cpp/src/c_api/uniform_neighbor_sampling.cpp @@ -35,11 +35,13 @@ struct cugraph_sample_result_t { bool experimental_{true}; cugraph_type_erased_device_array_t* src_{nullptr}; cugraph_type_erased_device_array_t* dst_{nullptr}; - // FIXME: Will be deleted once experimental replaces curren + // FIXME: Will be deleted once experimental replaces current cugraph_type_erased_device_array_t* label_{nullptr}; cugraph_type_erased_device_array_t* index_{nullptr}; - // FIXME: Will be deleted once experimental replaces curren + // FIXME: Will be deleted once experimental replaces current cugraph_type_erased_host_array_t* count_{nullptr}; + // FIXME: Rename to count_ once experimental replaces current + cugraph_type_erased_device_array_t* experimental_count_{nullptr}; }; } // namespace c_api @@ -233,7 +235,7 @@ struct experimental_uniform_neighbor_sampling_functor : public cugraph::c_api::a graph_view.local_vertex_partition_range_last(), false); - auto&& [srcs, dsts, weights] = cugraph::uniform_nbr_sample( + auto&& [srcs, dsts, weights, counts] = cugraph::uniform_nbr_sample( handle_, graph_view, raft::device_span(start.data(), start.size()), @@ -262,7 +264,8 @@ struct experimental_uniform_neighbor_sampling_functor : public cugraph::c_api::a new cugraph::c_api::cugraph_type_erased_device_array_t(dsts, graph_->vertex_type_), nullptr, new cugraph::c_api::cugraph_type_erased_device_array_t(weights, graph_->weight_type_), - nullptr}; + nullptr, + new cugraph::c_api::cugraph_type_erased_device_array_t(counts, graph_->edge_type_)}; } } }; diff --git a/cpp/src/sampling/detail/graph_functions.hpp b/cpp/src/sampling/detail/graph_functions.hpp index d875958a6b9..f0b1580b88e 100644 --- a/cpp/src/sampling/detail/graph_functions.hpp +++ b/cpp/src/sampling/detail/graph_functions.hpp @@ -175,6 +175,16 @@ gather_one_hop_edgelist( GraphViewType const& graph_view, const rmm::device_uvector& active_majors); +template +std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +count_and_remove_duplicates(raft::handle_t const& handle, + rmm::device_uvector&& src, + rmm::device_uvector&& dst, + rmm::device_uvector&& wgt); + } // namespace detail } // namespace cugraph diff --git a/cpp/src/sampling/detail/sampling_utils_impl.cuh b/cpp/src/sampling/detail/sampling_utils_impl.cuh index 65bd3e660d6..2e4ced78897 100644 --- a/cpp/src/sampling/detail/sampling_utils_impl.cuh +++ b/cpp/src/sampling/detail/sampling_utils_impl.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -907,5 +908,46 @@ gather_one_hop_edgelist( return std::make_tuple(std::move(majors), std::move(minors), std::move(weights)); } +template +std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +count_and_remove_duplicates(raft::handle_t const& handle, + rmm::device_uvector&& src, + rmm::device_uvector&& dst, + rmm::device_uvector&& wgt) +{ + auto tuple_iter_begin = + thrust::make_zip_iterator(thrust::make_tuple(src.begin(), dst.begin(), wgt.begin())); + + thrust::sort(handle.get_thrust_policy(), tuple_iter_begin, tuple_iter_begin + src.size()); + + auto num_uniques = + thrust::count_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(size_t{0}), + thrust::make_counting_iterator(src.size()), + detail::is_first_in_run_pair_t{src.data(), dst.data()}); + + rmm::device_uvector result_src(num_uniques, handle.get_stream()); + rmm::device_uvector result_dst(num_uniques, handle.get_stream()); + rmm::device_uvector result_wgt(num_uniques, handle.get_stream()); + rmm::device_uvector result_count(num_uniques, handle.get_stream()); + + rmm::device_uvector count(src.size(), handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), count.begin(), count.end(), edge_t{1}); + + thrust::reduce_by_key(handle.get_thrust_policy(), + tuple_iter_begin, + tuple_iter_begin + src.size(), + count.begin(), + thrust::make_zip_iterator(thrust::make_tuple( + result_src.begin(), result_dst.begin(), result_wgt.begin())), + result_count.begin()); + + return std::make_tuple( + std::move(result_src), std::move(result_dst), std::move(result_wgt), std::move(result_count)); +} + } // namespace detail } // namespace cugraph diff --git a/cpp/src/sampling/detail/sampling_utils_sg.cu b/cpp/src/sampling/detail/sampling_utils_sg.cu index 64778511391..52f2f9245b9 100644 --- a/cpp/src/sampling/detail/sampling_utils_sg.cu +++ b/cpp/src/sampling/detail/sampling_utils_sg.cu @@ -229,5 +229,60 @@ gather_one_hop_edgelist(raft::handle_t const& handle, graph_view_t const& graph_view, rmm::device_uvector const& active_majors); +// Only need to build once, not separately for SG/MG +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +count_and_remove_duplicates(raft::handle_t const& handle, + rmm::device_uvector&& src, + rmm::device_uvector&& dst, + rmm::device_uvector&& wgt); + +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +count_and_remove_duplicates(raft::handle_t const& handle, + rmm::device_uvector&& src, + rmm::device_uvector&& dst, + rmm::device_uvector&& wgt); + +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +count_and_remove_duplicates(raft::handle_t const& handle, + rmm::device_uvector&& src, + rmm::device_uvector&& dst, + rmm::device_uvector&& wgt); + +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +count_and_remove_duplicates(raft::handle_t const& handle, + rmm::device_uvector&& src, + rmm::device_uvector&& dst, + rmm::device_uvector&& wgt); + +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +count_and_remove_duplicates(raft::handle_t const& handle, + rmm::device_uvector&& src, + rmm::device_uvector&& dst, + rmm::device_uvector&& wgt); + +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +count_and_remove_duplicates(raft::handle_t const& handle, + rmm::device_uvector&& src, + rmm::device_uvector&& dst, + rmm::device_uvector&& wgt); + } // namespace detail } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_impl.hpp b/cpp/src/sampling/uniform_neighbor_sampling_impl.hpp index 0847abf9556..a7a7fdc24d6 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_impl.hpp +++ b/cpp/src/sampling/uniform_neighbor_sampling_impl.hpp @@ -41,7 +41,8 @@ namespace detail { template std::tuple, rmm::device_uvector, - rmm::device_uvector> + rmm::device_uvector, + rmm::device_uvector> uniform_nbr_sample_impl( raft::handle_t const& handle, graph_view_t const& graph_view, @@ -150,8 +151,8 @@ uniform_nbr_sample_impl( ++level; } - return std::make_tuple( - std::move(d_result_src), std::move(d_result_dst), std::move(*d_result_indices)); + return count_and_remove_duplicates( + handle, std::move(d_result_src), std::move(d_result_dst), std::move(*d_result_indices)); } } // namespace detail @@ -160,15 +161,17 @@ template -std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample( - raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed) +std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample( + raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed) { rmm::device_uvector d_start_vs(starting_vertices.size(), handle.get_stream()); raft::copy( diff --git a/cpp/src/sampling/uniform_neighbor_sampling_mg.cpp b/cpp/src/sampling/uniform_neighbor_sampling_mg.cpp index a569703dbc1..ad8d1a9f0f1 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_mg.cpp +++ b/cpp/src/sampling/uniform_neighbor_sampling_mg.cpp @@ -20,58 +20,70 @@ namespace cugraph { -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); } // namespace cugraph diff --git a/cpp/src/sampling/uniform_neighbor_sampling_sg.cpp b/cpp/src/sampling/uniform_neighbor_sampling_sg.cpp index eb706271a57..ad4442ecae1 100644 --- a/cpp/src/sampling/uniform_neighbor_sampling_sg.cpp +++ b/cpp/src/sampling/uniform_neighbor_sampling_sg.cpp @@ -20,58 +20,70 @@ namespace cugraph { -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); -template std:: - tuple, rmm::device_uvector, rmm::device_uvector> - uniform_nbr_sample(raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span starting_vertices, - raft::host_span fan_out, - bool with_replacement, - uint64_t seed); +template std::tuple, + rmm::device_uvector, + rmm::device_uvector, + rmm::device_uvector> +uniform_nbr_sample(raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::device_span starting_vertices, + raft::host_span fan_out, + bool with_replacement, + uint64_t seed); } // namespace cugraph diff --git a/cpp/tests/sampling/mg_uniform_neighbor_sampling.cu b/cpp/tests/sampling/mg_uniform_neighbor_sampling.cu index a793348f2db..0388f3ab7ec 100644 --- a/cpp/tests/sampling/mg_uniform_neighbor_sampling.cu +++ b/cpp/tests/sampling/mg_uniform_neighbor_sampling.cu @@ -95,7 +95,7 @@ class Tests_MG_Nbr_Sampling std::vector h_fan_out{indices_per_source}; // depth = 1 - auto&& [d_src_out, d_dst_out, d_indices] = cugraph::uniform_nbr_sample( + auto&& [d_src_out, d_dst_out, d_indices, d_counts] = cugraph::uniform_nbr_sample( handle, mg_graph_view, raft::device_span(random_sources.data(), random_sources.size()), diff --git a/cpp/tests/sampling/sg_uniform_neighbor_sampling.cu b/cpp/tests/sampling/sg_uniform_neighbor_sampling.cu index cad89e51e4f..20197fe629b 100644 --- a/cpp/tests/sampling/sg_uniform_neighbor_sampling.cu +++ b/cpp/tests/sampling/sg_uniform_neighbor_sampling.cu @@ -74,7 +74,7 @@ class Tests_Uniform_Neighbor_Sampling std::vector h_fan_out{indices_per_source}; // depth = 1 - auto&& [d_src_out, d_dst_out, d_indices] = cugraph::uniform_nbr_sample( + auto&& [d_src_out, d_dst_out, d_indices, d_counts] = cugraph::uniform_nbr_sample( handle, graph_view, raft::device_span(random_sources.data(), random_sources.size()),