diff --git a/cpp/include/raft/sparse/mst/detail/mst_kernels.cuh b/cpp/include/raft/sparse/mst/detail/mst_kernels.cuh index 895927d6e1..d2f86d6dc8 100644 --- a/cpp/include/raft/sparse/mst/detail/mst_kernels.cuh +++ b/cpp/include/raft/sparse/mst/detail/mst_kernels.cuh @@ -107,9 +107,9 @@ __global__ void min_edge_per_supervertex( const vertex_t* color, const vertex_t* color_index, edge_t* new_mst_edge, bool* mst_edge, const vertex_t* indices, const weight_t* weights, const weight_t* altered_weights, vertex_t* temp_src, vertex_t* temp_dst, - weight_t* temp_weights, const weight_t* min_edge_color, const vertex_t v) { + weight_t* temp_weights, const weight_t* min_edge_color, const vertex_t v, + bool symmetrize_output) { auto tid = get_1D_idx(); - if (tid < v) { vertex_t vertex_color_idx = color_index[tid]; vertex_t vertex_color = color[vertex_color_idx]; @@ -120,13 +120,38 @@ __global__ void min_edge_per_supervertex( // if yes, that is part of mst if (edge_idx != std::numeric_limits::max()) { weight_t vertex_weight = altered_weights[edge_idx]; + + bool add_edge = false; if (min_edge_color[vertex_color] == vertex_weight) { - temp_src[tid] = tid; - temp_dst[tid] = indices[edge_idx]; - temp_weights[tid] = weights[edge_idx]; + add_edge = true; + + auto dst = indices[edge_idx]; + if (!symmetrize_output) { + auto dst_edge_idx = new_mst_edge[dst]; + auto dst_color = color[color_index[dst]]; + + // vertices added each other + // only if destination has found an edge + // the edge points back to source + // the edge is minimum edge found for dst color + if (dst_edge_idx != std::numeric_limits::max() && + indices[dst_edge_idx] == tid && + min_edge_color[dst_color] == altered_weights[dst_edge_idx]) { + if (vertex_color > dst_color) { + add_edge = false; + } + } + } - mst_edge[edge_idx] = true; - } else { + if (add_edge) { + temp_src[tid] = tid; + temp_dst[tid] = dst; + temp_weights[tid] = weights[edge_idx]; + mst_edge[edge_idx] = true; + } + } + + if (!add_edge) { new_mst_edge[tid] = std::numeric_limits::max(); } } @@ -138,7 +163,7 @@ __global__ void add_reverse_edge(const edge_t* new_mst_edge, const vertex_t* indices, const weight_t* weights, vertex_t* temp_src, vertex_t* temp_dst, weight_t* temp_weights, - const vertex_t v) { + const vertex_t v, bool symmetrize_output) { auto tid = get_1D_idx(); if (tid < v) { @@ -155,12 +180,14 @@ __global__ void add_reverse_edge(const edge_t* new_mst_edge, reverse_needed = true; } else { // check what vertex the neighbor vertex picked - vertex_t neighbor_vertex_neighbor = indices[neighbor_edge_idx]; - - // if vertices did not pick each other - // add a reverse edge - if (tid != neighbor_vertex_neighbor) { - reverse_needed = true; + if (symmetrize_output) { + vertex_t neighbor_vertex_neighbor = indices[neighbor_edge_idx]; + + // if vertices did not pick each other + // add a reverse edge + if (tid != neighbor_vertex_neighbor) { + reverse_needed = true; + } } } diff --git a/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh b/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh index 2ae4d93113..d3a82ca711 100644 --- a/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh +++ b/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh @@ -39,11 +39,11 @@ typedef std::chrono::high_resolution_clock Clock; // curand generator uniform inline curandStatus_t curand_generate_uniformX(curandGenerator_t generator, - float* outputPtr, size_t n) { + float* outputPtr, size_t n) { return curandGenerateUniform(generator, outputPtr, n); } inline curandStatus_t curand_generate_uniformX(curandGenerator_t generator, - double* outputPtr, size_t n) { + double* outputPtr, size_t n) { return curandGenerateUniformDouble(generator, outputPtr, n); } @@ -51,7 +51,8 @@ template MST_solver::MST_solver( const raft::handle_t& handle_, const edge_t* offsets_, const vertex_t* indices_, const weight_t* weights_, const vertex_t v_, - const edge_t e_, vertex_t* color_, cudaStream_t stream_) + const edge_t e_, vertex_t* color_, cudaStream_t stream_, + bool symmetrize_output_) : handle(handle_), offsets(offsets_), indices(indices_), @@ -70,7 +71,8 @@ MST_solver::MST_solver( temp_weights(2 * v_), mst_edge_count(1, 0), prev_mst_edge_count(1, 0), - stream(stream_) { + stream(stream_), + symmetrize_output(symmetrize_output_) { max_blocks = handle_.get_device_properties().maxGridSize[0]; max_threads = handle_.get_device_properties().maxThreadsPerBlock; sm_count = handle_.get_device_properties().multiProcessorCount; @@ -262,9 +264,9 @@ void MST_solver::label_prop(vertex_t* mst_src, // update the colors of both ends its until there is no change in colors thrust::host_vector curr_mst_edge_count = mst_edge_count; - auto min_pair_nthreads = std::min(v, max_threads); - auto min_pair_nblocks = - std::min((v + min_pair_nthreads - 1) / min_pair_nthreads, max_blocks); + auto min_pair_nthreads = std::min(v, (vertex_t)max_threads); + auto min_pair_nblocks = std::min( + (v + min_pair_nthreads - 1) / min_pair_nthreads, (vertex_t)max_blocks); rmm::device_vector done(1, false); @@ -316,8 +318,8 @@ void MST_solver::min_edge_per_vertex() { // Finds the minimum edge from each supervertex to the lowest color template void MST_solver::min_edge_per_supervertex() { - int nthreads = std::min(v, max_threads); - int nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks); + auto nthreads = std::min(v, max_threads); + auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks); thrust::fill(temp_src.begin(), temp_src.end(), std::numeric_limits::max()); @@ -334,20 +336,23 @@ void MST_solver::min_edge_per_supervertex() { detail::min_edge_per_supervertex<<>>( color, color_index_ptr, new_mst_edge_ptr, mst_edge_ptr, indices, weights, altered_weights_ptr, temp_src_ptr, temp_dst_ptr, temp_weights_ptr, - min_edge_color_ptr, v); + min_edge_color_ptr, v, symmetrize_output); // the above kernel only adds directed mst edges in the case where // a pair of vertices don't pick the same min edge between them // so, now we add the reverse edge to make it undirected - detail::add_reverse_edge<<>>( - new_mst_edge_ptr, indices, weights, temp_src_ptr, temp_dst_ptr, - temp_weights_ptr, v); + if (symmetrize_output) { + detail::add_reverse_edge<<>>( + new_mst_edge_ptr, indices, weights, temp_src_ptr, temp_dst_ptr, + temp_weights_ptr, v, symmetrize_output); + } } template void MST_solver::check_termination() { - int nthreads = std::min(2 * v, max_threads); - int nblocks = std::min((2 * v + nthreads - 1) / nthreads, max_blocks); + vertex_t nthreads = std::min(2 * v, (vertex_t)max_threads); + vertex_t nblocks = + std::min((2 * v + nthreads - 1) / nthreads, (vertex_t)max_blocks); // count number of new mst edges edge_t* mst_edge_count_ptr = mst_edge_count.data().get(); diff --git a/cpp/include/raft/sparse/mst/mst.cuh b/cpp/include/raft/sparse/mst/mst.cuh index d9caca3ba4..453fa9f1c1 100644 --- a/cpp/include/raft/sparse/mst/mst.cuh +++ b/cpp/include/raft/sparse/mst/mst.cuh @@ -26,9 +26,9 @@ template raft::Graph_COO mst( const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices, weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color, - cudaStream_t stream) { + cudaStream_t stream, bool symmetrize_output = true) { MST_solver mst_solver( - handle, offsets, indices, weights, v, e, color, stream); + handle, offsets, indices, weights, v, e, color, stream, symmetrize_output); return mst_solver.solve(); } diff --git a/cpp/include/raft/sparse/mst/mst_solver.cuh b/cpp/include/raft/sparse/mst/mst_solver.cuh index d747a32eaf..dedbe06370 100644 --- a/cpp/include/raft/sparse/mst/mst_solver.cuh +++ b/cpp/include/raft/sparse/mst/mst_solver.cuh @@ -42,7 +42,7 @@ class MST_solver { MST_solver(const raft::handle_t& handle_, const edge_t* offsets_, const vertex_t* indices_, const weight_t* weights_, const vertex_t v_, const edge_t e_, vertex_t* color_, - cudaStream_t stream_); + cudaStream_t stream_, bool symmetrize_output_); raft::Graph_COO solve(); @@ -52,6 +52,8 @@ class MST_solver { const raft::handle_t& handle; cudaStream_t stream; + bool symmetrize_output; + //CSR const edge_t* offsets; const vertex_t* indices; @@ -59,9 +61,9 @@ class MST_solver { const vertex_t v; const edge_t e; - int max_blocks; - int max_threads; - int sm_count; + vertex_t max_blocks; + vertex_t max_threads; + vertex_t sm_count; vertex_t* color; // represent each supervertex as a color rmm::device_vector diff --git a/cpp/test/mst.cu b/cpp/test/mst.cu index 4005238812..733d56a7b1 100644 --- a/cpp/test/mst.cu +++ b/cpp/test/mst.cu @@ -112,7 +112,9 @@ template class MSTTest : public ::testing::TestWithParam> { protected: - raft::Graph_COO mst_sequential() { + std::pair, + raft::Graph_COO> + mst_gpu() { edge_t *offsets = static_cast(csr_d.offsets.data()); vertex_t *indices = static_cast(csr_d.indices.data()); weight_t *weights = static_cast(csr_d.weights.data()); @@ -128,21 +130,22 @@ class MSTTest vertex_t *color_ptr = thrust::raw_pointer_cast(color.data()); - MST_solver mst_solver( - handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream()); - auto result = mst_solver.solve(); - raft::print_device_vector("Final MST Src: ", result.src.data(), - result.n_edges, std::cout); - raft::print_device_vector("Final MST Dst: ", result.dst.data(), - result.n_edges, std::cout); - raft::print_device_vector("Final MST Weights: ", result.weights.data(), - result.n_edges, std::cout); - raft::print_device_vector("Final MST Colors: ", color_ptr, v, std::cout); - - std::cout << "number_of_MST_edges: " << result.n_edges << std::endl; - EXPECT_LE(result.n_edges, 2 * v - 2); - - return result; + MST_solver symmetric_solver( + handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), + true); + auto symmetric_result = symmetric_solver.solve(); + + MST_solver non_symmetric_solver( + handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), + false); + auto non_symmetric_result = non_symmetric_solver.solve(); + + std::cout << "number_of_MST_edges: " << symmetric_result.n_edges + << std::endl; + EXPECT_LE(symmetric_result.n_edges, 2 * v - 2); + + return std::make_pair(std::move(symmetric_result), + std::move(non_symmetric_result)); } void SetUp() override { @@ -215,17 +218,24 @@ const std::vector> csr_in5_h = { typedef MSTTest MSTTestSequential; TEST_P(MSTTestSequential, Sequential) { - auto gpu_result = mst_sequential(); + auto results_pair = mst_gpu(); + auto &symmetric_result = results_pair.first; + auto &non_symmetric_result = results_pair.second; // do assertions here // in this case, running sequential MST auto prims_result = prims(csr_h); - auto parallel_mst_result = - thrust::reduce(thrust::device, gpu_result.weights.data(), - gpu_result.weights.data() + gpu_result.n_edges); + auto symmetric_sum = + thrust::reduce(thrust::device, symmetric_result.weights.data(), + symmetric_result.weights.data() + symmetric_result.n_edges); + auto non_symmetric_sum = thrust::reduce( + thrust::device, non_symmetric_result.weights.data(), + non_symmetric_result.weights.data() + non_symmetric_result.n_edges); - ASSERT_TRUE(raft::match(2 * prims_result, parallel_mst_result, + ASSERT_TRUE(raft::match(2 * prims_result, symmetric_sum, + raft::CompareApprox(0.1))); + ASSERT_TRUE(raft::match(prims_result, non_symmetric_sum, raft::CompareApprox(0.1))); }