Skip to content

Commit

Permalink
Adding simple support for n_clusters=1 in agglomerative clustering an…
Browse files Browse the repository at this point in the history
…d ignoring self-loops (#217)

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #217
  • Loading branch information
cjnolet authored May 12, 2021
1 parent d1b02ca commit 7d83f0d
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 89 deletions.
124 changes: 65 additions & 59 deletions cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -280,81 +280,87 @@ void extract_flattened_clusters(const raft::handle_t &handle, value_idx *labels,
auto stream = handle.get_stream();
auto thrust_policy = rmm::exec_policy(stream);

/**
* Compute levels for each node
*
* 1. Initialize "levels" array of size n_leaves * 2
*
* 2. For each entry in children, write parent
* out for each of the children
*/
// Handle special case where n_clusters == 1
if (n_clusters == 1) {
thrust::fill(thrust_policy, labels, labels + n_leaves, 0);
} else {
/**
* Compute levels for each node
*
* 1. Initialize "levels" array of size n_leaves * 2
*
* 2. For each entry in children, write parent
* out for each of the children
*/

size_t n_edges = (n_leaves - 1) * 2;
size_t n_edges = (n_leaves - 1) * 2;

thrust::device_ptr<const value_idx> d_ptr =
thrust::device_pointer_cast(children);
value_idx n_vertices =
*(thrust::max_element(thrust_policy, d_ptr, d_ptr + n_edges)) + 1;
thrust::device_ptr<const value_idx> d_ptr =
thrust::device_pointer_cast(children);
value_idx n_vertices =
*(thrust::max_element(thrust_policy, d_ptr, d_ptr + n_edges)) + 1;

// Prevent potential infinite loop from labeling disconnected
// connectivities graph.
RAFT_EXPECTS(n_vertices == (n_leaves - 1) * 2,
"Multiple components found in MST or MST is invalid. "
"Cannot find single-linkage solution.");
// Prevent potential infinite loop from labeling disconnected
// connectivities graph.
RAFT_EXPECTS(n_vertices == (n_leaves - 1) * 2,
"Multiple components found in MST or MST is invalid. "
"Cannot find single-linkage solution.");

rmm::device_uvector<value_idx> levels(n_vertices, stream);
rmm::device_uvector<value_idx> levels(n_vertices, stream);

value_idx n_blocks = ceildiv(n_vertices, (value_idx)tpb);
write_levels_kernel<<<n_blocks, tpb, 0, stream>>>(children, levels.data(),
n_vertices);
/**
* Step 1: Find label roots:
*
* 1. Copying children[children.size()-(n_clusters-1):] entries to
* separate arrayo
* 2. sort array
* 3. take first n_clusters entries
*/
value_idx n_blocks = ceildiv(n_vertices, (value_idx)tpb);
write_levels_kernel<<<n_blocks, tpb, 0, stream>>>(children, levels.data(),
n_vertices);
/**
* Step 1: Find label roots:
*
* 1. Copying children[children.size()-(n_clusters-1):] entries to
* separate arrayo
* 2. sort array
* 3. take first n_clusters entries
*/

value_idx child_size = (n_clusters - 1) * 2;
rmm::device_uvector<value_idx> label_roots(child_size, stream);
value_idx child_size = (n_clusters - 1) * 2;
rmm::device_uvector<value_idx> label_roots(child_size, stream);

value_idx children_cpy_start = n_edges - child_size;
raft::copy_async(label_roots.data(), children + children_cpy_start,
child_size, stream);
value_idx children_cpy_start = n_edges - child_size;
raft::copy_async(label_roots.data(), children + children_cpy_start,
child_size, stream);

thrust::sort(thrust_policy, label_roots.data(),
label_roots.data() + (child_size), thrust::greater<value_idx>());
thrust::sort(thrust_policy, label_roots.data(),
label_roots.data() + (child_size),
thrust::greater<value_idx>());

rmm::device_uvector<value_idx> tmp_labels(n_vertices, stream);
rmm::device_uvector<value_idx> tmp_labels(n_vertices, stream);

// Init labels to -1
thrust::fill(thrust_policy, tmp_labels.data(), tmp_labels.data() + n_vertices,
-1);
// Init labels to -1
thrust::fill(thrust_policy, tmp_labels.data(),
tmp_labels.data() + n_vertices, -1);

// Write labels for cluster roots to "labels"
thrust::counting_iterator<uint> first(0);
// Write labels for cluster roots to "labels"
thrust::counting_iterator<uint> first(0);

auto z_iter = thrust::make_zip_iterator(thrust::make_tuple(
first, label_roots.data() + (label_roots.size() - n_clusters)));
auto z_iter = thrust::make_zip_iterator(thrust::make_tuple(
first, label_roots.data() + (label_roots.size() - n_clusters)));

thrust::for_each(thrust_policy, z_iter, z_iter + n_clusters,
init_label_roots<value_idx>(tmp_labels.data()));
thrust::for_each(thrust_policy, z_iter, z_iter + n_clusters,
init_label_roots<value_idx>(tmp_labels.data()));

/**
* Step 2: Propagate labels by having children iterate through their parents
* 1. Initialize labels to -1
* 2. For each element in levels array, propagate until parent's
* label is !=-1
*/
value_idx cut_level = (n_edges / 2) - (n_clusters - 1);
/**
* Step 2: Propagate labels by having children iterate through their parents
* 1. Initialize labels to -1
* 2. For each element in levels array, propagate until parent's
* label is !=-1
*/
value_idx cut_level = (n_edges / 2) - (n_clusters - 1);

inherit_labels<<<n_blocks, tpb, 0, stream>>>(children, levels.data(),
n_leaves, tmp_labels.data(),
cut_level, n_vertices);
inherit_labels<<<n_blocks, tpb, 0, stream>>>(children, levels.data(),
n_leaves, tmp_labels.data(),
cut_level, n_vertices);

// copy tmp labels to actual labels
raft::copy_async(labels, tmp_labels.data(), n_leaves, stream);
// copy tmp labels to actual labels
raft::copy_async(labels, tmp_labels.data(), n_leaves, stream);
}
}

}; // namespace detail
Expand Down
30 changes: 18 additions & 12 deletions cpp/include/raft/sparse/hierarchy/detail/connectivities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <raft/linalg/unary_op.cuh>
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <raft/linalg/distance_type.h>
#include <raft/sparse/hierarchy/common.h>
Expand Down Expand Up @@ -61,6 +62,7 @@ struct distance_graph_impl<raft::hierarchy::LinkageDistance::KNN_GRAPH,
rmm::device_uvector<value_t> &data, int c) {
auto d_alloc = handle.get_device_allocator();
auto stream = handle.get_stream();
auto exec_policy = rmm::exec_policy(stream);

// Need to symmetrize knn into undirected graph
raft::sparse::COO<value_t, value_idx> knn_graph_coo(d_alloc, stream);
Expand All @@ -71,10 +73,26 @@ struct distance_graph_impl<raft::hierarchy::LinkageDistance::KNN_GRAPH,
indices.resize(knn_graph_coo.nnz, stream);
data.resize(knn_graph_coo.nnz, stream);

// self-loops get max distance
auto transform_in = thrust::make_zip_iterator(thrust::make_tuple(
knn_graph_coo.rows(), knn_graph_coo.cols(), knn_graph_coo.vals()));

thrust::transform(
exec_policy, transform_in, transform_in + knn_graph_coo.nnz,
knn_graph_coo.vals(),
[=] __device__(const thrust::tuple<value_idx, value_idx, value_t> &tup) {
bool self_loop = thrust::get<0>(tup) == thrust::get<1>(tup);
return (self_loop * std::numeric_limits<value_t>::max()) +
(!self_loop * thrust::get<2>(tup));
});

raft::sparse::convert::sorted_coo_to_csr(knn_graph_coo.rows(),
knn_graph_coo.nnz, indptr.data(),
m + 1, d_alloc, stream);

// TODO: Wouldn't need to copy here if we could compute knn
// graph directly on the device uvectors
// ref: https://github.com/rapidsai/raft/issues/227
raft::copy_async(indices.data(), knn_graph_coo.cols(), knn_graph_coo.nnz,
stream);
raft::copy_async(data.data(), knn_graph_coo.vals(), knn_graph_coo.nnz,
Expand Down Expand Up @@ -111,18 +129,6 @@ void get_distance_graph(const raft::handle_t &handle, const value_t *X,

distance_graph_impl<dist_type, value_idx, value_t> dist_graph;
dist_graph.run(handle, X, m, n, metric, indptr, indices, data, c);

// a little adjustment for distances of 0.
// TODO: This will only need to be done when src_v==dst_v
raft::linalg::unaryOp<value_t>(
data.data(), data.data(), data.size(),
[] __device__(value_t input) {
if (input == 0)
return std::numeric_limits<value_t>::max();
else
return input;
},
stream);
}

}; // namespace detail
Expand Down
31 changes: 16 additions & 15 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
#include <iostream>

#include <raft/cudart_utils.h>

#include <rmm/device_buffer.hpp>
#include <rmm/exec_policy.hpp>

namespace raft {
namespace mst {
Expand Down Expand Up @@ -83,12 +85,12 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
//Initially, color holds the vertex id as color
auto policy = rmm::exec_policy(stream);
if (initialize_colors_) {
thrust::sequence(policy->on(stream), color.begin(), color.end(), 0);
thrust::sequence(policy->on(stream), color_index, color_index + v, 0);
thrust::sequence(policy, color.begin(), color.end(), 0);
thrust::sequence(policy, color_index, color_index + v, 0);
} else {
raft::copy(color.data().get(), color_index, v, stream);
}
thrust::sequence(policy->on(stream), next_color.begin(), next_color.end(), 0);
thrust::sequence(policy, next_color.begin(), next_color.end(), 0);
}

template <typename vertex_t, typename edge_t, typename weight_t>
Expand Down Expand Up @@ -213,22 +215,22 @@ weight_t MST_solver<vertex_t, edge_t, weight_t>::alteration_max() {
auto policy = rmm::exec_policy(stream);
rmm::device_vector<weight_t> tmp(e);
thrust::device_ptr<const weight_t> weights_ptr(weights);
thrust::copy(policy->on(stream), weights_ptr, weights_ptr + e, tmp.begin());
thrust::copy(policy, weights_ptr, weights_ptr + e, tmp.begin());
//sort tmp weights
thrust::sort(policy->on(stream), tmp.begin(), tmp.end());
thrust::sort(policy, tmp.begin(), tmp.end());

//remove duplicates
auto new_end = thrust::unique(policy->on(stream), tmp.begin(), tmp.end());
auto new_end = thrust::unique(policy, tmp.begin(), tmp.end());

//min(a[i+1]-a[i])/2
auto begin =
thrust::make_zip_iterator(thrust::make_tuple(tmp.begin(), tmp.begin() + 1));
auto end =
thrust::make_zip_iterator(thrust::make_tuple(new_end - 1, new_end));
auto init = tmp[1] - tmp[0];
auto max = thrust::transform_reduce(policy->on(stream), begin, end,
alteration_functor<weight_t>(), init,
thrust::minimum<weight_t>());
auto max =
thrust::transform_reduce(policy, begin, end, alteration_functor<weight_t>(),
init, thrust::minimum<weight_t>());
return max / static_cast<weight_t>(2);
}

Expand Down Expand Up @@ -307,9 +309,9 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_vertex() {
auto policy = rmm::exec_policy(stream);
thrust::fill(policy->on(stream), min_edge_color.begin(), min_edge_color.end(),
thrust::fill(policy, min_edge_color.begin(), min_edge_color.end(),
std::numeric_limits<weight_t>::max());
thrust::fill(policy->on(stream), new_mst_edge.begin(), new_mst_edge.end(),
thrust::fill(policy, new_mst_edge.begin(), new_mst_edge.end(),
std::numeric_limits<weight_t>::max());

int n_threads = 32;
Expand All @@ -332,7 +334,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);

auto policy = rmm::exec_policy(stream);
thrust::fill(policy->on(stream), temp_src.begin(), temp_src.end(),
thrust::fill(policy, temp_src.begin(), temp_src.end(),
std::numeric_limits<vertex_t>::max());

vertex_t* color_ptr = color.data().get();
Expand Down Expand Up @@ -402,9 +404,8 @@ void MST_solver<vertex_t, edge_t, weight_t>::append_src_dst_pair(
thrust::make_tuple(temp_src.end(), temp_dst.end(), temp_weights.end()));

// copy new mst edges to final output
thrust::copy_if(policy->on(stream), temp_src_dst_zip_begin,
temp_src_dst_zip_end, src_dst_zip_end,
new_edges_functor<vertex_t, weight_t>());
thrust::copy_if(policy, temp_src_dst_zip_begin, temp_src_dst_zip_end,
src_dst_zip_end, new_edges_functor<vertex_t, weight_t>());
}

} // namespace mst
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/mst/detail/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

#pragma once

#include <rmm/thrust_rmm_allocator.h>
#include <iostream>
#include <rmm/device_vector.hpp>
#define MST_TIME

namespace raft {
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/mst/mst_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

#pragma once

#include <rmm/thrust_rmm_allocator.h>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/device_vector.hpp>

namespace raft {

Expand Down
1 change: 0 additions & 1 deletion cpp/test/mst.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <bits/stdc++.h>

#include <gtest/gtest.h>
#include <rmm/thrust_rmm_allocator.h>
#include <iostream>
#include <rmm/device_buffer.hpp>
#include <vector>
Expand Down

0 comments on commit 7d83f0d

Please sign in to comment.