diff --git a/cpp/include/raft/sparse/hierarchy/detail/mst.cuh b/cpp/include/raft/sparse/hierarchy/detail/mst.cuh index 91c4e1642d..8ffcfe0f2b 100644 --- a/cpp/include/raft/sparse/hierarchy/detail/mst.cuh +++ b/cpp/include/raft/sparse/hierarchy/detail/mst.cuh @@ -116,7 +116,7 @@ void connect_knn_graph(const raft::handle_t &handle, const value_t *X, // On the second call, we hand the MST the original colors // and the new set of edges and let it restart the optimization process - auto new_mst = raft::mst::mst( + auto new_mst = raft::mst::mst( handle, indptr2.data(), connected_edges.cols(), connected_edges.vals(), m, connected_edges.nnz, color, stream, false, false); @@ -164,7 +164,7 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X, rmm::device_uvector color(m, stream); // We want to have MST initialize colors on first call. - auto mst_coo = raft::mst::mst( + auto mst_coo = raft::mst::mst( handle, indptr, indices, pw_dists, (value_idx)m, nnz, color.data(), stream, false, true); @@ -201,12 +201,6 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X, " or increase 'max_iter'", max_iter); - RAFT_EXPECTS(mst_coo.n_edges == m - 1, - "n_edges should be %d but was %d. This" - "could be an indication of duplicate edges returned from the" - "MST or symmetrization stage.", - m - 1, mst_coo.n_edges); - sort_coo_by_data(mst_coo.src.data(), mst_coo.dst.data(), mst_coo.weights.data(), mst_coo.n_edges, stream); diff --git a/cpp/include/raft/sparse/mst/detail/mst_kernels.cuh b/cpp/include/raft/sparse/mst/detail/mst_kernels.cuh index d2f86d6dc8..f0d30b0cb7 100644 --- a/cpp/include/raft/sparse/mst/detail/mst_kernels.cuh +++ b/cpp/include/raft/sparse/mst/detail/mst_kernels.cuh @@ -27,22 +27,22 @@ namespace raft { namespace mst { namespace detail { -template +template __global__ void kernel_min_edge_per_vertex( - const edge_t* offsets, const vertex_t* indices, const weight_t* weights, + const edge_t* offsets, const vertex_t* indices, const alteration_t* weights, const vertex_t* color, const vertex_t* color_index, edge_t* new_mst_edge, - const bool* mst_edge, weight_t* min_edge_color, const vertex_t v) { + const bool* mst_edge, alteration_t* min_edge_color, const vertex_t v) { edge_t tid = threadIdx.x + blockIdx.x * blockDim.x; unsigned warp_id = tid / 32; unsigned lane_id = tid % 32; __shared__ edge_t min_edge_index[32]; - __shared__ weight_t min_edge_weight[32]; + __shared__ alteration_t min_edge_weight[32]; __shared__ vertex_t min_color[32]; min_edge_index[lane_id] = std::numeric_limits::max(); - min_edge_weight[lane_id] = std::numeric_limits::max(); + min_edge_weight[lane_id] = std::numeric_limits::max(); min_color[lane_id] = std::numeric_limits::max(); __syncthreads(); @@ -61,7 +61,7 @@ __global__ void kernel_min_edge_per_vertex( // assuming one warp per row // find min for each thread in warp for (edge_t e = row_start + lane_id; e < row_end; e += 32) { - weight_t curr_edge_weight = weights[e]; + alteration_t curr_edge_weight = weights[e]; vertex_t successor_color_idx = color_index[indices[e]]; vertex_t successor_color = color[successor_color_idx]; @@ -92,7 +92,7 @@ __global__ void kernel_min_edge_per_vertex( // min edge may now be found in first thread if (lane_id == 0) { - if (min_edge_weight[0] != std::numeric_limits::max()) { + if (min_edge_weight[0] != std::numeric_limits::max()) { new_mst_edge[warp_id] = min_edge_index[0]; // atomically set min edge per color @@ -102,12 +102,13 @@ __global__ void kernel_min_edge_per_vertex( } } -template +template __global__ void min_edge_per_supervertex( const vertex_t* color, const vertex_t* color_index, edge_t* new_mst_edge, bool* mst_edge, const vertex_t* indices, const weight_t* weights, - const weight_t* altered_weights, vertex_t* temp_src, vertex_t* temp_dst, - weight_t* temp_weights, const weight_t* min_edge_color, const vertex_t v, + const alteration_t* altered_weights, vertex_t* temp_src, vertex_t* temp_dst, + weight_t* temp_weights, const alteration_t* min_edge_color, const vertex_t v, bool symmetrize_output) { auto tid = get_1D_idx(); if (tid < v) { @@ -119,7 +120,7 @@ __global__ void min_edge_per_supervertex( // find minimum edge is same as minimum edge of whole supervertex // if yes, that is part of mst if (edge_idx != std::numeric_limits::max()) { - weight_t vertex_weight = altered_weights[edge_idx]; + alteration_t vertex_weight = altered_weights[edge_idx]; bool add_edge = false; if (min_edge_color[vertex_color] == vertex_weight) { @@ -281,13 +282,14 @@ __global__ void final_color_indices(const vertex_t v, const vertex_t* color, // Alterate the weights, make all undirected edge weight unique while keeping Wuv == Wvu // Consider using curand device API instead of precomputed random_values array -template +template __global__ void alteration_kernel(const vertex_t v, const edge_t e, const edge_t* offsets, const vertex_t* indices, - const weight_t* weights, weight_t max, - weight_t* random_values, - weight_t* altered_weights) { + const weight_t* weights, alteration_t max, + alteration_t* random_values, + alteration_t* altered_weights) { auto row = get_1D_idx(); if (row < v) { auto row_begin = offsets[row]; diff --git a/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh b/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh index c12c4a8d02..c5ba4fcb4f 100644 --- a/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh +++ b/cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh @@ -22,6 +22,10 @@ #include "mst_kernels.cuh" #include "utils.cuh" +#include +#include +#include + #include #include #include @@ -50,8 +54,9 @@ inline curandStatus_t curand_generate_uniformX(curandGenerator_t generator, return curandGenerateUniformDouble(generator, outputPtr, n); } -template -MST_solver::MST_solver( +template +MST_solver::MST_solver( const raft::handle_t& handle_, const edge_t* offsets_, const vertex_t* indices_, const weight_t* weights_, const vertex_t v_, const edge_t e_, vertex_t* color_, cudaStream_t stream_, @@ -93,9 +98,10 @@ MST_solver::MST_solver( thrust::sequence(policy, next_color.begin(), next_color.end(), 0); } -template +template raft::Graph_COO -MST_solver::solve() { +MST_solver::solve() { RAFT_EXPECTS(v > 0, "0 vertices"); RAFT_EXPECTS(e > 0, "0 edges"); RAFT_EXPECTS(offsets != nullptr, "Null offsets."); @@ -116,7 +122,9 @@ MST_solver::solve() { timer0 = duration_us(stop - start); #endif - Graph_COO mst_result(2 * v - 2, stream); + auto max_mst_edges = symmetrize_output ? 2 * v - 2 : v - 1; + + Graph_COO mst_result(max_mst_edges, stream); // Boruvka original formulation says "while more than 1 supervertex remains" // Here we adjust it to support disconnected components (spanning forest) @@ -152,7 +160,12 @@ MST_solver::solve() { timer3 += duration_us(stop - start); #endif - if (prev_mst_edge_count[0] == mst_edge_count[0]) { + auto curr_mst_edge_count = mst_edge_count[0]; + RAFT_EXPECTS(curr_mst_edge_count <= max_mst_edges, + "Number of edges found by MST is invalid. This may be due to " + "loss in precision. Try increasing precision of weights."); + + if (curr_mst_edge_count == prev_mst_edge_count[0]) { #ifdef MST_TIME std::cout << "Iterations: " << i << std::endl; std::cout << timer0 << "," << timer1 << "," << timer2 << "," << timer3 @@ -210,8 +223,10 @@ struct alteration_functor { }; // Compute the uper bound for the alteration -template -weight_t MST_solver::alteration_max() { +template +alteration_t +MST_solver::alteration_max() { auto policy = rmm::exec_policy(rmm::cuda_stream_view{stream}); rmm::device_vector tmp(e); thrust::device_ptr weights_ptr(weights); @@ -231,21 +246,22 @@ weight_t MST_solver::alteration_max() { auto max = thrust::transform_reduce(policy, begin, end, alteration_functor(), init, thrust::minimum()); - return max / static_cast(2); + return max / static_cast(2); } // Compute the alteration to make all undirected edge weight unique // Preserves weights order -template -void MST_solver::alteration() { +template +void MST_solver::alteration() { auto nthreads = std::min(v, max_threads); auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks); // maximum alteration that does not change realtive weights order - weight_t max = alteration_max(); + alteration_t max = alteration_max(); // pool of rand values - rmm::device_vector rand_values(v); + rmm::device_vector rand_values(v); // Random number generator curandGenerator_t randGen; @@ -267,9 +283,10 @@ void MST_solver::alteration() { } // updates colors of vertices by propagating the lower color to the higher -template -void MST_solver::label_prop(vertex_t* mst_src, - vertex_t* mst_dst) { +template +void MST_solver::label_prop( + vertex_t* mst_src, vertex_t* mst_dst) { // update the colors of both ends its until there is no change in colors thrust::host_vector curr_mst_edge_count = mst_edge_count; @@ -306,11 +323,13 @@ void MST_solver::label_prop(vertex_t* mst_src, } // Finds the minimum edge from each vertex to the lowest color -template -void MST_solver::min_edge_per_vertex() { +template +void MST_solver::min_edge_per_vertex() { auto policy = rmm::exec_policy(rmm::cuda_stream_view{stream}); thrust::fill(policy, min_edge_color.begin(), min_edge_color.end(), - std::numeric_limits::max()); + std::numeric_limits::max()); thrust::fill(policy, new_mst_edge.begin(), new_mst_edge.end(), std::numeric_limits::max()); @@ -319,8 +338,8 @@ void MST_solver::min_edge_per_vertex() { vertex_t* color_ptr = color.data().get(); edge_t* new_mst_edge_ptr = new_mst_edge.data().get(); bool* mst_edge_ptr = mst_edge.data().get(); - weight_t* min_edge_color_ptr = min_edge_color.data().get(); - weight_t* altered_weights_ptr = altered_weights.data().get(); + alteration_t* min_edge_color_ptr = min_edge_color.data().get(); + alteration_t* altered_weights_ptr = altered_weights.data().get(); detail::kernel_min_edge_per_vertex<<>>( offsets, indices, altered_weights_ptr, color_ptr, color_index, @@ -328,8 +347,10 @@ void MST_solver::min_edge_per_vertex() { } // Finds the minimum edge from each supervertex to the lowest color -template -void MST_solver::min_edge_per_supervertex() { +template +void MST_solver::min_edge_per_supervertex() { auto nthreads = std::min(v, max_threads); auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks); @@ -340,8 +361,8 @@ void MST_solver::min_edge_per_supervertex() { vertex_t* color_ptr = color.data().get(); edge_t* new_mst_edge_ptr = new_mst_edge.data().get(); bool* mst_edge_ptr = mst_edge.data().get(); - weight_t* min_edge_color_ptr = min_edge_color.data().get(); - weight_t* altered_weights_ptr = altered_weights.data().get(); + alteration_t* min_edge_color_ptr = min_edge_color.data().get(); + alteration_t* altered_weights_ptr = altered_weights.data().get(); vertex_t* temp_src_ptr = temp_src.data().get(); vertex_t* temp_dst_ptr = temp_dst.data().get(); weight_t* temp_weights_ptr = temp_weights.data().get(); @@ -361,8 +382,9 @@ void MST_solver::min_edge_per_supervertex() { } } -template -void MST_solver::check_termination() { +template +void MST_solver::check_termination() { vertex_t nthreads = std::min(2 * v, (vertex_t)max_threads); vertex_t nblocks = std::min((2 * v + nthreads - 1) / nthreads, (vertex_t)max_blocks); @@ -385,8 +407,9 @@ struct new_edges_functor { } }; -template -void MST_solver::append_src_dst_pair( +template +void MST_solver::append_src_dst_pair( vertex_t* mst_src, vertex_t* mst_dst, weight_t* mst_weights) { auto policy = rmm::exec_policy(rmm::cuda_stream_view{stream}); diff --git a/cpp/include/raft/sparse/mst/mst.cuh b/cpp/include/raft/sparse/mst/mst.cuh index 4685431e7a..10c981445e 100644 --- a/cpp/include/raft/sparse/mst/mst.cuh +++ b/cpp/include/raft/sparse/mst/mst.cuh @@ -22,13 +22,14 @@ namespace raft { namespace mst { -template +template raft::Graph_COO mst( const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices, weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color, cudaStream_t stream, bool symmetrize_output = true, bool initialize_colors = true, int iterations = 0) { - MST_solver mst_solver( + MST_solver mst_solver( handle, offsets, indices, weights, v, e, color, stream, symmetrize_output, initialize_colors, iterations); return mst_solver.solve(); diff --git a/cpp/include/raft/sparse/mst/mst_solver.cuh b/cpp/include/raft/sparse/mst/mst_solver.cuh index 438dc79a49..833882ea0d 100644 --- a/cpp/include/raft/sparse/mst/mst_solver.cuh +++ b/cpp/include/raft/sparse/mst/mst_solver.cuh @@ -36,7 +36,8 @@ struct Graph_COO { namespace mst { -template +template class MST_solver { public: MST_solver(const raft::handle_t& handle_, const edge_t* offsets_, @@ -67,10 +68,11 @@ class MST_solver { vertex_t sm_count; vertex_t* color_index; // represent each supervertex as a color - rmm::device_vector + rmm::device_vector min_edge_color; // minimum incident edge weight per color - rmm::device_vector new_mst_edge; // new minimum edge per vertex - rmm::device_vector altered_weights; // weights to be used for mst + rmm::device_vector new_mst_edge; // new minimum edge per vertex + rmm::device_vector + altered_weights; // weights to be used for mst rmm::device_vector mst_edge_count; // total number of edges added after every iteration rmm::device_vector @@ -90,7 +92,7 @@ class MST_solver { void min_edge_per_supervertex(); void check_termination(); void alteration(); - weight_t alteration_max(); + alteration_t alteration_max(); void append_src_dst_pair(vertex_t* mst_src, vertex_t* mst_dst, weight_t* mst_weights); }; diff --git a/cpp/test/mst.cu b/cpp/test/mst.cu index d32c10d881..215c6f6548 100644 --- a/cpp/test/mst.cu +++ b/cpp/test/mst.cu @@ -136,12 +136,12 @@ class MSTTest vertex_t *color_ptr = thrust::raw_pointer_cast(color.data()); if (iterations == 0) { - MST_solver symmetric_solver( + MST_solver symmetric_solver( handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), true, true, 0); auto symmetric_result = symmetric_solver.solve(); - MST_solver non_symmetric_solver( + MST_solver non_symmetric_solver( handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), false, true, 0); auto non_symmetric_result = non_symmetric_solver.solve(); @@ -152,12 +152,12 @@ class MSTTest return std::make_pair(std::move(symmetric_result), std::move(non_symmetric_result)); } else { - MST_solver intermediate_solver( + MST_solver intermediate_solver( handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), true, true, iterations); auto intermediate_result = intermediate_solver.solve(); - MST_solver symmetric_solver( + MST_solver symmetric_solver( handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), true, false, 0); auto symmetric_result = symmetric_solver.solve(); @@ -180,7 +180,7 @@ class MSTTest intermediate_result.n_edges, handle.get_stream()); symmetric_result.n_edges = total_edge_size; - MST_solver non_symmetric_solver( + MST_solver non_symmetric_solver( handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(), false, true, 0); auto non_symmetric_result = non_symmetric_solver.solve();