Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some MST updates for single-linkage / HDBSCAN clustering #119

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions cpp/include/raft/sparse/mst/detail/mst_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,32 +107,48 @@ __global__ void min_edge_per_supervertex(
const vertex_t* color, const vertex_t* color_index, edge_t* new_mst_edge,
bool* mst_edge, const vertex_t* indices, const weight_t* weights,
const weight_t* altered_weights, vertex_t* temp_src, vertex_t* temp_dst,
weight_t* temp_weights, const weight_t* min_edge_color, const vertex_t v) {
weight_t* temp_weights, const weight_t* min_edge_color, const vertex_t v,
bool symmetrize_output) {
auto tid = get_1D_idx<vertex_t>();

if (tid < v) {
vertex_t vertex_color_idx = color_index[tid];
vertex_t vertex_color = color[vertex_color_idx];
edge_t edge_idx = new_mst_edge[tid];

// check if valid outgoing edge was found
// find minimum edge is same as minimum edge of whole supervertex
// if yes, that is part of mst
if (edge_idx != std::numeric_limits<edge_t>::max()) {
weight_t vertex_weight = altered_weights[edge_idx];
if (min_edge_color[vertex_color] == vertex_weight) {
temp_src[tid] = tid;
temp_dst[tid] = indices[edge_idx];
temp_weights[tid] = weights[edge_idx];

mst_edge[edge_idx] = true;
} else {
new_mst_edge[tid] = std::numeric_limits<edge_t>::max();
auto dst = indices[edge_idx];
auto dst_edge_idx = new_mst_edge[dst];
auto dst_color = color[color_index[dst]];
// vertices added each other
// only if destination has found an edge
// the edge points back to source
// the edge is minimum edge found for dst color
bool add_edge = false;
if (dst_edge_idx != std::numeric_limits<edge_t>::max() &&
indices[dst_edge_idx] == tid &&
min_edge_color[dst_color] == altered_weights[dst_edge_idx]) {
if (symmetrize_output || vertex_color < dst_color) {
add_edge = true;
}
} else {
add_edge = true;
}
if (add_edge) {
temp_src[tid] = tid;
temp_dst[tid] = dst;
temp_weights[tid] = weights[edge_idx];
mst_edge[edge_idx] = true;
} else {
new_mst_edge[tid] = std::numeric_limits<edge_t>::max();
}
}
}
}
}

template <typename vertex_t, typename edge_t, typename weight_t>
__global__ void add_reverse_edge(const edge_t* new_mst_edge,
const vertex_t* indices,
Expand Down
35 changes: 20 additions & 15 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@ typedef std::chrono::high_resolution_clock Clock;

// curand generator uniform
inline curandStatus_t curand_generate_uniformX(curandGenerator_t generator,
float* outputPtr, size_t n) {
float* outputPtr, size_t n) {
return curandGenerateUniform(generator, outputPtr, n);
}
inline curandStatus_t curand_generate_uniformX(curandGenerator_t generator,
double* outputPtr, size_t n) {
double* outputPtr, size_t n) {
return curandGenerateUniformDouble(generator, outputPtr, n);
}

template <typename vertex_t, typename edge_t, typename weight_t>
MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_, const vertex_t v_,
const edge_t e_, vertex_t* color_, cudaStream_t stream_)
const edge_t e_, vertex_t* color_, cudaStream_t stream_,
bool symmetrize_output_)
: handle(handle_),
offsets(offsets_),
indices(indices_),
Expand All @@ -70,7 +71,8 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
temp_weights(2 * v_),
mst_edge_count(1, 0),
prev_mst_edge_count(1, 0),
stream(stream_) {
stream(stream_),
symmetrize_output(symmetrize_output_) {
max_blocks = handle_.get_device_properties().maxGridSize[0];
max_threads = handle_.get_device_properties().maxThreadsPerBlock;
sm_count = handle_.get_device_properties().multiProcessorCount;
Expand Down Expand Up @@ -262,9 +264,9 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
// update the colors of both ends its until there is no change in colors
thrust::host_vector<edge_t> curr_mst_edge_count = mst_edge_count;

auto min_pair_nthreads = std::min(v, max_threads);
auto min_pair_nblocks =
std::min((v + min_pair_nthreads - 1) / min_pair_nthreads, max_blocks);
auto min_pair_nthreads = std::min(v, (vertex_t)max_threads);
auto min_pair_nblocks = std::min(
(v + min_pair_nthreads - 1) / min_pair_nthreads, (vertex_t)max_blocks);

rmm::device_vector<bool> done(1, false);

Expand Down Expand Up @@ -316,8 +318,8 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_vertex() {
// Finds the minimum edge from each supervertex to the lowest color
template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
int nthreads = std::min(v, max_threads);
int nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);
auto nthreads = std::min(v, max_threads);
auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);

thrust::fill(temp_src.begin(), temp_src.end(),
std::numeric_limits<vertex_t>::max());
Expand All @@ -334,20 +336,23 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
detail::min_edge_per_supervertex<<<nblocks, nthreads, 0, stream>>>(
color, color_index_ptr, new_mst_edge_ptr, mst_edge_ptr, indices, weights,
altered_weights_ptr, temp_src_ptr, temp_dst_ptr, temp_weights_ptr,
min_edge_color_ptr, v);
min_edge_color_ptr, v, symmetrize_output);

