Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass pre-computed colors to MST #154

Merged
merged 4 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,16 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_, const vertex_t v_,
const edge_t e_, vertex_t* color_, cudaStream_t stream_,
bool symmetrize_output_)
bool symmetrize_output_, bool initialize_colors_, int iterations_)
: handle(handle_),
offsets(offsets_),
indices(indices_),
weights(weights_),
altered_weights(e_),
v(v_),
e(e_),
color(color_),
color_index(v_),
color_index(color_),
color(v_),
next_color(v_),
min_edge_color(v_),
new_mst_edge(v_),
Expand All @@ -72,16 +72,21 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
mst_edge_count(1, 0),
prev_mst_edge_count(1, 0),
stream(stream_),
symmetrize_output(symmetrize_output_) {
symmetrize_output(symmetrize_output_),
initialize_colors(initialize_colors_),
iterations(iterations_) {
max_blocks = handle_.get_device_properties().maxGridSize[0];
max_threads = handle_.get_device_properties().maxThreadsPerBlock;
sm_count = handle_.get_device_properties().multiProcessorCount;

//Initially, color holds the vertex id as color
auto policy = rmm::exec_policy(stream);
thrust::sequence(policy->on(stream), color, color + v, 0);
thrust::sequence(policy->on(stream), color_index.begin(), color_index.end(),
0);
if (initialize_colors_) {
thrust::sequence(policy->on(stream), color.begin(), color.end(), 0);
thrust::sequence(policy->on(stream), color_index, color_index + v, 0);
} else {
raft::copy(color.data().get(), color_index, v, stream);
}
thrust::sequence(policy->on(stream), next_color.begin(), next_color.end(), 0);
}

Expand Down Expand Up @@ -113,7 +118,8 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
// Boruvka original formulation says "while more than 1 supervertex remains"
// Here we adjust it to support disconnected components (spanning forest)
// track completion with mst_edge_found status and v as upper bound
for (auto i = 0; i < v; i++) {
auto mst_iterations = iterations > 0 ? iterations : v;
for (auto i = 0; i < mst_iterations; i++) {
#ifdef MST_TIME
start = Clock::now();
#endif
Expand Down Expand Up @@ -184,6 +190,8 @@ MST_solver<vertex_t, edge_t, weight_t>::solve() {
mst_result.dst.resize(mst_result.n_edges, stream);
mst_result.weights.resize(mst_result.n_edges, stream);

// raft::print_device_vector("Colors before sending: ", color_index, 7, std::cout);

return mst_result;
}

Expand Down Expand Up @@ -271,7 +279,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
rmm::device_vector<bool> done(1, false);

edge_t* new_mst_edge_ptr = new_mst_edge.data().get();
vertex_t* color_index_ptr = color_index.data().get();
vertex_t* color_ptr = color.data().get();
vertex_t* next_color_ptr = next_color.data().get();

bool* done_ptr = done.data().get();
Expand All @@ -281,16 +289,16 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
done[0] = true;

detail::min_pair_colors<<<min_pair_nblocks, min_pair_nthreads, 0, stream>>>(
v, indices, new_mst_edge_ptr, color, color_index_ptr, next_color_ptr);
v, indices, new_mst_edge_ptr, color_ptr, color_index, next_color_ptr);

detail::update_colors<<<min_pair_nblocks, min_pair_nthreads, 0, stream>>>(
v, color, color_index_ptr, next_color_ptr, done_ptr);
v, color_ptr, color_index, next_color_ptr, done_ptr);
i++;
}

detail::
final_color_indices<<<min_pair_nblocks, min_pair_nthreads, 0, stream>>>(
v, color, color_index_ptr);
v, color_ptr, color_index);
#ifdef MST_TIME
std::cout << "Label prop iterations: " << i << std::endl;
#endif
Expand All @@ -304,14 +312,14 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_vertex() {

int n_threads = 32;

vertex_t* color_index_ptr = color_index.data().get();
vertex_t* color_ptr = color.data().get();
edge_t* new_mst_edge_ptr = new_mst_edge.data().get();
bool* mst_edge_ptr = mst_edge.data().get();
weight_t* min_edge_color_ptr = min_edge_color.data().get();
weight_t* altered_weights_ptr = altered_weights.data().get();

detail::kernel_min_edge_per_vertex<<<v, n_threads, 0, stream>>>(
offsets, indices, altered_weights_ptr, color, color_index_ptr,
offsets, indices, altered_weights_ptr, color_ptr, color_index,
new_mst_edge_ptr, mst_edge_ptr, min_edge_color_ptr, v);
}

