Skip to content

Commit

Permalink
Pass pre-computed colors to MST (#154)
Browse files Browse the repository at this point in the history
closes #145

Authors:
  - Divye Gala (@divyegala)

Approvers:
  - Corey J. Nolet (@cjnolet)

URL: #154
  • Loading branch information
divyegala authored Mar 4, 2021
1 parent b055cf8 commit 5b2240c
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 75 deletions.
40 changes: 24 additions & 16 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,16 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_, const vertex_t v_,
const edge_t e_, vertex_t* color_, cudaStream_t stream_,
bool symmetrize_output_)
bool symmetrize_output_, bool initialize_colors_, int iterations_)
: handle(handle_),
offsets(offsets_),
indices(indices_),
weights(weights_),
altered_weights(e_),
v(v_),
e(e_),
color(color_),
color_index(v_),
color_index(color_),
color(v_),
next_color(v_),
min_edge_color(v_),
new_mst_edge(v_),
Expand All @@ -72,16 +72,21 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
mst_edge_count(1, 0),
prev_mst_edge_count(1, 0),
stream(stream_),
symmetrize_output(symmetrize_output_) {
symmetrize_output(symmetrize_output_),
initialize_colors(initialize_colors_),
iterations(iterations_) {
max_blocks = handle_.get_device_properties().maxGridSize[0];
max_threads = handle_.get_device_properties().maxThreadsPerBlock;
sm_count = handle_.get_device_properties().multiProcessorCount;

//Initially, color holds the vertex id as color
auto policy = rmm::exec_policy(stream);
thrust::sequence(policy->on(stream), color, color + v, 0);
thrust::sequence(policy->on(stream), color_index.begin(), color_index.end(),
0);
if (initialize_colors_) {
thrust::sequence(policy->on(stream), color.begin(), color.end(), 0);
thrust::sequence(policy->on(stream), color_index, color_index + v, 0);
} else {
raft::copy(color.data().get(), color_index, v, stream);
}
thrust::sequence(policy->on(stream), next_color.begin(), next_color.end(), 0);
}

Expand Down Expand Up @@ -113,7 +118,8 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
// Boruvka original formulation says "while more than 1 supervertex remains"
// Here we adjust it to support disconnected components (spanning forest)
// track completion with mst_edge_found status and v as upper bound
for (auto i = 0; i < v; i++) {
auto mst_iterations = iterations > 0 ? iterations : v;
for (auto i = 0; i < mst_iterations; i++) {
#ifdef MST_TIME
start = Clock::now();
#endif
Expand Down Expand Up @@ -184,6 +190,8 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
mst_result.dst.resize(mst_result.n_edges, stream);
mst_result.weights.resize(mst_result.n_edges, stream);

// raft::print_device_vector("Colors before sending: ", color_index, 7, std::cout);

return mst_result;
}

Expand Down Expand Up @@ -271,7 +279,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
rmm::device_vector<bool> done(1, false);

edge_t* new_mst_edge_ptr = new_mst_edge.data().get();
vertex_t* color_index_ptr = color_index.data().get();
vertex_t* color_ptr = color.data().get();
vertex_t* next_color_ptr = next_color.data().get();

bool* done_ptr = done.data().get();
Expand All @@ -281,16 +289,16 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
done[0] = true;

detail::min_pair_colors<<<min_pair_nblocks, min_pair_nthreads, 0, stream>>>(
v, indices, new_mst_edge_ptr, color, color_index_ptr, next_color_ptr);
v, indices, new_mst_edge_ptr, color_ptr, color_index, next_color_ptr);

detail::update_colors<<<min_pair_nblocks, min_pair_nthreads, 0, stream>>>(
v, color, color_index_ptr, next_color_ptr, done_ptr);
v, color_ptr, color_index, next_color_ptr, done_ptr);
i++;
}

