diff --git a/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh b/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh index d3a82ca711..0235994feb 100644 --- a/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh +++ b/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh @@ -52,7 +52,7 @@ MST_solver::MST_solver( const raft::handle_t& handle_, const edge_t* offsets_, const vertex_t* indices_, const weight_t* weights_, const vertex_t v_, const edge_t e_, vertex_t* color_, cudaStream_t stream_, - bool symmetrize_output_) + bool symmetrize_output_, bool initialize_colors_, int iterations_) : handle(handle_), offsets(offsets_), indices(indices_), @@ -60,8 +60,8 @@ MST_solver::MST_solver( altered_weights(e_), v(v_), e(e_), - color(color_), - color_index(v_), + color_index(color_), + color(v_), next_color(v_), min_edge_color(v_), new_mst_edge(v_), @@ -72,16 +72,21 @@ MST_solver::MST_solver( mst_edge_count(1, 0), prev_mst_edge_count(1, 0), stream(stream_), - symmetrize_output(symmetrize_output_) { + symmetrize_output(symmetrize_output_), + initialize_colors(initialize_colors_), + iterations(iterations_) { max_blocks = handle_.get_device_properties().maxGridSize[0]; max_threads = handle_.get_device_properties().maxThreadsPerBlock; sm_count = handle_.get_device_properties().multiProcessorCount; //Initially, color holds the vertex id as color auto policy = rmm::exec_policy(stream); - thrust::sequence(policy->on(stream), color, color + v, 0); - thrust::sequence(policy->on(stream), color_index.begin(), color_index.end(), - 0); + if (initialize_colors_) { + thrust::sequence(policy->on(stream), color.begin(), color.end(), 0); + thrust::sequence(policy->on(stream), color_index, color_index + v, 0); + } else { + raft::copy(color.data().get(), color_index, v, stream); + } thrust::sequence(policy->on(stream), next_color.begin(), next_color.end(), 0); } @@ -113,7 +118,8 @@ MST_solver::solve() { // Boruvka original formulation says "while more than 1 supervertex remains" // Here we adjust it to support disconnected components (spanning forest) // track completion with mst_edge_found status and v as upper bound - for (auto i = 0; i < v; i++) { + auto mst_iterations = iterations > 0 ? iterations : v; + for (auto i = 0; i < mst_iterations; i++) { #ifdef MST_TIME start = Clock::now(); #endif @@ -184,6 +190,8 @@ MST_solver::solve() { mst_result.dst.resize(mst_result.n_edges, stream); mst_result.weights.resize(mst_result.n_edges, stream); + // raft::print_device_vector("Colors before sending: ", color_index, 7, std::cout); + return mst_result; } @@ -271,7 +279,7 @@ void MST_solver::label_prop(vertex_t* mst_src, rmm::device_vector done(1, false); edge_t* new_mst_edge_ptr = new_mst_edge.data().get(); - vertex_t* color_index_ptr = color_index.data().get(); + vertex_t* color_ptr = color.data().get(); vertex_t* next_color_ptr = next_color.data().get(); bool* done_ptr = done.data().get(); @@ -281,16 +289,16 @@ void MST_solver::label_prop(vertex_t* mst_src, done[0] = true; detail::min_pair_colors<<>>( - v, indices, new_mst_edge_ptr, color, color_index_ptr, next_color_ptr); + v, indices, new_mst_edge_ptr, color_ptr, color_index, next_color_ptr); detail::update_colors<<>>( - v, color, color_index_ptr, next_color_ptr, done_ptr); + v, color_ptr, color_index, next_color_ptr, done_ptr); i++; } detail:: final_color_indices<<>>( - v, color, color_index_ptr); + v, color_ptr, color_index); #ifdef MST_TIME std::cout << "Label prop iterations: " << i << std::endl; #endif @@ -304,14 +312,14 @@ void MST_solver::min_edge_per_vertex() { int n_threads = 32; - vertex_t* color_index_ptr = color_index.data().get(); + vertex_t* color_ptr = color.data().get(); edge_t* new_mst_edge_ptr = new_mst_edge.data().get(); bool* mst_edge_ptr = mst_edge.data().get(); weight_t* min_edge_color_ptr = min_edge_color.data().get(); weight_t* altered_weights_ptr = altered_weights.data().get(); detail::kernel_min_edge_per_vertex<<>>( - offsets, indices, altered_weights_ptr, color, color_index_ptr, + offsets, indices, altered_weights_ptr, color_ptr, color_index, new_mst_edge_ptr, mst_edge_ptr, min_edge_color_ptr, v); } @@ -324,7 +332,7 @@ void MST_solver::min_edge_per_supervertex() { thrust::fill(temp_src.begin(), temp_src.end(), std::numeric_limits::max()); - vertex_t* color_index_ptr = color_index.data().get(); + vertex_t* color_ptr = color.data().get(); edge_t* new_mst_edge_ptr = new_mst_edge.data().get(); bool* mst_edge_ptr = mst_edge.data().get(); weight_t* min_edge_color_ptr = min_edge_color.data().get(); @@ -334,7 +342,7 @@ void MST_solver::min_edge_per_supervertex() { weight_t* temp_weights_ptr = temp_weights.data().get(); detail::min_edge_per_supervertex<<>>( - color, color_index_ptr, new_mst_edge_ptr, mst_edge_ptr, indices, weights, + color_ptr, color_index, new_mst_edge_ptr, mst_edge_ptr, indices, weights, altered_weights_ptr, temp_src_ptr, temp_dst_ptr, temp_weights_ptr, min_edge_color_ptr, v, symmetrize_output); diff --git a/cpp/include/raft/sparse/mst/mst.cuh b/cpp/include/raft/sparse/mst/mst.cuh index 453fa9f1c1..4685431e7a 100644 --- a/cpp/include/raft/sparse/mst/mst.cuh +++ b/cpp/include/raft/sparse/mst/mst.cuh @@ -26,9 +26,11 @@ template raft::Graph_COO mst( const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices, weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color, - cudaStream_t stream, bool symmetrize_output = true) { + cudaStream_t stream, bool symmetrize_output = true, + bool initialize_colors = true, int iterations = 0) { MST_solver mst_solver( - handle, offsets, indices, weights, v, e, color, stream, symmetrize_output); + handle, offsets, indices, weights, v, e, color, stream, symmetrize_output, + initialize_colors, iterations); return mst_solver.solve(); } diff --git a/cpp/include/raft/sparse/mst/mst_solver.cuh b/cpp/include/raft/sparse/mst/mst_solver.cuh index dedbe06370..af8604cdea 100644 --- a/cpp/include/raft/sparse/mst/mst_solver.cuh +++ b/cpp/include/raft/sparse/mst/mst_solver.cuh @@ -42,7 +42,8 @@ class MST_solver { MST_solver(const raft::handle_t& handle_, const edge_t* offsets_, const vertex_t* indices_, const weight_t* weights_, const vertex_t v_, const edge_t e_, vertex_t* color_, - cudaStream_t stream_, bool symmetrize_output_); + cudaStream_t stream_, bool symmetrize_output_, + bool initialize_colors_, int iterations_); raft::Graph_COO solve(); @@ -51,8 +52,8 @@ class MST_solver { private: const raft::handle_t& handle; cudaStream_t stream; - - bool symmetrize_output; + bool symmetrize_output, initialize_colors; + int iterations; //CSR const edge_t* offsets; @@ -65,7 +66,7 @@ class MST_solver { vertex_t max_threads; vertex_t sm_count; - vertex_t* color; // represent each supervertex as a color + vertex_t* color_index; // represent each supervertex as a color rmm::device_vector min_edge_color; // minimum incident edge weight per color rmm::device_vector new_mst_edge; // new minimum edge per vertex @@ -77,8 +78,7 @@ class MST_solver { rmm::device_vector mst_edge; // mst output - true if the edge belongs in mst rmm::device_vector next_color; // next iteration color - rmm::device_vector - color_index; // index of color that vertex points to + rmm::device_vector color; // index of color that vertex points to // new src-dst pairs found per iteration rmm::device_vector temp_src; diff --git a/cpp/test/mst.cu b/cpp/test/mst.cu index 733d56a7b1..3dbbdc40c9 100644 --- a/cpp/test/mst.cu +++ b/cpp/test/mst.cu @@ -36,6 +36,12 @@ struct CSRHost { std::vector weights; }; +template +struct MSTTestInput { + struct CSRHost csr_h; + int iterations; +}; + template struct CSRDevice { rmm::device_buffer offsets; @@ -110,7 +116,7 @@ weight_t prims(CSRHost &csr_h) { template class MSTTest - : public ::testing::TestWithParam> { + : public ::testing::TestWithParam> { protected: std::pair, raft::Graph_COO> @@ -130,77 +136,125 @@ class MSTTest vertex_t *color_ptr = thrust::raw_pointer_cast(color.data()); - MST_solver symmetric_solver( - handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), - true); - auto symmetric_result = symmetric_solver.solve(); - - MST_solver non_symmetric_solver( - handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), - false); - auto non_symmetric_result = non_symmetric_solver.solve(); - - std::cout << "number_of_MST_edges: " << symmetric_result.n_edges - << std::endl; - EXPECT_LE(symmetric_result.n_edges, 2 * v - 2); - - return std::make_pair(std::move(symmetric_result), - std::move(non_symmetric_result)); + if (iterations == 0) { + MST_solver symmetric_solver( + handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), + true, true, 0); + auto symmetric_result = symmetric_solver.solve(); + + MST_solver non_symmetric_solver( + handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), + false, true, 0); + auto non_symmetric_result = non_symmetric_solver.solve(); + + EXPECT_LE(symmetric_result.n_edges, 2 * v - 2); + EXPECT_LE(non_symmetric_result.n_edges, v - 1); + + return std::make_pair(std::move(symmetric_result), + std::move(non_symmetric_result)); + } else { + MST_solver intermediate_solver( + handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), + true, true, iterations); + auto intermediate_result = intermediate_solver.solve(); + + MST_solver symmetric_solver( + handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), + true, false, 0); + auto symmetric_result = symmetric_solver.solve(); + + // symmetric_result.n_edges += intermediate_result.n_edges; + auto total_edge_size = + symmetric_result.n_edges + intermediate_result.n_edges; + symmetric_result.src.resize(total_edge_size, handle.get_stream()); + symmetric_result.dst.resize(total_edge_size, handle.get_stream()); + symmetric_result.weights.resize(total_edge_size, handle.get_stream()); + + raft::copy(symmetric_result.src.data() + symmetric_result.n_edges, + intermediate_result.src.data(), intermediate_result.n_edges, + handle.get_stream()); + raft::copy(symmetric_result.dst.data() + symmetric_result.n_edges, + intermediate_result.dst.data(), intermediate_result.n_edges, + handle.get_stream()); + raft::copy(symmetric_result.weights.data() + symmetric_result.n_edges, + intermediate_result.weights.data(), + intermediate_result.n_edges, handle.get_stream()); + symmetric_result.n_edges = total_edge_size; + + MST_solver non_symmetric_solver( + handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), + false, true, 0); + auto non_symmetric_result = non_symmetric_solver.solve(); + + EXPECT_LE(symmetric_result.n_edges, 2 * v - 2); + EXPECT_LE(non_symmetric_result.n_edges, v - 1); + + return std::make_pair(std::move(symmetric_result), + std::move(non_symmetric_result)); + } } void SetUp() override { - csr_h = - ::testing::TestWithParam>::GetParam(); - - csr_d.offsets = rmm::device_buffer(csr_h.offsets.data(), - csr_h.offsets.size() * sizeof(edge_t)); - csr_d.indices = rmm::device_buffer(csr_h.indices.data(), - csr_h.indices.size() * sizeof(vertex_t)); - csr_d.weights = rmm::device_buffer(csr_h.weights.data(), - csr_h.weights.size() * sizeof(weight_t)); + mst_input = ::testing::TestWithParam< + MSTTestInput>::GetParam(); + iterations = mst_input.iterations; + + csr_d.offsets = + rmm::device_buffer(mst_input.csr_h.offsets.data(), + mst_input.csr_h.offsets.size() * sizeof(edge_t)); + csr_d.indices = + rmm::device_buffer(mst_input.csr_h.indices.data(), + mst_input.csr_h.indices.size() * sizeof(vertex_t)); + csr_d.weights = + rmm::device_buffer(mst_input.csr_h.weights.data(), + mst_input.csr_h.weights.size() * sizeof(weight_t)); } void TearDown() override {} protected: - CSRHost csr_h; + MSTTestInput mst_input; CSRDevice csr_d; rmm::device_vector mst_edge; vertex_t v; edge_t e; + int iterations; raft::handle_t handle; }; // connected components tests // a full MST is produced -const std::vector> csr_in_h = { +const std::vector> csr_in_h = { // single iteration - {{0, 3, 5, 7, 8}, {1, 2, 3, 0, 3, 0, 0, 1}, {2, 3, 4, 2, 1, 3, 4, 1}}, + {{{0, 3, 5, 7, 8}, {1, 2, 3, 0, 3, 0, 0, 1}, {2, 3, 4, 2, 1, 3, 4, 1}}, 0}, // multiple iterations and cycles - {{0, 4, 6, 9, 12, 15, 17, 20}, - {2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3}, - {5.0f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 5.0f, 2.0f, 6.0f, 8.0f, - 1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}}, + {{{0, 4, 6, 9, 12, 15, 17, 20}, + {2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3}, + {5.0f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 5.0f, 2.0f, 6.0f, 8.0f, + 1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}}, + 1}, // negative weights - {{0, 4, 6, 9, 12, 15, 17, 20}, - {2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3}, - {-5.0f, -9.0f, -1.0f, 4.0f, -8.0f, -7.0f, -5.0f, -2.0f, -6.0f, -8.0f, - -1.0f, -10.0f, -9.0f, -2.0f, -1.0f, -1.0f, -6.0f, 4.0f, -7.0f, -10.0f}}, - - // equal weights - {{0, 4, 6, 9, 12, 15, 17, 20}, - {2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3}, - {0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.1, 0.2, 0.2, 0.2, - 0.1, 0.1, 0.1, 0.2, 0.1, 0.1, 0.2, 0.1, 0.2, 0.1}}, - - //self loop - {{0, 4, 6, 9, 12, 15, 17, 20}, - {0, 4, 5, 6, 3, 6, 2, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3}, - {0.5f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 0.5f, 2.0f, 6.0f, 8.0f, - 1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}}, -}; + {{{0, 4, 6, 9, 12, 15, 17, 20}, + {2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3}, + {-5.0f, -9.0f, -1.0f, 4.0f, -8.0f, -7.0f, -5.0f, -2.0f, -6.0f, -8.0f, + -1.0f, -10.0f, -9.0f, -2.0f, -1.0f, -1.0f, -6.0f, 4.0f, -7.0f, -10.0f}}, + 0}, + + // // equal weights + {{{0, 4, 6, 9, 12, 15, 17, 20}, + {2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3}, + {0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.1, 0.2, 0.2, 0.2, + 0.1, 0.1, 0.1, 0.2, 0.1, 0.1, 0.2, 0.1, 0.2, 0.1}}, + 0}, + + // //self loop + {{{0, 4, 6, 9, 12, 15, 17, 20}, + {0, 4, 5, 6, 3, 6, 2, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3}, + {0.5f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 0.5f, 2.0f, 6.0f, 8.0f, + 1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}}, + 0}}; // disconnected const std::vector> csr_in4_h = { @@ -224,7 +278,7 @@ TEST_P(MSTTestSequential, Sequential) { // do assertions here // in this case, running sequential MST - auto prims_result = prims(csr_h); + auto prims_result = prims(mst_input.csr_h); auto symmetric_sum = thrust::reduce(thrust::device, symmetric_result.weights.data(),