Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MST symmetric/non-symmetric output for SLHC #162

Merged
merged 12 commits into from
Feb 24, 2021
55 changes: 41 additions & 14 deletions cpp/include/raft/sparse/mst/detail/mst_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ __global__ void min_edge_per_supervertex(
const vertex_t* color, const vertex_t* color_index, edge_t* new_mst_edge,
bool* mst_edge, const vertex_t* indices, const weight_t* weights,
const weight_t* altered_weights, vertex_t* temp_src, vertex_t* temp_dst,
weight_t* temp_weights, const weight_t* min_edge_color, const vertex_t v) {
weight_t* temp_weights, const weight_t* min_edge_color, const vertex_t v,
bool symmetrize_output) {
auto tid = get_1D_idx<vertex_t>();

if (tid < v) {
vertex_t vertex_color_idx = color_index[tid];
vertex_t vertex_color = color[vertex_color_idx];
Expand All @@ -120,13 +120,38 @@ __global__ void min_edge_per_supervertex(
// if yes, that is part of mst
if (edge_idx != std::numeric_limits<edge_t>::max()) {
weight_t vertex_weight = altered_weights[edge_idx];

bool add_edge = false;
if (min_edge_color[vertex_color] == vertex_weight) {
temp_src[tid] = tid;
temp_dst[tid] = indices[edge_idx];
temp_weights[tid] = weights[edge_idx];
add_edge = true;

auto dst = indices[edge_idx];
if (!symmetrize_output) {
auto dst_edge_idx = new_mst_edge[dst];
auto dst_color = color[color_index[dst]];

// vertices added each other
// only if destination has found an edge
// the edge points back to source
// the edge is minimum edge found for dst color
if (dst_edge_idx != std::numeric_limits<edge_t>::max() &&
indices[dst_edge_idx] == tid &&
min_edge_color[dst_color] == altered_weights[dst_edge_idx]) {
if (vertex_color > dst_color) {
add_edge = false;
}
}
}

mst_edge[edge_idx] = true;
} else {
if (add_edge) {
temp_src[tid] = tid;
temp_dst[tid] = dst;
temp_weights[tid] = weights[edge_idx];
mst_edge[edge_idx] = true;
}
}

if (!add_edge) {
new_mst_edge[tid] = std::numeric_limits<edge_t>::max();
}
}
Expand All @@ -138,7 +163,7 @@ __global__ void add_reverse_edge(const edge_t* new_mst_edge,
const vertex_t* indices,
const weight_t* weights, vertex_t* temp_src,
vertex_t* temp_dst, weight_t* temp_weights,
const vertex_t v) {
const vertex_t v, bool symmetrize_output) {
auto tid = get_1D_idx<vertex_t>();

if (tid < v) {
Expand All @@ -155,12 +180,14 @@ __global__ void add_reverse_edge(const edge_t* new_mst_edge,
reverse_needed = true;
} else {
// check what vertex the neighbor vertex picked
vertex_t neighbor_vertex_neighbor = indices[neighbor_edge_idx];

// if vertices did not pick each other
// add a reverse edge
if (tid != neighbor_vertex_neighbor) {
reverse_needed = true;
if (symmetrize_output) {
vertex_t neighbor_vertex_neighbor = indices[neighbor_edge_idx];

// if vertices did not pick each other
// add a reverse edge
if (tid != neighbor_vertex_neighbor) {
reverse_needed = true;
}
}
}

Expand Down
35 changes: 20 additions & 15 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@ typedef std::chrono::high_resolution_clock Clock;

// curand generator uniform
inline curandStatus_t curand_generate_uniformX(curandGenerator_t generator,
float* outputPtr, size_t n) {
float* outputPtr, size_t n) {
return curandGenerateUniform(generator, outputPtr, n);
}
inline curandStatus_t curand_generate_uniformX(curandGenerator_t generator,
double* outputPtr, size_t n) {
double* outputPtr, size_t n) {
return curandGenerateUniformDouble(generator, outputPtr, n);
}

template <typename vertex_t, typename edge_t, typename weight_t>
MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_, const vertex_t v_,
const edge_t e_, vertex_t* color_, cudaStream_t stream_)
const edge_t e_, vertex_t* color_, cudaStream_t stream_,
bool symmetrize_output_)
: handle(handle_),
offsets(offsets_),
indices(indices_),
Expand All @@ -70,7 +71,8 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
temp_weights(2 * v_),
mst_edge_count(1, 0),
prev_mst_edge_count(1, 0),
stream(stream_) {
stream(stream_),
symmetrize_output(symmetrize_output_) {
max_blocks = handle_.get_device_properties().maxGridSize[0];
max_threads = handle_.get_device_properties().maxThreadsPerBlock;
sm_count = handle_.get_device_properties().multiProcessorCount;
Expand Down Expand Up @@ -262,9 +264,9 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
// update the colors of both ends its until there is no change in colors
thrust::host_vector<edge_t> curr_mst_edge_count = mst_edge_count;

auto min_pair_nthreads = std::min(v, max_threads);
auto min_pair_nblocks =
std::min((v + min_pair_nthreads - 1) / min_pair_nthreads, max_blocks);
auto min_pair_nthreads = std::min(v, (vertex_t)max_threads);
auto min_pair_nblocks = std::min(
(v + min_pair_nthreads - 1) / min_pair_nthreads, (vertex_t)max_blocks);

rmm::device_vector<bool> done(1, false);

Expand Down Expand Up @@ -316,8 +318,8 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_vertex() {
// Finds the minimum edge from each supervertex to the lowest color
template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
int nthreads = std::min(v, max_threads);
int nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);
auto nthreads = std::min(v, max_threads);
auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);

thrust::fill(temp_src.begin(), temp_src.end(),
std::numeric_limits<vertex_t>::max());
Expand All @@ -334,20 +336,23 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
detail::min_edge_per_supervertex<<<nblocks, nthreads, 0, stream>>>(
color, color_index_ptr, new_mst_edge_ptr, mst_edge_ptr, indices, weights,
altered_weights_ptr, temp_src_ptr, temp_dst_ptr, temp_weights_ptr,
min_edge_color_ptr, v);
min_edge_color_ptr, v, symmetrize_output);

// the above kernel only adds directed mst edges in the case where
// a pair of vertices don't pick the same min edge between them
// so, now we add the reverse edge to make it undirected
detail::add_reverse_edge<<<nblocks, nthreads, 0, stream>>>(
new_mst_edge_ptr, indices, weights, temp_src_ptr, temp_dst_ptr,
temp_weights_ptr, v);
if (symmetrize_output) {
detail::add_reverse_edge<<<nblocks, nthreads, 0, stream>>>(
new_mst_edge_ptr, indices, weights, temp_src_ptr, temp_dst_ptr,
temp_weights_ptr, v, symmetrize_output);
}
}

template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::check_termination() {
int nthreads = std::min(2 * v, max_threads);
int nblocks = std::min((2 * v + nthreads - 1) / nthreads, max_blocks);
vertex_t nthreads = std::min(2 * v, (vertex_t)max_threads);
vertex_t nblocks =
std::min((2 * v + nthreads - 1) / nthreads, (vertex_t)max_blocks);

// count number of new mst edges
edge_t* mst_edge_count_ptr = mst_edge_count.data().get();
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/sparse/mst/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ template <typename vertex_t, typename edge_t, typename weight_t>
raft::Graph_COO<vertex_t, edge_t, weight_t> mst(
const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices,
weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color,
cudaStream_t stream) {
cudaStream_t stream, bool symmetrize_output = true) {
MST_solver<vertex_t, edge_t, weight_t> mst_solver(
handle, offsets, indices, weights, v, e, color, stream);
handle, offsets, indices, weights, v, e, color, stream, symmetrize_output);
return mst_solver.solve();
}

Expand Down
10 changes: 6 additions & 4 deletions cpp/include/raft/sparse/mst/mst_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MST_solver {
MST_solver(const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_,
const vertex_t v_, const edge_t e_, vertex_t* color_,
cudaStream_t stream_);
cudaStream_t stream_, bool symmetrize_output_);

raft::Graph_COO<vertex_t, edge_t, weight_t> solve();

Expand All @@ -52,16 +52,18 @@ class MST_solver {
const raft::handle_t& handle;
cudaStream_t stream;

bool symmetrize_output;

//CSR
const edge_t* offsets;
const vertex_t* indices;
const weight_t* weights;
const vertex_t v;
const edge_t e;

int max_blocks;
int max_threads;
int sm_count;
vertex_t max_blocks;
vertex_t max_threads;
vertex_t sm_count;

vertex_t* color; // represent each supervertex as a color
rmm::device_vector<weight_t>
Expand Down
52 changes: 31 additions & 21 deletions cpp/test/mst.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ template <typename vertex_t, typename edge_t, typename weight_t>
class MSTTest
: public ::testing::TestWithParam<CSRHost<vertex_t, edge_t, weight_t>> {
protected:
raft::Graph_COO<vertex_t, edge_t, weight_t> mst_sequential() {
std::pair<raft::Graph_COO<vertex_t, edge_t, weight_t>,
raft::Graph_COO<vertex_t, edge_t, weight_t>>
mst_gpu() {
edge_t *offsets = static_cast<edge_t *>(csr_d.offsets.data());
vertex_t *indices = static_cast<vertex_t *>(csr_d.indices.data());
weight_t *weights = static_cast<weight_t *>(csr_d.weights.data());
Expand All @@ -128,21 +130,22 @@ class MSTTest

vertex_t *color_ptr = thrust::raw_pointer_cast(color.data());

MST_solver<vertex_t, edge_t, weight_t> mst_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream());
auto result = mst_solver.solve();
raft::print_device_vector("Final MST Src: ", result.src.data(),
result.n_edges, std::cout);
raft::print_device_vector("Final MST Dst: ", result.dst.data(),
result.n_edges, std::cout);
raft::print_device_vector("Final MST Weights: ", result.weights.data(),
result.n_edges, std::cout);
raft::print_device_vector("Final MST Colors: ", color_ptr, v, std::cout);

std::cout << "number_of_MST_edges: " << result.n_edges << std::endl;
EXPECT_LE(result.n_edges, 2 * v - 2);

return result;
MST_solver<vertex_t, edge_t, weight_t> symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true);
auto symmetric_result = symmetric_solver.solve();

MST_solver<vertex_t, edge_t, weight_t> non_symmetric_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
false);
auto non_symmetric_result = non_symmetric_solver.solve();

std::cout << "number_of_MST_edges: " << symmetric_result.n_edges
<< std::endl;
EXPECT_LE(symmetric_result.n_edges, 2 * v - 2);

return std::make_pair(std::move(symmetric_result),
std::move(non_symmetric_result));
}

void SetUp() override {
Expand Down Expand Up @@ -215,17 +218,24 @@ const std::vector<CSRHost<int, int, float>> csr_in5_h = {

typedef MSTTest<int, int, float> MSTTestSequential;
TEST_P(MSTTestSequential, Sequential) {
auto gpu_result = mst_sequential();
auto results_pair = mst_gpu();
auto &symmetric_result = results_pair.first;
auto &non_symmetric_result = results_pair.second;

// do assertions here
// in this case, running sequential MST
auto prims_result = prims(csr_h);

auto parallel_mst_result =
thrust::reduce(thrust::device, gpu_result.weights.data(),
gpu_result.weights.data() + gpu_result.n_edges);
auto symmetric_sum =
thrust::reduce(thrust::device, symmetric_result.weights.data(),
symmetric_result.weights.data() + symmetric_result.n_edges);
auto non_symmetric_sum = thrust::reduce(
thrust::device, non_symmetric_result.weights.data(),
non_symmetric_result.weights.data() + non_symmetric_result.n_edges);

ASSERT_TRUE(raft::match(2 * prims_result, parallel_mst_result,
ASSERT_TRUE(raft::match(2 * prims_result, symmetric_sum,
raft::CompareApprox<float>(0.1)));
ASSERT_TRUE(raft::match(prims_result, non_symmetric_sum,
raft::CompareApprox<float>(0.1)));
}

Expand Down