detail::
final_color_indices<<<min_pair_nblocks, min_pair_nthreads, 0, stream>>>(
v, color, color_index_ptr);
v, color_ptr, color_index);
#ifdef MST_TIME
std::cout << "Label prop iterations: " << i << std::endl;
#endif
Expand All @@ -304,14 +312,14 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_vertex() {

int n_threads = 32;

vertex_t* color_index_ptr = color_index.data().get();
vertex_t* color_ptr = color.data().get();
edge_t* new_mst_edge_ptr = new_mst_edge.data().get();
bool* mst_edge_ptr = mst_edge.data().get();
weight_t* min_edge_color_ptr = min_edge_color.data().get();
weight_t* altered_weights_ptr = altered_weights.data().get();

detail::kernel_min_edge_per_vertex<<<v, n_threads, 0, stream>>>(
offsets, indices, altered_weights_ptr, color, color_index_ptr,
offsets, indices, altered_weights_ptr, color_ptr, color_index,
new_mst_edge_ptr, mst_edge_ptr, min_edge_color_ptr, v);
}

Expand All @@ -324,7 +332,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
thrust::fill(temp_src.begin(), temp_src.end(),
std::numeric_limits<vertex_t>::max());

vertex_t* color_index_ptr = color_index.data().get();
vertex_t* color_ptr = color.data().get();
edge_t* new_mst_edge_ptr = new_mst_edge.data().get();
bool* mst_edge_ptr = mst_edge.data().get();
weight_t* min_edge_color_ptr = min_edge_color.data().get();
Expand All @@ -334,7 +342,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
weight_t* temp_weights_ptr = temp_weights.data().get();

detail::min_edge_per_supervertex<<<nblocks, nthreads, 0, stream>>>(
color, color_index_ptr, new_mst_edge_ptr, mst_edge_ptr, indices, weights,
color_ptr, color_index, new_mst_edge_ptr, mst_edge_ptr, indices, weights,
altered_weights_ptr, temp_src_ptr, temp_dst_ptr, temp_weights_ptr,
min_edge_color_ptr, v, symmetrize_output);

Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/sparse/mst/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ template <typename vertex_t, typename edge_t, typename weight_t>
raft::Graph_COO<vertex_t, edge_t, weight_t> mst(
const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices,
weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color,
cudaStream_t stream, bool symmetrize_output = true) {
cudaStream_t stream, bool symmetrize_output = true,
bool initialize_colors = true, int iterations = 0) {
MST_solver<vertex_t, edge_t, weight_t> mst_solver(
handle, offsets, indices, weights, v, e, color, stream, symmetrize_output);
handle, offsets, indices, weights, v, e, color, stream, symmetrize_output,
initialize_colors, iterations);
return mst_solver.solve();
}

Expand Down
12 changes: 6 additions & 6 deletions cpp/include/raft/sparse/mst/mst_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class MST_solver {
MST_solver(const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_,
const vertex_t v_, const edge_t e_, vertex_t* color_,
cudaStream_t stream_, bool symmetrize_output_);
cudaStream_t stream_, bool symmetrize_output_,
bool initialize_colors_, int iterations_);

raft::Graph_COO<vertex_t, edge_t, weight_t> solve();

