Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss of Precision in MST weight alteration #223

Merged
merged 6 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions cpp/include/raft/sparse/hierarchy/detail/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void connect_knn_graph(const raft::handle_t &handle, const value_t *X,

// On the second call, we hand the MST the original colors
// and the new set of edges and let it restart the optimization process
auto new_mst = raft::mst::mst<value_idx, value_idx, value_t>(
auto new_mst = raft::mst::mst<value_idx, value_idx, value_t, double>(
afender marked this conversation as resolved.
Show resolved Hide resolved
handle, indptr2.data(), connected_edges.cols(), connected_edges.vals(), m,
connected_edges.nnz, color, stream, false, false);

Expand Down Expand Up @@ -164,7 +164,7 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X,
rmm::device_uvector<value_idx> color(m, stream);

// We want to have MST initialize colors on first call.
auto mst_coo = raft::mst::mst<value_idx, value_idx, value_t>(
auto mst_coo = raft::mst::mst<value_idx, value_idx, value_t, double>(
handle, indptr, indices, pw_dists, (value_idx)m, nnz, color.data(), stream,
false, true);

Expand Down Expand Up @@ -201,12 +201,6 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X,
" or increase 'max_iter'",
max_iter);

RAFT_EXPECTS(mst_coo.n_edges == m - 1,
"n_edges should be %d but was %d. This"
"could be an indication of duplicate edges returned from the"
"MST or symmetrization stage.",
m - 1, mst_coo.n_edges);

sort_coo_by_data(mst_coo.src.data(), mst_coo.dst.data(),
mst_coo.weights.data(), mst_coo.n_edges, stream);

Expand Down
32 changes: 17 additions & 15 deletions cpp/include/raft/sparse/mst/detail/mst_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ namespace raft {
namespace mst {
namespace detail {

template <typename vertex_t, typename edge_t, typename weight_t>
template <typename vertex_t, typename edge_t, typename alteration_t>
__global__ void kernel_min_edge_per_vertex(
const edge_t* offsets, const vertex_t* indices, const weight_t* weights,
const edge_t* offsets, const vertex_t* indices, const alteration_t* weights,
const vertex_t* color, const vertex_t* color_index, edge_t* new_mst_edge,
const bool* mst_edge, weight_t* min_edge_color, const vertex_t v) {
const bool* mst_edge, alteration_t* min_edge_color, const vertex_t v) {
edge_t tid = threadIdx.x + blockIdx.x * blockDim.x;

unsigned warp_id = tid / 32;
unsigned lane_id = tid % 32;

__shared__ edge_t min_edge_index[32];
__shared__ weight_t min_edge_weight[32];
__shared__ alteration_t min_edge_weight[32];
__shared__ vertex_t min_color[32];

min_edge_index[lane_id] = std::numeric_limits<edge_t>::max();
min_edge_weight[lane_id] = std::numeric_limits<weight_t>::max();
min_edge_weight[lane_id] = std::numeric_limits<alteration_t>::max();
min_color[lane_id] = std::numeric_limits<vertex_t>::max();

__syncthreads();
Expand All @@ -61,7 +61,7 @@ __global__ void kernel_min_edge_per_vertex(
// assuming one warp per row
// find min for each thread in warp
for (edge_t e = row_start + lane_id; e < row_end; e += 32) {
weight_t curr_edge_weight = weights[e];
alteration_t curr_edge_weight = weights[e];
vertex_t successor_color_idx = color_index[indices[e]];
vertex_t successor_color = color[successor_color_idx];

Expand Down Expand Up @@ -92,7 +92,7 @@ __global__ void kernel_min_edge_per_vertex(

// min edge may now be found in first thread
if (lane_id == 0) {
if (min_edge_weight[0] != std::numeric_limits<weight_t>::max()) {
if (min_edge_weight[0] != std::numeric_limits<alteration_t>::max()) {
new_mst_edge[warp_id] = min_edge_index[0];

// atomically set min edge per color
Expand All @@ -102,12 +102,13 @@ __global__ void kernel_min_edge_per_vertex(
}
}

template <typename vertex_t, typename edge_t, typename weight_t>
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
__global__ void min_edge_per_supervertex(
const vertex_t* color, const vertex_t* color_index, edge_t* new_mst_edge,
bool* mst_edge, const vertex_t* indices, const weight_t* weights,
const weight_t* altered_weights, vertex_t* temp_src, vertex_t* temp_dst,
weight_t* temp_weights, const weight_t* min_edge_color, const vertex_t v,
const alteration_t* altered_weights, vertex_t* temp_src, vertex_t* temp_dst,
weight_t* temp_weights, const alteration_t* min_edge_color, const vertex_t v,
bool symmetrize_output) {
auto tid = get_1D_idx<vertex_t>();
if (tid < v) {
Expand All @@ -119,7 +120,7 @@ __global__ void min_edge_per_supervertex(
// find minimum edge is same as minimum edge of whole supervertex
// if yes, that is part of mst
if (edge_idx != std::numeric_limits<edge_t>::max()) {
weight_t vertex_weight = altered_weights[edge_idx];
alteration_t vertex_weight = altered_weights[edge_idx];

bool add_edge = false;
if (min_edge_color[vertex_color] == vertex_weight) {
Expand Down Expand Up @@ -281,13 +282,14 @@ __global__ void final_color_indices(const vertex_t v, const vertex_t* color,

// Alterate the weights, make all undirected edge weight unique while keeping Wuv == Wvu
// Consider using curand device API instead of precomputed random_values array
template <typename vertex_t, typename edge_t, typename weight_t>
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
__global__ void alteration_kernel(const vertex_t v, const edge_t e,
const edge_t* offsets,
const vertex_t* indices,
const weight_t* weights, weight_t max,
weight_t* random_values,
weight_t* altered_weights) {
const weight_t* weights, alteration_t max,
alteration_t* random_values,
alteration_t* altered_weights) {
auto row = get_1D_idx<vertex_t>();
if (row < v) {
auto row_begin = offsets[row];
Expand Down
81 changes: 52 additions & 29 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
#include "mst_kernels.cuh"
#include "utils.cuh"

#include <raft/cudart_utils.h>
#include <rmm/device_buffer.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
Expand Down Expand Up @@ -50,8 +54,9 @@ inline curandStatus_t curand_generate_uniformX(curandGenerator_t generator,
return curandGenerateUniformDouble(generator, outputPtr, n);
}

template <typename vertex_t, typename edge_t, typename weight_t>
MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
MST_solver<vertex_t, edge_t, weight_t, alteration_t>::MST_solver(
const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_, const vertex_t v_,
const edge_t e_, vertex_t* color_, cudaStream_t stream_,
Expand Down Expand Up @@ -93,9 +98,10 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
thrust::sequence(policy, next_color.begin(), next_color.end(), 0);
}

template <typename vertex_t, typename edge_t, typename weight_t>
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
raft::Graph_COO<vertex_t, edge_t, weight_t>
MST_solver<vertex_t, edge_t, weight_t>::solve() {
MST_solver<vertex_t, edge_t, weight_t, alteration_t>::solve() {
RAFT_EXPECTS(v > 0, "0 vertices");
RAFT_EXPECTS(e > 0, "0 edges");
RAFT_EXPECTS(offsets != nullptr, "Null offsets.");
Expand All @@ -116,7 +122,9 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
timer0 = duration_us(stop - start);
#endif

Graph_COO<vertex_t, edge_t, weight_t> mst_result(2 * v - 2, stream);
auto max_mst_edges = symmetrize_output ? 2 * v - 2 : v - 1;

Graph_COO<vertex_t, edge_t, weight_t> mst_result(max_mst_edges, stream);

// Boruvka original formulation says "while more than 1 supervertex remains"
// Here we adjust it to support disconnected components (spanning forest)
Expand Down Expand Up @@ -152,7 +160,12 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
timer3 += duration_us(stop - start);
#endif

if (prev_mst_edge_count[0] == mst_edge_count[0]) {
auto curr_mst_edge_count = mst_edge_count[0];
RAFT_EXPECTS(curr_mst_edge_count <= max_mst_edges,
"Number of edges found by MST is invalid. This may be due to "
"loss in precision. Try increasing precision of weights.");

if (curr_mst_edge_count == prev_mst_edge_count[0]) {
#ifdef MST_TIME
std::cout << "Iterations: " << i << std::endl;
std::cout << timer0 << "," << timer1 << "," << timer2 << "," << timer3
Expand Down Expand Up @@ -210,8 +223,10 @@ struct alteration_functor {
};

// Compute the uper bound for the alteration
template <typename vertex_t, typename edge_t, typename weight_t>
weight_t MST_solver<vertex_t, edge_t, weight_t>::alteration_max() {
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
alteration_t
MST_solver<vertex_t, edge_t, weight_t, alteration_t>::alteration_max() {
auto policy = rmm::exec_policy(rmm::cuda_stream_view{stream});
rmm::device_vector<weight_t> tmp(e);
thrust::device_ptr<const weight_t> weights_ptr(weights);
Expand All @@ -231,21 +246,22 @@ weight_t MST_solver<vertex_t, edge_t, weight_t>::alteration_max() {
auto max =
thrust::transform_reduce(policy, begin, end, alteration_functor<weight_t>(),
init, thrust::minimum<weight_t>());
return max / static_cast<weight_t>(2);
return max / static_cast<alteration_t>(2);
}

// Compute the alteration to make all undirected edge weight unique
// Preserves weights order
template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::alteration() {
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
void MST_solver<vertex_t, edge_t, weight_t, alteration_t>::alteration() {
auto nthreads = std::min(v, max_threads);
auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);

// maximum alteration that does not change realtive weights order
weight_t max = alteration_max();
alteration_t max = alteration_max();

// pool of rand values
rmm::device_vector<weight_t> rand_values(v);
rmm::device_vector<alteration_t> rand_values(v);

// Random number generator
curandGenerator_t randGen;
Expand All @@ -267,9 +283,10 @@ void MST_solver<vertex_t, edge_t, weight_t>::alteration() {
}

// updates colors of vertices by propagating the lower color to the higher
template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
vertex_t* mst_dst) {
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
void MST_solver<vertex_t, edge_t, weight_t, alteration_t>::label_prop(
vertex_t* mst_src, vertex_t* mst_dst) {
// update the colors of both ends its until there is no change in colors
thrust::host_vector<edge_t> curr_mst_edge_count = mst_edge_count;

Expand Down Expand Up @@ -306,11 +323,13 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
}

// Finds the minimum edge from each vertex to the lowest color
template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_vertex() {
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
void MST_solver<vertex_t, edge_t, weight_t,
alteration_t>::min_edge_per_vertex() {
auto policy = rmm::exec_policy(rmm::cuda_stream_view{stream});
thrust::fill(policy, min_edge_color.begin(), min_edge_color.end(),
std::numeric_limits<weight_t>::max());
std::numeric_limits<alteration_t>::max());
thrust::fill(policy, new_mst_edge.begin(), new_mst_edge.end(),
std::numeric_limits<weight_t>::max());

Expand All @@ -319,17 +338,19 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_vertex() {
vertex_t* color_ptr = color.data().get();
edge_t* new_mst_edge_ptr = new_mst_edge.data().get();
bool* mst_edge_ptr = mst_edge.data().get();
weight_t* min_edge_color_ptr = min_edge_color.data().get();
weight_t* altered_weights_ptr = altered_weights.data().get();
alteration_t* min_edge_color_ptr = min_edge_color.data().get();
alteration_t* altered_weights_ptr = altered_weights.data().get();

detail::kernel_min_edge_per_vertex<<<v, n_threads, 0, stream>>>(
offsets, indices, altered_weights_ptr, color_ptr, color_index,
new_mst_edge_ptr, mst_edge_ptr, min_edge_color_ptr, v);
}

// Finds the minimum edge from each supervertex to the lowest color
template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
void MST_solver<vertex_t, edge_t, weight_t,
alteration_t>::min_edge_per_supervertex() {
auto nthreads = std::min(v, max_threads);
auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);

Expand All @@ -340,8 +361,8 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
vertex_t* color_ptr = color.data().get();
edge_t* new_mst_edge_ptr = new_mst_edge.data().get();
bool* mst_edge_ptr = mst_edge.data().get();
weight_t* min_edge_color_ptr = min_edge_color.data().get();
weight_t* altered_weights_ptr = altered_weights.data().get();
alteration_t* min_edge_color_ptr = min_edge_color.data().get();
alteration_t* altered_weights_ptr = altered_weights.data().get();
vertex_t* temp_src_ptr = temp_src.data().get();
vertex_t* temp_dst_ptr = temp_dst.data().get();
weight_t* temp_weights_ptr = temp_weights.data().get();
Expand All @@ -361,8 +382,9 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
}
}

template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::check_termination() {
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
void MST_solver<vertex_t, edge_t, weight_t, alteration_t>::check_termination() {
vertex_t nthreads = std::min(2 * v, (vertex_t)max_threads);
vertex_t nblocks =
std::min((2 * v + nthreads - 1) / nthreads, (vertex_t)max_blocks);
Expand All @@ -385,8 +407,9 @@ struct new_edges_functor {
}
};

template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::append_src_dst_pair(
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
void MST_solver<vertex_t, edge_t, weight_t, alteration_t>::append_src_dst_pair(
vertex_t* mst_src, vertex_t* mst_dst, weight_t* mst_weights) {
auto policy = rmm::exec_policy(rmm::cuda_stream_view{stream});

Expand Down
5 changes: 3 additions & 2 deletions cpp/include/raft/sparse/mst/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
namespace raft {
namespace mst {

template <typename vertex_t, typename edge_t, typename weight_t>
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t = weight_t>
raft::Graph_COO<vertex_t, edge_t, weight_t> mst(
const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices,
weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color,
cudaStream_t stream, bool symmetrize_output = true,
bool initialize_colors = true, int iterations = 0) {
MST_solver<vertex_t, edge_t, weight_t> mst_solver(
MST_solver<vertex_t, edge_t, weight_t, alteration_t> mst_solver(
handle, offsets, indices, weights, v, e, color, stream, symmetrize_output,
initialize_colors, iterations);
return mst_solver.solve();
Expand Down
12 changes: 7 additions & 5 deletions cpp/include/raft/sparse/mst/mst_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ struct Graph_COO {

namespace mst {

template <typename vertex_t, typename edge_t, typename weight_t>
template <typename vertex_t, typename edge_t, typename weight_t,
typename alteration_t>
class MST_solver {
public:
MST_solver(const raft::handle_t& handle_, const edge_t* offsets_,
Expand Down Expand Up @@ -67,10 +68,11 @@ class MST_solver {
vertex_t sm_count;

vertex_t* color_index; // represent each supervertex as a color
rmm::device_vector<weight_t>
rmm::device_vector<alteration_t>
min_edge_color; // minimum incident edge weight per color
rmm::device_vector<edge_t> new_mst_edge; // new minimum edge per vertex
rmm::device_vector<weight_t> altered_weights; // weights to be used for mst
rmm::device_vector<edge_t> new_mst_edge; // new minimum edge per vertex
rmm::device_vector<alteration_t>
altered_weights; // weights to be used for mst
rmm::device_vector<edge_t>
mst_edge_count; // total number of edges added after every iteration
rmm::device_vector<edge_t>
Expand All @@ -90,7 +92,7 @@ class MST_solver {
void min_edge_per_supervertex();
void check_termination();
void alteration();
weight_t alteration_max();
alteration_t alteration_max();
void append_src_dst_pair(vertex_t* mst_src, vertex_t* mst_dst,
weight_t* mst_weights);
};
Expand Down
Loading