From 23e70a6f2c81416736ec4d9107088c2b5a4788cb Mon Sep 17 00:00:00 2001 From: Naim <110031745+naimnv@users.noreply.github.com> Date: Tue, 12 Dec 2023 19:42:58 +0100 Subject: [PATCH] Test select_random_vertices for all possible values of flags (#4042) Test `select_random_vertices` for all possible values of flags. Authors: - Naim (https://github.com/naimnv) Approvers: - Seunghwa Kang (https://github.com/seunghwak) - Chuck Hastings (https://github.com/ChuckHastings) URL: https://github.com/rapidsai/cugraph/pull/4042 --- .../mg_select_random_vertices_test.cpp | 155 ++++++++++-------- 1 file changed, 86 insertions(+), 69 deletions(-) diff --git a/cpp/tests/structure/mg_select_random_vertices_test.cpp b/cpp/tests/structure/mg_select_random_vertices_test.cpp index e49e1ebcb99..8392a6831ca 100644 --- a/cpp/tests/structure/mg_select_random_vertices_test.cpp +++ b/cpp/tests/structure/mg_select_random_vertices_test.cpp @@ -79,6 +79,8 @@ class Tests_MGSelectRandomVertices // std::vector with_replacement_flags = {true, false}; + std::vector sort_vertices_flags = {true, false}; + { // Generate distributed vertex set to sample from std::srand((unsigned)std::chrono::duration_cast( @@ -107,80 +109,95 @@ class Tests_MGSelectRandomVertices ? select_random_vertices_usecase.select_count : std::rand() % (num_of_elements_in_given_set + 1); - for (int idx = 0; idx < with_replacement_flags.size(); idx++) { - bool with_replacement = with_replacement_flags[idx]; - auto d_sampled_vertices = - cugraph::select_random_vertices(*handle_, - mg_graph_view, - std::make_optional(raft::device_span{ - d_given_set.data(), d_given_set.size()}), - rng_state, - select_count, - with_replacement, - true); - - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - - auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices); - - if (select_random_vertices_usecase.check_correctness) { - if (!with_replacement) { - std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end()); - - auto nr_duplicates = - std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()), - h_sampled_vertices.end()); - - ASSERT_EQ(nr_duplicates, 0); + for (int i = 0; i < with_replacement_flags.size(); i++) { + for (int j = 0; j < sort_vertices_flags.size(); j++) { + bool with_replacement = with_replacement_flags[i]; + bool sort_vertices = sort_vertices_flags[j]; + + auto d_sampled_vertices = + cugraph::select_random_vertices(*handle_, + mg_graph_view, + std::make_optional(raft::device_span{ + d_given_set.data(), d_given_set.size()}), + rng_state, + select_count, + with_replacement, + sort_vertices); + + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices); + + if (select_random_vertices_usecase.check_correctness) { + if (!with_replacement) { + std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end()); + + auto nr_duplicates = + std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()), + h_sampled_vertices.end()); + + ASSERT_EQ(nr_duplicates, 0); + } + + std::sort(h_given_set.begin(), h_given_set.end()); + if (sort_vertices) { + assert(std::is_sorted(h_sampled_vertices.begin(), h_sampled_vertices.end())); + } else { + std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end()); + } + std::for_each( + h_sampled_vertices.begin(), h_sampled_vertices.end(), [&h_given_set](vertex_t v) { + ASSERT_TRUE(std::binary_search(h_given_set.begin(), h_given_set.end(), v)); + }); } - - std::sort(h_given_set.begin(), h_given_set.end()); - std::for_each( - h_sampled_vertices.begin(), h_sampled_vertices.end(), [&h_given_set](vertex_t v) { - ASSERT_TRUE(std::binary_search(h_given_set.begin(), h_given_set.end(), v)); - }); } } - } - - // - // Test sampling from [0, V) - // - - for (int idx = 0; idx < with_replacement_flags.size(); idx++) { - bool with_replacement = false; - auto d_sampled_vertices = cugraph::select_random_vertices( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - rng_state, - select_random_vertices_usecase.select_count, - with_replacement, - true); - - RAFT_CUDA_TRY(cudaDeviceSynchronize()); - - auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices); - if (select_random_vertices_usecase.check_correctness) { - if (!with_replacement) { - std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end()); - - auto nr_duplicates = - std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()), - h_sampled_vertices.end()); - - ASSERT_EQ(nr_duplicates, 0); + // + // Test sampling from [0, V) + // + + for (int i = 0; i < with_replacement_flags.size(); i++) { + for (int j = 0; j < sort_vertices_flags.size(); j++) { + bool with_replacement = with_replacement_flags[i]; + bool sort_vertices = sort_vertices_flags[j]; + + auto d_sampled_vertices = cugraph::select_random_vertices( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + rng_state, + select_random_vertices_usecase.select_count, + with_replacement, + sort_vertices); + + RAFT_CUDA_TRY(cudaDeviceSynchronize()); + + auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices); + + if (select_random_vertices_usecase.check_correctness) { + if (!with_replacement) { + std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end()); + + auto nr_duplicates = + std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()), + h_sampled_vertices.end()); + + ASSERT_EQ(nr_duplicates, 0); + } + if (sort_vertices) { + assert(std::is_sorted(h_sampled_vertices.begin(), h_sampled_vertices.end())); + } + + auto vertex_first = mg_graph_view.local_vertex_partition_range_first(); + auto vertex_last = mg_graph_view.local_vertex_partition_range_last(); + std::for_each(h_sampled_vertices.begin(), + h_sampled_vertices.end(), + [vertex_first, vertex_last](vertex_t v) { + ASSERT_TRUE((v >= vertex_first) && (v < vertex_last)); + }); + } } - - auto vertex_first = mg_graph_view.local_vertex_partition_range_first(); - auto vertex_last = mg_graph_view.local_vertex_partition_range_last(); - - std::for_each(h_sampled_vertices.begin(), - h_sampled_vertices.end(), - [vertex_first, vertex_last](vertex_t v) { - ASSERT_TRUE((v >= vertex_first) && (v < vertex_last)); - }); } } }