diff --git a/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh b/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh index 69fd8d4dc9..1ac075489a 100644 --- a/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh +++ b/cpp/include/raft/sparse/hierarchy/detail/agglomerative.cuh @@ -73,9 +73,15 @@ class UnionFind { }; /** - * Standard single-threaded agglomerative labeling on host. This should work - * well for smaller sizes of m. This is a C++ port of the original reference - * implementation of HDBSCAN. + * Agglomerative labeling on host. This has not been found to be a bottleneck + * in the algorithm. A parallel version of this can be done using a parallel + * variant of Kruskal's MST algorithm + * (ref http://cucis.ece.northwestern.edu/publications/pdf/HenPat12.pdf), + * which breaks apart the sorted MST results into overlapping subsets and + * independently runs Kruskal's algorithm on each subset, merging them back + * together into a single hierarchy when complete. Unfortunately, + * this is nontrivial and the speedup wouldn't be useful until this + * becomes a bottleneck. * * @tparam value_idx * @tparam value_t @@ -91,9 +97,8 @@ class UnionFind { template void build_dendrogram_host(const handle_t &handle, const value_idx *rows, const value_idx *cols, const value_t *data, - size_t nnz, value_idx *children, - rmm::device_uvector &out_delta, - rmm::device_uvector &out_size) { + size_t nnz, value_idx *children, value_t *out_delta, + value_idx *out_size) { auto d_alloc = handle.get_device_allocator(); auto stream = handle.get_stream(); @@ -103,8 +108,6 @@ void build_dendrogram_host(const handle_t &handle, const value_idx *rows, std::vector mst_dst_h(n_edges); std::vector mst_weights_h(n_edges); - printf("n_edges: %d\n", n_edges); - update_host(mst_src_h.data(), rows, n_edges, stream); update_host(mst_dst_h.data(), cols, n_edges, stream); update_host(mst_weights_h.data(), data, n_edges, stream); @@ -113,12 +116,14 @@ void build_dendrogram_host(const handle_t &handle, const value_idx *rows, std::vector children_h(n_edges * 2); std::vector out_size_h(n_edges); + std::vector out_delta_h(n_edges); UnionFind U(nnz + 1); for (value_idx i = 0; i < nnz; i++) { value_idx a = mst_src_h[i]; value_idx b = mst_dst_h[i]; + value_t delta = mst_weights_h[i]; value_idx aa = U.find(a); value_idx bb = U.find(b); @@ -127,72 +132,15 @@ void build_dendrogram_host(const handle_t &handle, const value_idx *rows, children_h[children_idx] = aa; children_h[children_idx + 1] = bb; + out_delta_h[i] = delta; out_size_h[i] = U.size[aa] + U.size[bb]; U.perform_union(aa, bb); } - out_size.resize(n_edges, stream); - - printf("Finished dendrogram\n"); - raft::update_device(children, children_h.data(), n_edges * 2, stream); - raft::update_device(out_size.data(), out_size_h.data(), n_edges, stream); -} - -/** - * Parallel agglomerative labeling. This amounts to a parallel Kruskal's - * MST algorithm, which breaks apart the sorted MST results into overlapping - * subsets and independently runs Kruskal's algorithm on each subset, - * merging them back together into a single hierarchy when complete. - * - * This outputs the same format as the reference HDBSCAN, but as 4 separate - * arrays, rather than a single 2D array. - * - * Reference: http://cucis.ece.northwestern.edu/publications/pdf/HenPat12.pdf - * - * TODO: Investigate potential for the following end-to-end single-hierarchy batching: - * For each of k (independent) batches over the input: - * - Sample n elements from X - * - Compute mutual reachability graph of batch - * - Construct labels from batch - * - * The sampled datasets should have some overlap across batches. This will - * allow for the cluster hierarchies to be merged. Being able to batch - * will reduce the memory cost so that the full n^2 pairwise distances - * don't need to be materialized in memory all at once. - * - * @tparam value_idx - * @tparam value_t - * @param[in] handle the raft handle - * @param[in] rows src edges of the sorted MST - * @param[in] cols dst edges of the sorted MST - * @param[in] nnz the number of edges in the sorted MST - * @param[out] out_src parents of output - * @param[out] out_dst children of output - * @param[out] out_delta distances of output - * @param[out] out_size cluster sizes of output - * @param[in] k_folds number of folds for parallelizing label step - */ -template -void build_dendrogram_device(const handle_t &handle, const value_idx *rows, - const value_idx *cols, const value_t *data, - value_idx nnz, value_idx *children, - value_t *out_delta, value_idx *out_size, - value_idx k_folds) { - ASSERT(k_folds < nnz / 2, "k_folds must be < n_edges / 2"); - /** - * divide (sorted) mst coo into overlapping subsets. Easiest way to do this is to - * break it into k-folds and iterate through two folds at a time. - */ - - // 1. Generate ranges for the overlapping subsets - - // 2. Run union-find in parallel for each pair of folds - - // 3. Sort individual label hierarchies - - // 4. Merge label hierarchies together + raft::update_device(out_size, out_size_h.data(), n_edges, stream); + raft::update_device(out_delta, out_delta_h.data(), n_edges, stream); } template diff --git a/cpp/include/raft/sparse/hierarchy/detail/mst.cuh b/cpp/include/raft/sparse/hierarchy/detail/mst.cuh index 8ffcfe0f2b..765a5ad77f 100644 --- a/cpp/include/raft/sparse/hierarchy/detail/mst.cuh +++ b/cpp/include/raft/sparse/hierarchy/detail/mst.cuh @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -35,29 +36,6 @@ namespace raft { namespace hierarchy { namespace detail { -/** - * Sorts a COO by its weight - * @tparam value_idx - * @tparam value_t - * @param[inout] rows source edges - * @param[inout] cols dest edges - * @param[inout] data edge weights - * @param[in] nnz number of edges in edge list - * @param[in] stream cuda stream for which to order cuda operations - */ -template -void sort_coo_by_data(value_idx *rows, value_idx *cols, value_t *data, - value_idx nnz, cudaStream_t stream) { - thrust::device_ptr t_rows = thrust::device_pointer_cast(rows); - thrust::device_ptr t_cols = thrust::device_pointer_cast(cols); - thrust::device_ptr t_data = thrust::device_pointer_cast(data); - - auto first = thrust::make_zip_iterator(thrust::make_tuple(rows, cols)); - - thrust::sort_by_key(thrust::cuda::par.on(stream), t_data, t_data + nnz, - first); -} - template void merge_msts(raft::Graph_COO &coo1, raft::Graph_COO &coo2, @@ -95,10 +73,11 @@ void merge_msts(raft::Graph_COO &coo1, * @param[inout] color the color labels array returned from the mst invocation * @return updated MST edge list */ -template +template void connect_knn_graph(const raft::handle_t &handle, const value_t *X, raft::Graph_COO &msf, size_t m, size_t n, value_idx *color, + red_op reduction_op, raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) { auto d_alloc = handle.get_device_allocator(); @@ -106,8 +85,8 @@ void connect_knn_graph(const raft::handle_t &handle, const value_t *X, raft::sparse::COO connected_edges(d_alloc, stream); - raft::linkage::connect_components(handle, connected_edges, - X, color, m, n); + raft::linkage::connect_components( + handle, connected_edges, X, color, m, n, reduction_op); rmm::device_uvector indptr2(m + 1, stream); raft::sparse::convert::sorted_coo_to_csr(connected_edges.rows(), @@ -147,38 +126,34 @@ void connect_knn_graph(const raft::handle_t &handle, const value_t *X, * @param[in] max_iter maximum iterations to run knn graph connection. This * argument is really just a safeguard against the potential for infinite loops. */ -template +template void build_sorted_mst(const raft::handle_t &handle, const value_t *X, const value_idx *indptr, const value_idx *indices, const value_t *pw_dists, size_t m, size_t n, - rmm::device_uvector &mst_src, - rmm::device_uvector &mst_dst, - rmm::device_uvector &mst_weight, - const size_t nnz, + value_idx *mst_src, value_idx *mst_dst, + value_t *mst_weight, value_idx *color, size_t nnz, + red_op reduction_op, raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded, int max_iter = 10) { auto d_alloc = handle.get_device_allocator(); auto stream = handle.get_stream(); - rmm::device_uvector color(m, stream); - // We want to have MST initialize colors on first call. auto mst_coo = raft::mst::mst( - handle, indptr, indices, pw_dists, (value_idx)m, nnz, color.data(), stream, - false, true); + handle, indptr, indices, pw_dists, (value_idx)m, nnz, color, stream, false, + true); int iters = 1; - int n_components = - linkage::get_n_components(color.data(), m, d_alloc, stream); + int n_components = linkage::get_n_components(color, m, d_alloc, stream); while (n_components > 1 && iters < max_iter) { - connect_knn_graph(handle, X, mst_coo, m, n, - color.data()); + connect_knn_graph(handle, X, mst_coo, m, n, color, + reduction_op); iters++; - n_components = linkage::get_n_components(color.data(), m, d_alloc, stream); + n_components = linkage::get_n_components(color, m, d_alloc, stream); } /** @@ -189,7 +164,7 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X, * 1. There is a bug in this code somewhere * 2. Either the given KNN graph wasn't generated from X or the same metric is not being used * to generate the 1-nn (currently only L2SqrtExpanded is supported). - * 3. max_iter was not large enough to connect the graph. + * 3. max_iter was not large enough to connect the graph (less likely). * * Note that a KNN graph generated from 50 random isotropic balls (with significant overlap) * was able to be connected in a single iteration. @@ -201,20 +176,15 @@ void build_sorted_mst(const raft::handle_t &handle, const value_t *X, " or increase 'max_iter'", max_iter); - sort_coo_by_data(mst_coo.src.data(), mst_coo.dst.data(), - mst_coo.weights.data(), mst_coo.n_edges, stream); - - // TODO: be nice if we could pass these directly into the MST - mst_src.resize(mst_coo.n_edges, stream); - mst_dst.resize(mst_coo.n_edges, stream); - mst_weight.resize(mst_coo.n_edges, stream); + raft::sparse::op::coo_sort_by_weight(mst_coo.src.data(), mst_coo.dst.data(), + mst_coo.weights.data(), mst_coo.n_edges, + stream); - raft::copy_async(mst_src.data(), mst_coo.src.data(), mst_coo.n_edges, stream); - raft::copy_async(mst_dst.data(), mst_coo.dst.data(), mst_coo.n_edges, stream); - raft::copy_async(mst_weight.data(), mst_coo.weights.data(), mst_coo.n_edges, - stream); + raft::copy_async(mst_src, mst_coo.src.data(), mst_coo.n_edges, stream); + raft::copy_async(mst_dst, mst_coo.dst.data(), mst_coo.n_edges, stream); + raft::copy_async(mst_weight, mst_coo.weights.data(), mst_coo.n_edges, stream); } }; // namespace detail }; // namespace hierarchy -}; // namespace raft \ No newline at end of file +}; // namespace raft diff --git a/cpp/include/raft/sparse/hierarchy/single_linkage.hpp b/cpp/include/raft/sparse/hierarchy/single_linkage.hpp index 096e154e87..01a033945c 100644 --- a/cpp/include/raft/sparse/hierarchy/single_linkage.hpp +++ b/cpp/include/raft/sparse/hierarchy/single_linkage.hpp @@ -70,16 +70,19 @@ void single_linkage(const raft::handle_t &handle, const value_t *X, size_t m, detail::get_distance_graph( handle, X, m, n, metric, indptr, indices, pw_dists, c); - rmm::device_uvector mst_rows(EMPTY, stream); - rmm::device_uvector mst_cols(EMPTY, stream); - rmm::device_uvector mst_data(EMPTY, stream); + rmm::device_uvector mst_rows(m - 1, stream); + rmm::device_uvector mst_cols(m - 1, stream); + rmm::device_uvector mst_data(m - 1, stream); /** * 2. Construct MST, sorted by weights */ + rmm::device_uvector color(m, stream); + raft::linkage::FixConnectivitiesRedOp op(color.data(), m); detail::build_sorted_mst( - handle, X, indptr.data(), indices.data(), pw_dists.data(), m, n, mst_rows, - mst_cols, mst_data, indices.size(), metric); + handle, X, indptr.data(), indices.data(), pw_dists.data(), m, n, + mst_rows.data(), mst_cols.data(), mst_data.data(), color.data(), + indices.size(), op, metric); pw_dists.release(); @@ -93,7 +96,7 @@ void single_linkage(const raft::handle_t &handle, const value_t *X, size_t m, // Create dendrogram detail::build_dendrogram_host( handle, mst_rows.data(), mst_cols.data(), mst_data.data(), n_edges, - out->children, out_delta, out_size); + out->children, out_delta.data(), out_size.data()); detail::extract_flattened_clusters(handle, out->labels, out->children, n_clusters, m); diff --git a/cpp/include/raft/sparse/op/reduce.cuh b/cpp/include/raft/sparse/op/reduce.cuh index 5856f8b1d8..53c9f89074 100644 --- a/cpp/include/raft/sparse/op/reduce.cuh +++ b/cpp/include/raft/sparse/op/reduce.cuh @@ -136,8 +136,8 @@ void max_duplicates(const raft::handle_t &handle, compute_duplicates_mask(diff.data(), rows, cols, nnz, stream); - thrust::exclusive_scan(exec_policy, diff.data(), diff.data() + diff.size(), - diff.data()); + thrust::exclusive_scan(thrust::cuda::par.on(stream), diff.data(), + diff.data() + diff.size(), diff.data()); // compute final size value_idx size = 0; diff --git a/cpp/include/raft/sparse/op/sort.h b/cpp/include/raft/sparse/op/sort.h index 792983cc9b..9dbe2b67c5 100644 --- a/cpp/include/raft/sparse/op/sort.h +++ b/cpp/include/raft/sparse/op/sort.h @@ -92,6 +92,29 @@ void coo_sort(COO *const in, coo_sort(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), d_alloc, stream); } + +/** + * Sorts a COO by its weight + * @tparam value_idx + * @tparam value_t + * @param[inout] rows source edges + * @param[inout] cols dest edges + * @param[inout] data edge weights + * @param[in] nnz number of edges in edge list + * @param[in] stream cuda stream for which to order cuda operations + */ +template +void coo_sort_by_weight(value_idx *rows, value_idx *cols, value_t *data, + value_idx nnz, cudaStream_t stream) { + thrust::device_ptr t_rows = thrust::device_pointer_cast(rows); + thrust::device_ptr t_cols = thrust::device_pointer_cast(cols); + thrust::device_ptr t_data = thrust::device_pointer_cast(data); + + auto first = thrust::make_zip_iterator(thrust::make_tuple(rows, cols)); + + thrust::sort_by_key(thrust::cuda::par.on(stream), t_data, t_data + nnz, + first); +} }; // namespace op }; // end NAMESPACE sparse }; // end NAMESPACE raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/selection/connect_components.cuh b/cpp/include/raft/sparse/selection/connect_components.cuh index 386f4f1830..8aae90f1d8 100644 --- a/cpp/include/raft/sparse/selection/connect_components.cuh +++ b/cpp/include/raft/sparse/selection/connect_components.cuh @@ -200,23 +200,22 @@ struct LookupColorOp { * @param[in] d_alloc device allocator to use * @param[in] stream cuda stream for which to order cuda operations */ -template +template void perform_1nn(cub::KeyValuePair *kvp, value_idx *nn_colors, value_idx *colors, const value_t *X, size_t n_rows, size_t n_cols, std::shared_ptr d_alloc, - cudaStream_t stream) { + cudaStream_t stream, red_op reduction_op) { rmm::device_uvector workspace(n_rows, stream); rmm::device_uvector x_norm(n_rows, stream); raft::linalg::rowNorm(x_norm.data(), X, n_cols, n_rows, raft::linalg::L2Norm, true, stream); - FixConnectivitiesRedOp red_op(colors, n_rows); raft::distance::fusedL2NN, - value_idx>(kvp, X, X, x_norm.data(), x_norm.data(), - n_rows, n_rows, n_cols, workspace.data(), - red_op, red_op, true, true, stream); + value_idx>( + kvp, X, X, x_norm.data(), x_norm.data(), n_rows, n_rows, n_cols, + workspace.data(), reduction_op, reduction_op, true, true, stream); LookupColorOp extract_colors_op(colors); thrust::transform(thrust::cuda::par.on(stream), kvp, kvp + n_rows, nn_colors, @@ -318,11 +317,11 @@ void min_components_by_color(raft::sparse::COO &coo, * @param[in] n_rows number of rows in X * @param[in] n_cols number of cols in X */ -template +template void connect_components(const raft::handle_t &handle, raft::sparse::COO &out, const value_t *X, const value_idx *orig_colors, - size_t n_rows, size_t n_cols, + size_t n_rows, size_t n_cols, red_op reduction_op, raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) { auto d_alloc = handle.get_device_allocator(); @@ -352,7 +351,7 @@ void connect_components(const raft::handle_t &handle, rmm::device_uvector src_indices(n_rows, stream); perform_1nn(temp_inds_dists.data(), nn_colors.data(), colors.data(), X, - n_rows, n_cols, d_alloc, stream); + n_rows, n_cols, d_alloc, stream, reduction_op); /** * Sort data points by color (neighbors are not sorted) diff --git a/cpp/test/sparse/connect_components.cu b/cpp/test/sparse/connect_components.cu index 68db00374c..d98f9de9c3 100644 --- a/cpp/test/sparse/connect_components.cu +++ b/cpp/test/sparse/connect_components.cu @@ -92,16 +92,18 @@ class ConnectComponentsTest : public ::testing::TestWithParam< */ rmm::device_uvector colors(params.n_row, stream); - auto mst_coo = raft::mst::mst( + auto mst_coo = raft::mst::mst( handle, indptr.data(), knn_graph_coo.cols(), knn_graph_coo.vals(), params.n_row, knn_graph_coo.nnz, colors.data(), stream, false, true); /** * 3. connect_components to fix connectivities */ + raft::linkage::FixConnectivitiesRedOp red_op( + colors.data(), params.n_row); raft::linkage::connect_components( - handle, out_edges, data.data(), colors.data(), params.n_row, - params.n_col); + handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, + red_op); /** * Construct final edge list