Expand All @@ -324,7 +332,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
thrust::fill(temp_src.begin(), temp_src.end(),
std::numeric_limits<vertex_t>::max());

vertex_t* color_index_ptr = color_index.data().get();
vertex_t* color_ptr = color.data().get();
edge_t* new_mst_edge_ptr = new_mst_edge.data().get();
bool* mst_edge_ptr = mst_edge.data().get();
weight_t* min_edge_color_ptr = min_edge_color.data().get();
Expand All @@ -334,7 +342,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
weight_t* temp_weights_ptr = temp_weights.data().get();

detail::min_edge_per_supervertex<<<nblocks, nthreads, 0, stream>>>(
color, color_index_ptr, new_mst_edge_ptr, mst_edge_ptr, indices, weights,
color_ptr, color_index, new_mst_edge_ptr, mst_edge_ptr, indices, weights,
altered_weights_ptr, temp_src_ptr, temp_dst_ptr, temp_weights_ptr,
min_edge_color_ptr, v, symmetrize_output);

Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/sparse/mst/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ template <typename vertex_t, typename edge_t, typename weight_t>
raft::Graph_COO<vertex_t, edge_t, weight_t> mst(
const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices,
weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color,
cudaStream_t stream, bool symmetrize_output = true) {
cudaStream_t stream, bool symmetrize_output = true,
bool initialize_colors = true, int iterations = 0) {
MST_solver<vertex_t, edge_t, weight_t> mst_solver(
handle, offsets, indices, weights, v, e, color, stream, symmetrize_output);
handle, offsets, indices, weights, v, e, color, stream, symmetrize_output,
initialize_colors, iterations);
return mst_solver.solve();
}

Expand Down
12 changes: 6 additions & 6 deletions cpp/include/raft/sparse/mst/mst_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class MST_solver {
MST_solver(const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_,
const vertex_t v_, const edge_t e_, vertex_t* color_,
cudaStream_t stream_, bool symmetrize_output_);
cudaStream_t stream_, bool symmetrize_output_,
bool initialize_colors_, int iterations_);

raft::Graph_COO<vertex_t, edge_t, weight_t> solve();