// the above kernel only adds directed mst edges in the case where
// a pair of vertices don't pick the same min edge between them
// so, now we add the reverse edge to make it undirected
detail::add_reverse_edge<<<nblocks, nthreads, 0, stream>>>(
new_mst_edge_ptr, indices, weights, temp_src_ptr, temp_dst_ptr,
temp_weights_ptr, v);
if (symmetrize_output) {
detail::add_reverse_edge<<<nblocks, nthreads, 0, stream>>>(
new_mst_edge_ptr, indices, weights, temp_src_ptr, temp_dst_ptr,
temp_weights_ptr, v);
}
}

template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::check_termination() {
int nthreads = std::min(2 * v, max_threads);
int nblocks = std::min((2 * v + nthreads - 1) / nthreads, max_blocks);
vertex_t nthreads = std::min(2 * v, (vertex_t)max_threads);
vertex_t nblocks =
std::min((2 * v + nthreads - 1) / nthreads, (vertex_t)max_blocks);

// count number of new mst edges
edge_t* mst_edge_count_ptr = mst_edge_count.data().get();
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/sparse/mst/mst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ template <typename vertex_t, typename edge_t, typename weight_t>
raft::Graph_COO<vertex_t, edge_t, weight_t> mst(
const raft::handle_t& handle, edge_t const* offsets, vertex_t const* indices,
weight_t const* weights, vertex_t const v, edge_t const e, vertex_t* color,
cudaStream_t stream) {
cudaStream_t stream, bool symmetrize_output = true) {
MST_solver<vertex_t, edge_t, weight_t> mst_solver(
handle, offsets, indices, weights, v, e, color, stream);
handle, offsets, indices, weights, v, e, color, stream, symmetrize_output);
return mst_solver.solve();
}

Expand Down
10 changes: 6 additions & 4 deletions cpp/include/raft/sparse/mst/mst_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MST_solver {
MST_solver(const raft::handle_t& handle_, const edge_t* offsets_,
const vertex_t* indices_, const weight_t* weights_,
const vertex_t v_, const edge_t e_, vertex_t* color_,
cudaStream_t stream_);
cudaStream_t stream_, bool symmetrize_output_);

raft::Graph_COO<vertex_t, edge_t, weight_t> solve();

Expand All @@ -52,16 +52,18 @@ class MST_solver {
const raft::handle_t& handle;
cudaStream_t stream;

bool symmetrize_output;

//CSR
const edge_t* offsets;
const vertex_t* indices;
const weight_t* weights;
const vertex_t v;
const edge_t e;

int max_blocks;
int max_threads;
int sm_count;
vertex_t max_blocks;
vertex_t max_threads;
vertex_t sm_count;

vertex_t* color; // represent each supervertex as a color
rmm::device_vector<weight_t>
Expand Down
3 changes: 2 additions & 1 deletion cpp/test/mst.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ class MSTTest
vertex_t *color_ptr = thrust::raw_pointer_cast(color.data());

MST_solver<vertex_t, edge_t, weight_t> mst_solver(
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream());
handle, offsets, indices, weights, v, e, color_ptr, handle.get_stream(),
true);
auto result = mst_solver.solve();
raft::print_device_vector("Final MST Src: ", result.src.data(),
result.n_edges, std::cout);
Expand Down