Skip to content

Commit

Permalink
Test select_random_vertices for all possible values of flags (#4042)
Browse files Browse the repository at this point in the history
Test `select_random_vertices` for all possible values of flags.

Authors:
  - Naim (https://github.com/naimnv)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)
  - Chuck Hastings (https://github.com/ChuckHastings)

URL: #4042
  • Loading branch information
naimnv authored Dec 12, 2023
1 parent 7fe7bea commit 23e70a6
Showing 1 changed file with 86 additions and 69 deletions.
155 changes: 86 additions & 69 deletions cpp/tests/structure/mg_select_random_vertices_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class Tests_MGSelectRandomVertices
//

std::vector<bool> with_replacement_flags = {true, false};
std::vector<bool> sort_vertices_flags = {true, false};

{
// Generate distributed vertex set to sample from
std::srand((unsigned)std::chrono::duration_cast<std::chrono::milliseconds>(
Expand Down Expand Up @@ -107,80 +109,95 @@ class Tests_MGSelectRandomVertices
? select_random_vertices_usecase.select_count
: std::rand() % (num_of_elements_in_given_set + 1);

for (int idx = 0; idx < with_replacement_flags.size(); idx++) {
bool with_replacement = with_replacement_flags[idx];
auto d_sampled_vertices =
cugraph::select_random_vertices(*handle_,
mg_graph_view,
std::make_optional(raft::device_span<vertex_t const>{
d_given_set.data(), d_given_set.size()}),
rng_state,
select_count,
with_replacement,
true);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
for (int i = 0; i < with_replacement_flags.size(); i++) {
for (int j = 0; j < sort_vertices_flags.size(); j++) {
bool with_replacement = with_replacement_flags[i];
bool sort_vertices = sort_vertices_flags[j];

auto d_sampled_vertices =
cugraph::select_random_vertices(*handle_,
mg_graph_view,
std::make_optional(raft::device_span<vertex_t const>{
d_given_set.data(), d_given_set.size()}),
rng_state,
select_count,
with_replacement,
sort_vertices);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
}

std::sort(h_given_set.begin(), h_given_set.end());
if (sort_vertices) {
assert(std::is_sorted(h_sampled_vertices.begin(), h_sampled_vertices.end()));
} else {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());
}
std::for_each(
h_sampled_vertices.begin(), h_sampled_vertices.end(), [&h_given_set](vertex_t v) {
ASSERT_TRUE(std::binary_search(h_given_set.begin(), h_given_set.end(), v));
});
}

std::sort(h_given_set.begin(), h_given_set.end());
std::for_each(
h_sampled_vertices.begin(), h_sampled_vertices.end(), [&h_given_set](vertex_t v) {
ASSERT_TRUE(std::binary_search(h_given_set.begin(), h_given_set.end(), v));
});
}
}
}

//
// Test sampling from [0, V)
//

for (int idx = 0; idx < with_replacement_flags.size(); idx++) {
bool with_replacement = false;
auto d_sampled_vertices = cugraph::select_random_vertices(
*handle_,
mg_graph_view,
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
rng_state,
select_random_vertices_usecase.select_count,
with_replacement,
true);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
//
// Test sampling from [0, V)
//

for (int i = 0; i < with_replacement_flags.size(); i++) {
for (int j = 0; j < sort_vertices_flags.size(); j++) {
bool with_replacement = with_replacement_flags[i];
bool sort_vertices = sort_vertices_flags[j];

auto d_sampled_vertices = cugraph::select_random_vertices(
*handle_,
mg_graph_view,
std::optional<raft::device_span<vertex_t const>>{std::nullopt},
rng_state,
select_random_vertices_usecase.select_count,
with_replacement,
sort_vertices);

RAFT_CUDA_TRY(cudaDeviceSynchronize());

auto h_sampled_vertices = cugraph::test::to_host(*handle_, d_sampled_vertices);

if (select_random_vertices_usecase.check_correctness) {
if (!with_replacement) {
std::sort(h_sampled_vertices.begin(), h_sampled_vertices.end());

auto nr_duplicates =
std::distance(std::unique(h_sampled_vertices.begin(), h_sampled_vertices.end()),
h_sampled_vertices.end());

ASSERT_EQ(nr_duplicates, 0);
}
if (sort_vertices) {
assert(std::is_sorted(h_sampled_vertices.begin(), h_sampled_vertices.end()));
}

auto vertex_first = mg_graph_view.local_vertex_partition_range_first();
auto vertex_last = mg_graph_view.local_vertex_partition_range_last();
std::for_each(h_sampled_vertices.begin(),
h_sampled_vertices.end(),
[vertex_first, vertex_last](vertex_t v) {
ASSERT_TRUE((v >= vertex_first) && (v < vertex_last));
});
}
}

auto vertex_first = mg_graph_view.local_vertex_partition_range_first();
auto vertex_last = mg_graph_view.local_vertex_partition_range_last();

std::for_each(h_sampled_vertices.begin(),
h_sampled_vertices.end(),
[vertex_first, vertex_last](vertex_t v) {
ASSERT_TRUE((v >= vertex_first) && (v < vertex_last));
});
}
}
}
Expand Down

0 comments on commit 23e70a6

Please sign in to comment.