Expand All @@ -51,8 +52,8 @@ class MST_solver {
private:
const raft::handle_t& handle;
cudaStream_t stream;

bool symmetrize_output;
bool symmetrize_output, initialize_colors;
int iterations;

//CSR
const edge_t* offsets;
Expand All @@ -65,7 +66,7 @@ class MST_solver {
vertex_t max_threads;
vertex_t sm_count;

vertex_t* color; // represent each supervertex as a color
vertex_t* color_index; // represent each supervertex as a color
rmm::device_vector<weight_t>
min_edge_color; // minimum incident edge weight per color
rmm::device_vector<edge_t> new_mst_edge; // new minimum edge per vertex
Expand All @@ -77,8 +78,7 @@ class MST_solver {
rmm::device_vector<bool>
mst_edge; // mst output - true if the edge belongs in mst
rmm::device_vector<vertex_t> next_color; // next iteration color
rmm::device_vector<vertex_t>
color_index; // index of color that vertex points to
rmm::device_vector<vertex_t> color; // index of color that vertex points to

// new src-dst pairs found per iteration
rmm::device_vector<vertex_t> temp_src;
Expand Down
156 changes: 105 additions & 51 deletions cpp/test/mst.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ struct CSRHost {
std::vector<weight_t> weights;
};

template <typename vertex_t, typename edge_t, typename weight_t>
struct MSTTestInput {
struct CSRHost<vertex_t, edge_t, weight_t> csr_h;
int iterations;
};

template <typename vertex_t, typename edge_t, typename weight_t>
struct CSRDevice {
rmm::device_buffer offsets;
Expand Down Expand Up @@ -110,7 +116,7 @@ weight_t prims(CSRHost<vertex_t, edge_t, weight_t> &csr_h) {

template <typename vertex_t, typename edge_t, typename weight_t>
class MSTTest
: public ::testing::TestWithParam<CSRHost<vertex_t, edge_t, weight_t>> {
: public ::testing::TestWithParam<MSTTestInput<vertex_t, edge_t, weight_t>> {
protected:
std::pair<raft::Graph_COO<vertex_t, edge_t, weight_t>,
raft::Graph_COO<vertex_t, edge_t, weight_t>>
Expand All @@ -130,77 +136,125 @@ class MSTTest

vertex_t *color_ptr = thrust::raw_pointer_cast(color.data());

MST_solver<vertex_t, edge_t, weight_t> symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true);
auto symmetric_result = symmetric_solver.solve();

MST_solver<vertex_t, edge_t, weight_t> non_symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
false);
auto non_symmetric_result = non_symmetric_solver.solve();

std::cout << "number_of_MST_edges: " << symmetric_result.n_edges
<< std::endl;
EXPECT_LE(symmetric_result.n_edges, 2 * v - 2);

return std::make_pair(std::move(symmetric_result),
std::move(non_symmetric_result));
if (iterations == 0) {
MST_solver<vertex_t, edge_t, weight_t> symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true, true, 0);
auto symmetric_result = symmetric_solver.solve();

MST_solver<vertex_t, edge_t, weight_t> non_symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
false, true, 0);
auto non_symmetric_result = non_symmetric_solver.solve();

EXPECT_LE(symmetric_result.n_edges, 2 * v - 2);
EXPECT_LE(non_symmetric_result.n_edges, v - 1);

return std::make_pair(std::move(symmetric_result),
std::move(non_symmetric_result));
} else {
MST_solver<vertex_t, edge_t, weight_t> intermediate_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true, true, iterations);
auto intermediate_result = intermediate_solver.solve();

MST_solver<vertex_t, edge_t, weight_t> symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true, false, 0);
auto symmetric_result = symmetric_solver.solve();

// symmetric_result.n_edges += intermediate_result.n_edges;
auto total_edge_size =
symmetric_result.n_edges + intermediate_result.n_edges;
symmetric_result.src.resize(total_edge_size, handle.get_stream());
symmetric_result.dst.resize(total_edge_size, handle.get_stream());
symmetric_result.weights.resize(total_edge_size, handle.get_stream());

raft::copy(symmetric_result.src.data() + symmetric_result.n_edges,
intermediate_result.src.data(), intermediate_result.n_edges,
handle.get_stream());
raft::copy(symmetric_result.dst.data() + symmetric_result.n_edges,
intermediate_result.dst.data(), intermediate_result.n_edges,
handle.get_stream());
raft::copy(symmetric_result.weights.data() + symmetric_result.n_edges,
intermediate_result.weights.data(),
intermediate_result.n_edges, handle.get_stream());
symmetric_result.n_edges = total_edge_size;

MST_solver<vertex_t, edge_t, weight_t> non_symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
false, true, 0);
auto non_symmetric_result = non_symmetric_solver.solve();

EXPECT_LE(symmetric_result.n_edges, 2 * v - 2);
EXPECT_LE(non_symmetric_result.n_edges, v - 1);

return std::make_pair(std::move(symmetric_result),
std::move(non_symmetric_result));
}
}

