Skip to content

Commit

Permalink
Single-linkage Hierarchical Clustering C++ (#3545)
Browse files Browse the repository at this point in the history
Closes #3518 

I've closed the original PR (#3308) which included both SLHC & HDBSCAN and opened this PR to only include the SLHC changes. 

This PR contains an implementation of SLHC which is currently broken across RAFT & cuML. Once we move the dense pairwise distance primitive over to RAFT the entire SLHC algorithm can live in RAFT so it can be shared w/ cugraph, and will just be exposed through cuml.

If reviewing this PR, please also review the corresponding RAFT PR: rapidsai/raft#140

Authors:
  - Corey J. Nolet (@cjnolet)

Approvers:
  - Divye Gala (@divyegala)
  - Dante Gama Dessavre (@dantegd)
  - Mike Wendt (@mike-wendt)

URL: #3545
  • Loading branch information
cjnolet authored Mar 18, 2021
1 parent 14bd6c1 commit 0d02e76
Show file tree
Hide file tree
Showing 16 changed files with 882 additions and 797 deletions.
2 changes: 1 addition & 1 deletion conda/recipes/cuml/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ requirements:
- libcumlprims {{ minor_version }}
- cupy>=7.8.0,<9.0.0a0
- treelite=1.0.0
- nccl>=2.5
- nccl>=2.8.4
- ucx-py {{ minor_version }}
- ucx-proc=*=gpu
- dask>=2.12.0
Expand Down
4 changes: 2 additions & 2 deletions conda/recipes/libcuml/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ requirements:
- clang=8.0.1
- clang-tools=8.0.1
host:
- nccl >=2.5
- nccl >=2.8.4
- cudf {{ minor_version }}
- cudatoolkit {{ cuda_version }}.*
- ucx-py {{ minor_version }}
Expand All @@ -52,7 +52,7 @@ requirements:
run:
- libcumlprims {{ minor_version }}
- cudf {{ minor_version }}
- nccl>=2.5
- nccl>=2.8.4
- ucx-py {{ minor_version }}
- ucx-proc=*=gpu
- {{ pin_compatible('cudatoolkit', max_pin='x.x') }}
Expand Down
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ if(BUILD_CUML_CPP_LIBRARY)
src/kmeans/kmeans.cu
src/knn/knn.cu
src/knn/knn_sparse.cu
src/hierarchy/linkage.cu
src/metrics/accuracy_score.cu
src/metrics/adjusted_rand_index.cu
src/metrics/completeness_score.cu
Expand Down
1 change: 1 addition & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ if(BUILD_CUML_BENCH)
sg/arima_loglikelihood.cu
sg/dbscan.cu
sg/kmeans.cu
sg/linkage.cu
sg/main.cpp
sg/rf_classifier.cu
# FIXME: RF Regressor is having an issue where the tests now seem to take
Expand Down
100 changes: 100 additions & 0 deletions cpp/bench/sg/linkage.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/linalg/distance_type.h>
#include <raft/sparse/hierarchy/common.h>
#include <cuml/cluster/linkage.hpp>
#include <cuml/common/logger.hpp>
#include <cuml/cuml.hpp>
#include <utility>
#include "benchmark.cuh"

namespace ML {
namespace Bench {
namespace linkage {

struct Params {
DatasetParams data;
BlobsParams blobs;
};

template <typename D>
class Linkage : public BlobsFixture<D> {
public:
Linkage(const std::string& name, const Params& p)
: BlobsFixture<D>(name, p.data, p.blobs) {}

protected:
void runBenchmark(::benchmark::State& state) override {
using MLCommon::Bench::CudaEventTimer;
if (!this->params.rowMajor) {
state.SkipWithError("Single-Linkage only supports row-major inputs");
}

this->loopOnState(state, [this]() {
out_arrs.labels = labels;
out_arrs.children = out_children;

ML::single_linkage_neighbors(
*this->handle, this->data.X, this->params.nrows, this->params.ncols,
&out_arrs, raft::distance::DistanceType::L2Unexpanded, 15, 50);
});
}

void allocateTempBuffers(const ::benchmark::State& state) override {
this->alloc(labels, this->params.nrows);
this->alloc(out_children, (this->params.nrows - 1) * 2);
}

void deallocateTempBuffers(const ::benchmark::State& state) override {
this->dealloc(labels, this->params.nrows);
this->dealloc(out_children, (this->params.nrows - 1) * 2);
}

private:
int* labels;
int* out_children;
raft::hierarchy::linkage_output<int, D> out_arrs;
};

std::vector<Params> getInputs() {
std::vector<Params> out;
Params p;
p.data.rowMajor = true;
p.blobs.cluster_std = 5.0;
p.blobs.shuffle = false;
p.blobs.center_box_min = -10.0;
p.blobs.center_box_max = 10.0;
p.blobs.seed = 12345ULL;
std::vector<std::pair<int, int>> rowcols = {
{35000, 128}, {16384, 128}, {12288, 128}, {8192, 128}, {4096, 128},
};
for (auto& rc : rowcols) {
p.data.nrows = rc.first;
p.data.ncols = rc.second;
for (auto nclass : std::vector<int>({1})) {
p.data.nclasses = nclass;
out.push_back(p);
}
}
return out;
}

ML_BENCH_REGISTER(Params, Linkage<float>, "blobs", getInputs());

} // namespace linkage
} // end namespace Bench
} // end namespace ML
2 changes: 1 addition & 1 deletion cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ else(DEFINED ENV{RAFT_PATH})

ExternalProject_Add(raft
GIT_REPOSITORY https://github.com/rapidsai/raft.git
GIT_TAG 6455e05b3889db2b495cf3189b33c2b07bfbebf2
GIT_TAG fc46618d76d70710b07d445e79d3e07dea6cad2f
PREFIX ${RAFT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand Down
78 changes: 78 additions & 0 deletions cpp/include/cuml/cluster/linkage.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/linalg/distance_type.h>
#include <raft/sparse/hierarchy/common.h>

#include <cuml/cuml.hpp>

namespace ML {

/**
* @brief Computes single-linkage hierarchical clustering on a dense input
* feature matrix and outputs the labels, dendrogram, and minimum spanning tree.
* Connectivities are constructed using the full n^2 pairwise distance matrix.
* This can be very fast for smaller datasets when there is enough memory
* available.
* @param[in] handle raft handle to encapsulate expensive resources
* @param[in] X dense feature matrix on device
* @param[in] m number of rows in X
* @param[in] n number of columns in X
* @param[out] out container object for output arrays
* @param[in] metric distance metric to use. Must be supported by the
* dense pairwise distances API.
* @param[out] n_clusters number of clusters to cut from resulting dendrogram
*/
void single_linkage_pairwise(const raft::handle_t &handle, const float *X,
size_t m, size_t n,
raft::hierarchy::linkage_output<int, float> *out,
raft::distance::DistanceType metric,
int n_clusters = 5);

/**
* @brief Computes single-linkage hierarchical clustering on a dense input
* feature matrix and outputs the labels, dendrogram, and minimum spanning tree.
* Connectivities are constructed using a k-nearest neighbors graph. While this
* strategy enables the algorithm to scale to much higher numbers of rows,
* it comes with the downside that additional knn steps may need to be
* executed to connect an otherwise unconnected k-nn graph.
* @param[in] handle raft handle to encapsulate expensive resources
* @param[in] X dense feature matrix on device
* @param[in] m number of rows in X
* @param[in] n number of columns in X
* @param[out] out container object for output arrays
* @param[in] metric distance metric to use. Must be supported by the
* dense pairwise distances API.
* @param[out] c the optimal value of k is guaranteed to be at least log(n) + c
* where c is some constant. This constant can usually be set to a fairly low
* value, like 15, and still maintain good performance.
* @param[out] n_clusters number of clusters to cut from resulting dendrogram
*/
void single_linkage_neighbors(const raft::handle_t &handle, const float *X,
size_t m, size_t n,
raft::hierarchy::linkage_output<int, float> *out,
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Unexpanded,
int c = 15, int n_clusters = 5);

void single_linkage_pairwise(
const raft::handle_t &handle, const float *X, size_t m, size_t n,
raft::hierarchy::linkage_output<int64_t, float> *out,
raft::distance::DistanceType metric, int n_clusters = 5);

}; // namespace ML
2 changes: 1 addition & 1 deletion cpp/include/cuml/neighbors/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ struct IVFSQParam : IVFParam {
* @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This
* is ignored if the metric_type is not Minkowski.
*/
void brute_force_knn(raft::handle_t &handle, std::vector<float *> &input,
void brute_force_knn(const raft::handle_t &handle, std::vector<float *> &input,
std::vector<int> &sizes, int D, float *search_items, int n,
int64_t *res_I, float *res_D, int k,
bool rowMajorIndex = false, bool rowMajorQuery = false,
Expand Down
50 changes: 50 additions & 0 deletions cpp/src/hierarchy/linkage.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuml/cluster/linkage.hpp>
#include <raft/sparse/hierarchy/single_linkage.hpp>
#include "pw_dist_graph.cuh"

namespace ML {

void single_linkage_pairwise(const raft::handle_t &handle, const float *X,
size_t m, size_t n,
raft::hierarchy::linkage_output<int, float> *out,
raft::distance::DistanceType metric,
int n_clusters) {
raft::hierarchy::single_linkage<int, float,
raft::hierarchy::LinkageDistance::PAIRWISE>(
handle, X, m, n, metric, out, 0, n_clusters);
}

void single_linkage_neighbors(const raft::handle_t &handle, const float *X,
size_t m, size_t n,
raft::hierarchy::linkage_output<int, float> *out,
raft::distance::DistanceType metric, int c,
int n_clusters) {
raft::hierarchy::single_linkage<int, float,
raft::hierarchy::LinkageDistance::KNN_GRAPH>(
handle, X, m, n, metric, out, c, n_clusters);
}

struct distance_graph_impl_int_float
: public raft::hierarchy::detail::distance_graph_impl<
raft::hierarchy::LinkageDistance::PAIRWISE, int, float> {};
struct distance_graph_impl_int_double
: public raft::hierarchy::detail::distance_graph_impl<
raft::hierarchy::LinkageDistance::PAIRWISE, int, double> {};

}; // end namespace ML
Loading

0 comments on commit 0d02e76

Please sign in to comment.