Expand All @@ -51,8 +52,8 @@ class MST_solver {
private:
const raft::handle_t& handle;
cudaStream_t stream;

bool symmetrize_output;
bool symmetrize_output, initialize_colors;
int iterations;

//CSR
const edge_t* offsets;
Expand All @@ -65,7 +66,7 @@ class MST_solver {
vertex_t max_threads;
vertex_t sm_count;

vertex_t* color; // represent each supervertex as a color
vertex_t* color_index; // represent each supervertex as a color
rmm::device_vector<weight_t>
min_edge_color; // minimum incident edge weight per color
rmm::device_vector<edge_t> new_mst_edge; // new minimum edge per vertex
Expand All @@ -77,8 +78,7 @@ class MST_solver {
rmm::device_vector<bool>
mst_edge; // mst output - true if the edge belongs in mst
rmm::device_vector<vertex_t> next_color; // next iteration color
rmm::device_vector<vertex_t>
color_index; // index of color that vertex points to
rmm::device_vector<vertex_t> color; // index of color that vertex points to

// new src-dst pairs found per iteration
rmm::device_vector<vertex_t> temp_src;
Expand Down
156 changes: 105 additions & 51 deletions cpp/test/mst.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ struct CSRHost {
std::vector<weight_t> weights;
};

template <typename vertex_t, typename edge_t, typename weight_t>
struct MSTTestInput {
struct CSRHost<vertex_t, edge_t, weight_t> csr_h;
int iterations;
};

template <typename vertex_t, typename edge_t, typename weight_t>
struct CSRDevice {
rmm::device_buffer offsets;
Expand Down Expand Up @@ -110,7 +116,7 @@ weight_t prims(CSRHost<vertex_t, edge_t, weight_t> &csr_h) {

template <typename vertex_t, typename edge_t, typename weight_t>
class MSTTest
: public ::testing::TestWithParam<CSRHost<vertex_t, edge_t, weight_t>> {
: public ::testing::TestWithParam<MSTTestInput<vertex_t, edge_t, weight_t>> {
protected:
std::pair<raft::Graph_COO<vertex_t, edge_t, weight_t>,
raft::Graph_COO<vertex_t, edge_t, weight_t>>
Expand All @@ -130,77 +136,125 @@ class MSTTest

vertex_t *color_ptr = thrust::raw_pointer_cast(color.data());

MST_solver<vertex_t, edge_t, weight_t> symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true);
auto symmetric_result = symmetric_solver.solve();

MST_solver<vertex_t, edge_t, weight_t> non_symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
false);
auto non_symmetric_result = non_symmetric_solver.solve();

std::cout << "number_of_MST_edges: " << symmetric_result.n_edges
<< std::endl;
EXPECT_LE(symmetric_result.n_edges, 2 * v - 2);

return std::make_pair(std::move(symmetric_result),
std::move(non_symmetric_result));
if (iterations == 0) {
MST_solver<vertex_t, edge_t, weight_t> symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true, true, 0);
auto symmetric_result = symmetric_solver.solve();

MST_solver<vertex_t, edge_t, weight_t> non_symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
false, true, 0);
auto non_symmetric_result = non_symmetric_solver.solve();

EXPECT_LE(symmetric_result.n_edges, 2 * v - 2);
EXPECT_LE(non_symmetric_result.n_edges, v - 1);

return std::make_pair(std::move(symmetric_result),
std::move(non_symmetric_result));
} else {
MST_solver<vertex_t, edge_t, weight_t> intermediate_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true, true, iterations);
auto intermediate_result = intermediate_solver.solve();

MST_solver<vertex_t, edge_t, weight_t> symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true, false, 0);
auto symmetric_result = symmetric_solver.solve();

// symmetric_result.n_edges += intermediate_result.n_edges;
auto total_edge_size =
symmetric_result.n_edges + intermediate_result.n_edges;
symmetric_result.src.resize(total_edge_size, handle.get_stream());
symmetric_result.dst.resize(total_edge_size, handle.get_stream());
symmetric_result.weights.resize(total_edge_size, handle.get_stream());

raft::copy(symmetric_result.src.data() + symmetric_result.n_edges,
intermediate_result.src.data(), intermediate_result.n_edges,
handle.get_stream());
raft::copy(symmetric_result.dst.data() + symmetric_result.n_edges,
intermediate_result.dst.data(), intermediate_result.n_edges,
handle.get_stream());
raft::copy(symmetric_result.weights.data() + symmetric_result.n_edges,
intermediate_result.weights.data(),
intermediate_result.n_edges, handle.get_stream());
symmetric_result.n_edges = total_edge_size;

MST_solver<vertex_t, edge_t, weight_t> non_symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
false, true, 0);
auto non_symmetric_result = non_symmetric_solver.solve();

EXPECT_LE(symmetric_result.n_edges, 2 * v - 2);
EXPECT_LE(non_symmetric_result.n_edges, v - 1);

return std::make_pair(std::move(symmetric_result),
std::move(non_symmetric_result));
}
}

void SetUp() override {
csr_h =
::testing::TestWithParam<CSRHost<vertex_t, edge_t, weight_t>>::GetParam();

csr_d.offsets = rmm::device_buffer(csr_h.offsets.data(),
csr_h.offsets.size() * sizeof(edge_t));
csr_d.indices = rmm::device_buffer(csr_h.indices.data(),
csr_h.indices.size() * sizeof(vertex_t));
csr_d.weights = rmm::device_buffer(csr_h.weights.data(),
csr_h.weights.size() * sizeof(weight_t));
mst_input = ::testing::TestWithParam<
MSTTestInput<vertex_t, edge_t, weight_t>>::GetParam();
iterations = mst_input.iterations;

csr_d.offsets =
rmm::device_buffer(mst_input.csr_h.offsets.data(),
mst_input.csr_h.offsets.size() * sizeof(edge_t));
csr_d.indices =
rmm::device_buffer(mst_input.csr_h.indices.data(),
mst_input.csr_h.indices.size() * sizeof(vertex_t));
csr_d.weights =
rmm::device_buffer(mst_input.csr_h.weights.data(),
mst_input.csr_h.weights.size() * sizeof(weight_t));
}

