Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Adding simple support for n_clusters=1 in agglomerative clustering and ignoring self-loops #217

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 65 additions & 59 deletions cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -280,81 +280,87 @@ void extract_flattened_clusters(const raft::handle_t &handle, value_idx *labels,
auto stream = handle.get_stream();
auto thrust_policy = rmm::exec_policy(stream);

/**
* Compute levels for each node
*
* 1. Initialize "levels" array of size n_leaves * 2
*
* 2. For each entry in children, write parent
* out for each of the children
*/
// Handle special case where n_clusters == 1
if (n_clusters == 1) {
thrust::fill(thrust_policy, labels, labels + n_leaves, 0);
} else {
/**
* Compute levels for each node
*
* 1. Initialize "levels" array of size n_leaves * 2
*
* 2. For each entry in children, write parent
* out for each of the children
*/

size_t n_edges = (n_leaves - 1) * 2;
size_t n_edges = (n_leaves - 1) * 2;

thrust::device_ptr<const value_idx> d_ptr =
thrust::device_pointer_cast(children);
value_idx n_vertices =
*(thrust::max_element(thrust_policy, d_ptr, d_ptr + n_edges)) + 1;
thrust::device_ptr<const value_idx> d_ptr =
thrust::device_pointer_cast(children);
value_idx n_vertices =
*(thrust::max_element(thrust_policy, d_ptr, d_ptr + n_edges)) + 1;

// Prevent potential infinite loop from labeling disconnected
// connectivities graph.
RAFT_EXPECTS(n_vertices == (n_leaves - 1) * 2,
"Multiple components found in MST or MST is invalid. "
"Cannot find single-linkage solution.");
// Prevent potential infinite loop from labeling disconnected
// connectivities graph.
RAFT_EXPECTS(n_vertices == (n_leaves - 1) * 2,
"Multiple components found in MST or MST is invalid. "
"Cannot find single-linkage solution.");

rmm::device_uvector<value_idx> levels(n_vertices, stream);
rmm::device_uvector<value_idx> levels(n_vertices, stream);

value_idx n_blocks = ceildiv(n_vertices, (value_idx)tpb);
write_levels_kernel<<<n_blocks, tpb, 0, stream>>>(children, levels.data(),
n_vertices);
/**
* Step 1: Find label roots:
*
* 1. Copying children[children.size()-(n_clusters-1):] entries to
* separate arrayo
* 2. sort array
* 3. take first n_clusters entries
*/
value_idx n_blocks = ceildiv(n_vertices, (value_idx)tpb);
write_levels_kernel<<<n_blocks, tpb, 0, stream>>>(children, levels.data(),
n_vertices);
/**
* Step 1: Find label roots:
*
* 1. Copying children[children.size()-(n_clusters-1):] entries to
* separate arrayo
* 2. sort array
* 3. take first n_clusters entries
*/

value_idx child_size = (n_clusters - 1) * 2;
rmm::device_uvector<value_idx> label_roots(child_size, stream);
value_idx child_size = (n_clusters - 1) * 2;
rmm::device_uvector<value_idx> label_roots(child_size, stream);

value_idx children_cpy_start = n_edges - child_size;
raft::copy_async(label_roots.data(), children + children_cpy_start,
child_size, stream);
value_idx children_cpy_start = n_edges - child_size;
raft::copy_async(label_roots.data(), children + children_cpy_start,
child_size, stream);

thrust::sort(thrust_policy, label_roots.data(),
label_roots.data() + (child_size), thrust::greater<value_idx>());
thrust::sort(thrust_policy, label_roots.data(),
label_roots.data() + (child_size),
thrust::greater<value_idx>());

rmm::device_uvector<value_idx> tmp_labels(n_vertices, stream);
rmm::device_uvector<value_idx> tmp_labels(n_vertices, stream);

// Init labels to -1
thrust::fill(thrust_policy, tmp_labels.data(), tmp_labels.data() + n_vertices,
-1);
// Init labels to -1
thrust::fill(thrust_policy, tmp_labels.data(),
tmp_labels.data() + n_vertices, -1);

// Write labels for cluster roots to "labels"
thrust::counting_iterator<uint> first(0);
// Write labels for cluster roots to "labels"
thrust::counting_iterator<uint> first(0);

auto z_iter = thrust::make_zip_iterator(thrust::make_tuple(
first, label_roots.data() + (label_roots.size() - n_clusters)));
auto z_iter = thrust::make_zip_iterator(thrust::make_tuple(
first, label_roots.data() + (label_roots.size() - n_clusters)));

thrust::for_each(thrust_policy, z_iter, z_iter + n_clusters,
init_label_roots<value_idx>(tmp_labels.data()));
thrust::for_each(thrust_policy, z_iter, z_iter + n_clusters,
init_label_roots<value_idx>(tmp_labels.data()));

/**
* Step 2: Propagate labels by having children iterate through their parents
* 1. Initialize labels to -1
* 2. For each element in levels array, propagate until parent's
* label is !=-1
*/
value_idx cut_level = (n_edges / 2) - (n_clusters - 1);
/**
* Step 2: Propagate labels by having children iterate through their parents
* 1. Initialize labels to -1
* 2. For each element in levels array, propagate until parent's
* label is !=-1
*/
value_idx cut_level = (n_edges / 2) - (n_clusters - 1);

inherit_labels<<<n_blocks, tpb, 0, stream>>>(children, levels.data(),
n_leaves, tmp_labels.data(),
cut_level, n_vertices);
inherit_labels<<<n_blocks, tpb, 0, stream>>>(children, levels.data(),
n_leaves, tmp_labels.data(),
cut_level, n_vertices);

// copy tmp labels to actual labels
raft::copy_async(labels, tmp_labels.data(), n_leaves, stream);
// copy tmp labels to actual labels
raft::copy_async(labels, tmp_labels.data(), n_leaves, stream);
}
}

}; // namespace detail
Expand Down
29 changes: 17 additions & 12 deletions cpp/include/raft/sparse/hierarchy/detail/connectivities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <raft/linalg/unary_op.cuh>
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <raft/linalg/distance_type.h>
#include <raft/sparse/hierarchy/common.h>
Expand Down Expand Up @@ -61,6 +62,7 @@ struct distance_graph_impl<raft::hierarchy::LinkageDistance::KNN_GRAPH,
rmm::device_uvector<value_t> &data, int c) {
auto d_alloc = handle.get_device_allocator();
auto stream = handle.get_stream();
auto exec_policy = rmm::exec_policy(stream);

// Need to symmetrize knn into undirected graph
raft::sparse::COO<value_t, value_idx> knn_graph_coo(d_alloc, stream);
Expand All @@ -71,10 +73,25 @@ struct distance_graph_impl<raft::hierarchy::LinkageDistance::KNN_GRAPH,
indices.resize(knn_graph_coo.nnz, stream);
data.resize(knn_graph_coo.nnz, stream);

// self-loops get max distance
auto transform_in = thrust::make_zip_iterator(thrust::make_tuple(
knn_graph_coo.rows(), knn_graph_coo.cols(), knn_graph_coo.vals()));

thrust::transform(
exec_policy, transform_in, transform_in + knn_graph_coo.nnz,
knn_graph_coo.vals(),
[=] __device__(const thrust::tuple<value_idx, value_idx, value_t> &tup) {
bool self_loop = thrust::get<0>(tup) == thrust::get<1>(tup);
return (self_loop * std::numeric_limits<value_t>::max()) +
(!self_loop * thrust::get<2>(tup));
});

raft::sparse::convert::sorted_coo_to_csr(knn_graph_coo.rows(),
knn_graph_coo.nnz, indptr.data(),
m + 1, d_alloc, stream);

// TODO: Wouldn't need to copy here if we could compute knn
// graph directly on the device uvectors
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
raft::copy_async(indices.data(), knn_graph_coo.cols(), knn_graph_coo.nnz,
stream);
raft::copy_async(data.data(), knn_graph_coo.vals(), knn_graph_coo.nnz,
Expand Down Expand Up @@ -111,18 +128,6 @@ void get_distance_graph(const raft::handle_t &handle, const value_t *X,

distance_graph_impl<dist_type, value_idx, value_t> dist_graph;
dist_graph.run(handle, X, m, n, metric, indptr, indices, data, c);

// a little adjustment for distances of 0.
// TODO: This will only need to be done when src_v==dst_v
raft::linalg::unaryOp<value_t>(
data.data(), data.data(), data.size(),
[] __device__(value_t input) {
if (input == 0)
return std::numeric_limits<value_t>::max();
else
return input;
},
stream);
}

}; // namespace detail
Expand Down
29 changes: 14 additions & 15 deletions cpp/include/raft/sparse/mst/detail/mst_solver_inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ MST_solver<vertex_t, edge_t, weight_t>::MST_solver(
//Initially, color holds the vertex id as color
auto policy = rmm::exec_policy(stream);
if (initialize_colors_) {
thrust::sequence(policy->on(stream), color.begin(), color.end(), 0);
thrust::sequence(policy->on(stream), color_index, color_index + v, 0);
thrust::sequence(policy, color.begin(), color.end(), 0);
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
thrust::sequence(policy, color_index, color_index + v, 0);
} else {
raft::copy(color.data().get(), color_index, v, stream);
}
thrust::sequence(policy->on(stream), next_color.begin(), next_color.end(), 0);
thrust::sequence(policy, next_color.begin(), next_color.end(), 0);
}

template <typename vertex_t, typename edge_t, typename weight_t>
Expand Down Expand Up @@ -213,22 +213,22 @@ weight_t MST_solver<vertex_t, edge_t, weight_t>::alteration_max() {
auto policy = rmm::exec_policy(stream);
rmm::device_vector<weight_t> tmp(e);
thrust::device_ptr<const weight_t> weights_ptr(weights);
thrust::copy(policy->on(stream), weights_ptr, weights_ptr + e, tmp.begin());
thrust::copy(policy, weights_ptr, weights_ptr + e, tmp.begin());
//sort tmp weights
thrust::sort(policy->on(stream), tmp.begin(), tmp.end());
thrust::sort(policy, tmp.begin(), tmp.end());

//remove duplicates
auto new_end = thrust::unique(policy->on(stream), tmp.begin(), tmp.end());
auto new_end = thrust::unique(policy, tmp.begin(), tmp.end());

//min(a[i+1]-a[i])/2
auto begin =
thrust::make_zip_iterator(thrust::make_tuple(tmp.begin(), tmp.begin() + 1));
auto end =
thrust::make_zip_iterator(thrust::make_tuple(new_end - 1, new_end));
auto init = tmp[1] - tmp[0];
auto max = thrust::transform_reduce(policy->on(stream), begin, end,
alteration_functor<weight_t>(), init,
thrust::minimum<weight_t>());
auto max =
thrust::transform_reduce(policy, begin, end, alteration_functor<weight_t>(),
init, thrust::minimum<weight_t>());
return max / static_cast<weight_t>(2);
}

Expand Down Expand Up @@ -307,9 +307,9 @@ void MST_solver<vertex_t, edge_t, weight_t>::label_prop(vertex_t* mst_src,
template <typename vertex_t, typename edge_t, typename weight_t>
void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_vertex() {
auto policy = rmm::exec_policy(stream);
thrust::fill(policy->on(stream), min_edge_color.begin(), min_edge_color.end(),
thrust::fill(policy, min_edge_color.begin(), min_edge_color.end(),
std::numeric_limits<weight_t>::max());
thrust::fill(policy->on(stream), new_mst_edge.begin(), new_mst_edge.end(),
thrust::fill(policy, new_mst_edge.begin(), new_mst_edge.end(),
std::numeric_limits<weight_t>::max());

int n_threads = 32;
Expand All @@ -332,7 +332,7 @@ void MST_solver<vertex_t, edge_t, weight_t>::min_edge_per_supervertex() {
auto nblocks = std::min((v + nthreads - 1) / nthreads, max_blocks);

auto policy = rmm::exec_policy(stream);
thrust::fill(policy->on(stream), temp_src.begin(), temp_src.end(),
thrust::fill(policy, temp_src.begin(), temp_src.end(),
std::numeric_limits<vertex_t>::max());

vertex_t* color_ptr = color.data().get();
Expand Down Expand Up @@ -402,9 +402,8 @@ void MST_solver<vertex_t, edge_t, weight_t>::append_src_dst_pair(
thrust::make_tuple(temp_src.end(), temp_dst.end(), temp_weights.end()));

// copy new mst edges to final output
thrust::copy_if(policy->on(stream), temp_src_dst_zip_begin,
temp_src_dst_zip_end, src_dst_zip_end,
new_edges_functor<vertex_t, weight_t>());
thrust::copy_if(policy, temp_src_dst_zip_begin, temp_src_dst_zip_end,
src_dst_zip_end, new_edges_functor<vertex_t, weight_t>());
}

} // namespace mst
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/mst/detail/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

#pragma once

#include <rmm/thrust_rmm_allocator.h>
#include <iostream>
#include <rmm/device_vector.hpp>
#define MST_TIME

namespace raft {
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/sparse/mst/mst_solver.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

#pragma once

#include <rmm/thrust_rmm_allocator.h>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/device_vector.hpp>

namespace raft {

Expand Down
1 change: 0 additions & 1 deletion cpp/test/mst.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <bits/stdc++.h>

#include <gtest/gtest.h>
#include <rmm/thrust_rmm_allocator.h>
#include <iostream>
#include <rmm/device_buffer.hpp>
#include <vector>
Expand Down