void SetUp() override {
csr_h =
::testing::TestWithParam<CSRHost<vertex_t, edge_t, weight_t>>::GetParam();

csr_d.offsets = rmm::device_buffer(csr_h.offsets.data(),
csr_h.offsets.size() * sizeof(edge_t));
csr_d.indices = rmm::device_buffer(csr_h.indices.data(),
csr_h.indices.size() * sizeof(vertex_t));
csr_d.weights = rmm::device_buffer(csr_h.weights.data(),
csr_h.weights.size() * sizeof(weight_t));
mst_input = ::testing::TestWithParam<
MSTTestInput<vertex_t, edge_t, weight_t>>::GetParam();
iterations = mst_input.iterations;

csr_d.offsets =
rmm::device_buffer(mst_input.csr_h.offsets.data(),
mst_input.csr_h.offsets.size() * sizeof(edge_t));
csr_d.indices =
rmm::device_buffer(mst_input.csr_h.indices.data(),
mst_input.csr_h.indices.size() * sizeof(vertex_t));
csr_d.weights =
rmm::device_buffer(mst_input.csr_h.weights.data(),
mst_input.csr_h.weights.size() * sizeof(weight_t));
}

void TearDown() override {}

protected:
CSRHost<vertex_t, edge_t, weight_t> csr_h;
MSTTestInput<vertex_t, edge_t, weight_t> mst_input;
CSRDevice<vertex_t, edge_t, weight_t> csr_d;
rmm::device_vector<bool> mst_edge;
vertex_t v;
edge_t e;
int iterations;

raft::handle_t handle;
};

// connected components tests
// a full MST is produced
const std::vector<CSRHost<int, int, float>> csr_in_h = {
const std::vector<MSTTestInput<int, int, float>> csr_in_h = {
// single iteration
{{0, 3, 5, 7, 8}, {1, 2, 3, 0, 3, 0, 0, 1}, {2, 3, 4, 2, 1, 3, 4, 1}},
{{{0, 3, 5, 7, 8}, {1, 2, 3, 0, 3, 0, 0, 1}, {2, 3, 4, 2, 1, 3, 4, 1}}, 0},

// multiple iterations and cycles
{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{5.0f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 5.0f, 2.0f, 6.0f, 8.0f,
1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}},
{{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{5.0f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 5.0f, 2.0f, 6.0f, 8.0f,
1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}},
1},
// negative weights
{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{-5.0f, -9.0f, -1.0f, 4.0f, -8.0f, -7.0f, -5.0f, -2.0f, -6.0f, -8.0f,
-1.0f, -10.0f, -9.0f, -2.0f, -1.0f, -1.0f, -6.0f, 4.0f, -7.0f, -10.0f}},

// equal weights
{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.1, 0.2, 0.2, 0.2,
0.1, 0.1, 0.1, 0.2, 0.1, 0.1, 0.2, 0.1, 0.2, 0.1}},

//self loop
{{0, 4, 6, 9, 12, 15, 17, 20},
{0, 4, 5, 6, 3, 6, 2, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{0.5f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 0.5f, 2.0f, 6.0f, 8.0f,
1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}},
};
{{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{-5.0f, -9.0f, -1.0f, 4.0f, -8.0f, -7.0f, -5.0f, -2.0f, -6.0f, -8.0f,
-1.0f, -10.0f, -9.0f, -2.0f, -1.0f, -1.0f, -6.0f, 4.0f, -7.0f, -10.0f}},
0},

// // equal weights
{{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.1, 0.2, 0.2, 0.2,
0.1, 0.1, 0.1, 0.2, 0.1, 0.1, 0.2, 0.1, 0.2, 0.1}},
0},

// //self loop
{{{0, 4, 6, 9, 12, 15, 17, 20},
{0, 4, 5, 6, 3, 6, 2, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{0.5f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 0.5f, 2.0f, 6.0f, 8.0f,
1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}},
0}};

// disconnected
const std::vector<CSRHost<int, int, float>> csr_in4_h = {
Expand All @@ -224,7 +278,7 @@ TEST_P(MSTTestSequential, Sequential) {

// do assertions here
// in this case, running sequential MST
auto prims_result = prims(csr_h);
auto prims_result = prims(mst_input.csr_h);

auto symmetric_sum =
thrust::reduce(thrust::device, symmetric_result.weights.data(),
Expand Down

0 comments on commit 5b2240c

Please sign in to comment.