void TearDown() override {}

protected:
CSRHost<vertex_t, edge_t, weight_t> csr_h;
MSTTestInput<vertex_t, edge_t, weight_t> mst_input;
CSRDevice<vertex_t, edge_t, weight_t> csr_d;
rmm::device_vector<bool> mst_edge;
vertex_t v;
edge_t e;
int iterations;

raft::handle_t handle;
};

// connected components tests
// a full MST is produced
const std::vector<CSRHost<int, int, float>> csr_in_h = {
const std::vector<MSTTestInput<int, int, float>> csr_in_h = {
// single iteration
{{0, 3, 5, 7, 8}, {1, 2, 3, 0, 3, 0, 0, 1}, {2, 3, 4, 2, 1, 3, 4, 1}},
{{{0, 3, 5, 7, 8}, {1, 2, 3, 0, 3, 0, 0, 1}, {2, 3, 4, 2, 1, 3, 4, 1}}, 0},

// multiple iterations and cycles
{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{5.0f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 5.0f, 2.0f, 6.0f, 8.0f,
1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}},
{{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{5.0f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 5.0f, 2.0f, 6.0f, 8.0f,
1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}},
1},
// negative weights
{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{-5.0f, -9.0f, -1.0f, 4.0f, -8.0f, -7.0f, -5.0f, -2.0f, -6.0f, -8.0f,
-1.0f, -10.0f, -9.0f, -2.0f, -1.0f, -1.0f, -6.0f, 4.0f, -7.0f, -10.0f}},

// equal weights
{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.1, 0.2, 0.2, 0.2,
0.1, 0.1, 0.1, 0.2, 0.1, 0.1, 0.2, 0.1, 0.2, 0.1}},

//self loop
{{0, 4, 6, 9, 12, 15, 17, 20},
{0, 4, 5, 6, 3, 6, 2, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{0.5f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 0.5f, 2.0f, 6.0f, 8.0f,
1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}},
};
{{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{-5.0f, -9.0f, -1.0f, 4.0f, -8.0f, -7.0f, -5.0f, -2.0f, -6.0f, -8.0f,
-1.0f, -10.0f, -9.0f, -2.0f, -1.0f, -1.0f, -6.0f, 4.0f, -7.0f, -10.0f}},
0},

// // equal weights
{{{0, 4, 6, 9, 12, 15, 17, 20},
{2, 4, 5, 6, 3, 6, 0, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.1, 0.2, 0.2, 0.2,
0.1, 0.1, 0.1, 0.2, 0.1, 0.1, 0.2, 0.1, 0.2, 0.1}},
0},

// //self loop
{{{0, 4, 6, 9, 12, 15, 17, 20},
{0, 4, 5, 6, 3, 6, 2, 4, 5, 1, 4, 6, 0, 2, 3, 0, 2, 0, 1, 3},
{0.5f, 9.0f, 1.0f, 4.0f, 8.0f, 7.0f, 0.5f, 2.0f, 6.0f, 8.0f,
1.0f, 10.0f, 9.0f, 2.0f, 1.0f, 1.0f, 6.0f, 4.0f, 7.0f, 10.0f}},
0}};

// disconnected
const std::vector<CSRHost<int, int, float>> csr_in4_h = {
Expand All @@ -224,7 +278,7 @@ TEST_P(MSTTestSequential, Sequential) {

// do assertions here
// in this case, running sequential MST
auto prims_result = prims(csr_h);
auto prims_result = prims(mst_input.csr_h);

auto symmetric_sum =
thrust::reduce(thrust::device, symmetric_result.weights.data(),
Expand Down