Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fixing remaining hdbscan bug #4179

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/include/cuml/cluster/hdbscan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ enum CLUSTER_SELECTION_METHOD { EOM = 0, LEAF = 1 };

class RobustSingleLinkageParams {
public:
int k = 5;
int min_samples = 5;
int min_cluster_size = 5;
int max_cluster_size = 0;
Expand Down
28 changes: 13 additions & 15 deletions cpp/src/hdbscan/detail/reachability.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@ namespace Reachability {
* @tparam value_t data type for distance
* @tparam tpb block size for kernel
* @param[in] knn_dists knn distance array (size n * k)
* @param[in] k neighborhood size
* @param[in] min_samples this neighbor will be selected for core distances
* @param[in] n number of samples
* @param[out] out output array (size n)
* @param[in] stream stream for which to order cuda operations
*/
template <typename value_idx, typename value_t, int tpb = 256>
void core_distances(
value_t* knn_dists, int k, int min_samples, size_t n, value_t* out, cudaStream_t stream)
value_t* knn_dists, int min_samples, size_t n, value_t* out, cudaStream_t stream)
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
{
int blocks = raft::ceildiv(n, (size_t)tpb);

Expand All @@ -67,7 +66,7 @@ void core_distances(
auto indices = thrust::make_counting_iterator<value_idx>(0);

thrust::transform(exec_policy, indices, indices + n, out, [=] __device__(value_idx row) {
return knn_dists[row * k + (min_samples - 1)];
return knn_dists[row * min_samples + (min_samples - 1)];
});
}

Expand Down Expand Up @@ -118,7 +117,6 @@ void mutual_reachability_graph(const raft::handle_t& handle,
size_t m,
size_t n,
raft::distance::DistanceType metric,
int k,
int min_samples,
value_t alpha,
value_idx* indptr,
Expand All @@ -139,10 +137,10 @@ void mutual_reachability_graph(const raft::handle_t& handle,

// This is temporary. Once faiss is updated, we should be able to
// pass value_idx through to knn.
rmm::device_uvector<value_idx> coo_rows(k * m, stream);
rmm::device_uvector<int64_t> int64_indices(k * m, stream);
rmm::device_uvector<value_idx> inds(k * m, stream);
rmm::device_uvector<value_t> dists(k * m, stream);
rmm::device_uvector<value_idx> coo_rows(min_samples * m, stream);
rmm::device_uvector<int64_t> int64_indices(min_samples * m, stream);
rmm::device_uvector<value_idx> inds(min_samples * m, stream);
rmm::device_uvector<value_t> dists(min_samples * m, stream);

// perform knn
brute_force_knn(handle,
Expand All @@ -153,7 +151,7 @@ void mutual_reachability_graph(const raft::handle_t& handle,
m,
int64_indices.data(),
dists.data(),
k,
min_samples,
true,
true,
metric);
Expand All @@ -166,24 +164,24 @@ void mutual_reachability_graph(const raft::handle_t& handle,
[] __device__(int64_t in) -> value_idx { return in; });

// Slice core distances (distances to kth nearest neighbor)
core_distances<value_idx>(dists.data(), k, min_samples, m, core_dists, stream);
core_distances<value_idx>(dists.data(), min_samples, m, core_dists, stream);

/**
* Compute L2 norm
*/
mutual_reachability_knn_l2(
handle, inds.data(), dists.data(), X, m, n, k, core_dists, (value_t)1.0 / alpha);
handle, inds.data(), dists.data(), X, m, n, min_samples, core_dists, (value_t)1.0 / alpha);

// self-loops get max distance
auto coo_rows_counting_itr = thrust::make_counting_iterator<value_idx>(0);
thrust::transform(exec_policy,
coo_rows_counting_itr,
coo_rows_counting_itr + (m * k),
coo_rows_counting_itr + (m * min_samples),
coo_rows.data(),
[k] __device__(value_idx c) -> value_idx { return c / k; });
[min_samples] __device__(value_idx c) -> value_idx { return c / min_samples; });

raft::sparse::linalg::symmetrize(
handle, coo_rows.data(), inds.data(), dists.data(), m, m, k * m, out);
handle, coo_rows.data(), inds.data(), dists.data(), m, m, min_samples * m, out);

raft::sparse::convert::sorted_coo_to_csr(out.rows(), out.nnz, indptr, m + 1, stream);

Expand All @@ -205,4 +203,4 @@ void mutual_reachability_graph(const raft::handle_t& handle,
}; // end namespace Reachability
}; // end namespace detail
}; // end namespace HDBSCAN
}; // end namespace ML
}; // end namespace ML
26 changes: 14 additions & 12 deletions cpp/src/hdbscan/detail/reachability_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ __global__ void l2SelectMinK(faiss::gpu::Tensor<value_t, 2, true> inner_products
faiss::gpu::Tensor<value_t, 1, true> core_dists,
faiss::gpu::Tensor<value_t, 2, true> out_dists,
faiss::gpu::Tensor<int, 2, true> out_inds,
int batch_offset,
int batch_offset_i,
int batch_offset_j,
int k,
value_t initK,
value_t alpha)
Expand Down Expand Up @@ -85,19 +86,19 @@ __global__ void l2SelectMinK(faiss::gpu::Tensor<value_t, 2, true> inner_products

for (; i < limit; i += blockDim.x) {
value_t v = sqrt(faiss::gpu::Math<value_t>::add(
sq_norms[row + batch_offset],
faiss::gpu::Math<value_t>::add(sq_norms[i], inner_products[row][i])));
sq_norms[row + batch_offset_i],
faiss::gpu::Math<value_t>::add(sq_norms[i + batch_offset_j], inner_products[row][i])));

v = max(core_dists[i], max(core_dists[row + batch_offset], alpha * v));
v = max(core_dists[i + batch_offset_j], max(core_dists[row + batch_offset_i], alpha * v));
heap.add(v, i);
}

if (i < inner_products.getSize(1)) {
value_t v = sqrt(faiss::gpu::Math<value_t>::add(
sq_norms[row + batch_offset],
faiss::gpu::Math<value_t>::add(sq_norms[i], inner_products[row][i])));
sq_norms[row + batch_offset_i],
faiss::gpu::Math<value_t>::add(sq_norms[i + batch_offset_j], inner_products[row][i])));

v = max(core_dists[i], max(core_dists[row + batch_offset], alpha * v));
v = max(core_dists[i + batch_offset_j], max(core_dists[row + batch_offset_i], alpha * v));
heap.addThreadQ(v, i);
}

Expand Down Expand Up @@ -127,7 +128,8 @@ void runL2SelectMin(faiss::gpu::Tensor<value_t, 2, true>& productDistances,
faiss::gpu::Tensor<value_t, 1, true>& coreDistances,
faiss::gpu::Tensor<value_t, 2, true>& outDistances,
faiss::gpu::Tensor<int, 2, true>& outIndices,
int batch_offset,
int batch_offset_i,
int batch_offset_j,
int k,
value_t alpha,
cudaStream_t stream)
Expand All @@ -149,7 +151,8 @@ void runL2SelectMin(faiss::gpu::Tensor<value_t, 2, true>& productDistances,
coreDistances, \
outDistances, \
outIndices, \
batch_offset, \
batch_offset_i, \
batch_offset_j, \
k, \
faiss::gpu::Limits<value_t>::getMax(), \
alpha); \
Expand Down Expand Up @@ -323,7 +326,6 @@ void mutual_reachability_knn_l2(const raft::handle_t& handle,
auto outIndexView = out_inds_tensor.narrow(0, i, curQuerySize);

auto queryView = x_tensor.narrow(0, i, curQuerySize);
norms_tensor.narrow(0, i, curQuerySize);

auto outDistanceBufRowView = outDistanceBufs[curStream]->narrow(0, 0, curQuerySize);
auto outIndexBufRowView = outIndexBufs[curStream]->narrow(0, 0, curQuerySize);
Expand Down Expand Up @@ -365,19 +367,19 @@ void mutual_reachability_knn_l2(const raft::handle_t& handle,
outDistanceView,
outIndexView,
i,
j,
k,
alpha,
streams[curStream]);
} else {
norms_tensor.narrow(0, j, curCentroidSize);

// Write into our intermediate output
runL2SelectMin<value_t>(distanceBufView,
norms_tensor,
core_dists_tensor,
outDistanceBufColView,
outIndexBufColView,
i,
j,
k,
alpha,
streams[curStream]);
Expand Down
5 changes: 1 addition & 4 deletions cpp/src/hdbscan/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,18 @@ void build_linkage(const raft::handle_t& handle,
{
auto stream = handle.get_stream();

int k = params.k + 1;

/**
* Mutual reachability graph
*/
rmm::device_uvector<value_idx> mutual_reachability_indptr(m + 1, stream);
raft::sparse::COO<value_t, value_idx> mutual_reachability_coo(stream, k * m * 2);
raft::sparse::COO<value_t, value_idx> mutual_reachability_coo(stream, params.min_samples * m * 2);
rmm::device_uvector<value_t> core_dists(m, stream);

detail::Reachability::mutual_reachability_graph(handle,
X,
(size_t)m,
(size_t)n,
metric,
k,
params.min_samples,
params.alpha,
mutual_reachability_indptr.data(),
Expand Down
3 changes: 1 addition & 2 deletions cpp/test/sg/hdbscan_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class HDBSCANTest : public ::testing::TestWithParam<HDBSCANInputs<T, IdxT>> {
mst_weights.data());

HDBSCAN::Common::HDBSCANParams hdbscan_params;
hdbscan_params.k = params.k;
hdbscan_params.min_cluster_size = params.min_cluster_size;
hdbscan_params.min_samples = params.min_pts;

Expand All @@ -116,6 +115,7 @@ class HDBSCANTest : public ::testing::TestWithParam<HDBSCANInputs<T, IdxT>> {

protected:
HDBSCANInputs<T, IdxT> params;
IdxT* labels_ref;
int k;

double score;
Expand Down Expand Up @@ -218,7 +218,6 @@ class ClusterCondensingTest : public ::testing::TestWithParam<ClusterCondensingI

protected:
ClusterCondensingInputs<T, IdxT> params;
int k;

double score;
};
Expand Down
16 changes: 5 additions & 11 deletions python/cuml/experimental/cluster/hdbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ cdef extern from "cuml/cluster/hdbscan.hpp" namespace "ML::HDBSCAN::Common":
CondensedHierarchy[int, float] &get_condensed_tree()

cdef cppclass HDBSCANParams:
int k
int min_samples
int min_cluster_size
int max_cluster_size,
Expand Down Expand Up @@ -435,12 +434,11 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
handle=None,
verbose=False,
connectivity='knn',
n_neighbors=10,
output_type=None):

super().__init__(handle,
verbose,
output_type)
super().__init__(handle=handle,
verbose=verbose,
output_type=output_type)

if min_samples is None:
min_samples = min_cluster_size
Expand All @@ -449,8 +447,8 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
raise ValueError("'connectivity' can only be one of "
"{'knn', 'pairwise'}")

if n_neighbors > 1023 or n_neighbors < 2:
raise ValueError("'n_neighbors' must be a positive number "
if 2 < min_samples and min_samples > 1023:
raise ValueError("'min_samples' must be a positive number "
"between 2 and 1023")

self.min_cluster_size = min_cluster_size
Expand All @@ -462,7 +460,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
self.alpha = alpha
self.cluster_selection_method = cluster_selection_method
self.allow_single_cluster = allow_single_cluster
self.n_neighbors = n_neighbors
self.connectivity = connectivity

self.fit_called_ = False
Expand Down Expand Up @@ -619,7 +616,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
self.hdbscan_output_ = <size_t>linkage_output

cdef HDBSCANParams params
params.k = self.n_neighbors
params.min_samples = self.min_samples
# params.alpha = self.alpha
params.min_cluster_size = self.min_cluster_size
Expand Down Expand Up @@ -730,7 +726,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):

def get_param_names(self):
return super().get_param_names() + [
"n_neighbors",
"metric",
"min_cluster_size",
"max_cluster_size",
Expand All @@ -740,7 +735,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
"p",
"allow_single_cluster",
"connectivity",
"n_neighbors",
"alpha",
"gen_min_span_tree",
]
22 changes: 16 additions & 6 deletions python/cuml/test/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from cuml.experimental.cluster import condense_hierarchy
from sklearn.datasets import make_blobs


from cuml.metrics import adjusted_rand_score
from cuml.test.utils import get_pattern

Expand Down Expand Up @@ -159,7 +158,6 @@ def test_hdbscan_blobs(nrows, ncols, nclusters,

cuml_agg = HDBSCAN(verbose=logger.level_info,
allow_single_cluster=allow_single_cluster,
n_neighbors=min_samples+1,
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
min_samples=min_samples,
max_cluster_size=max_cluster_size,
min_cluster_size=min_cluster_size,
Expand Down Expand Up @@ -210,7 +208,6 @@ def test_hdbscan_sklearn_datasets(dataset,

cuml_agg = HDBSCAN(verbose=logger.level_info,
allow_single_cluster=allow_single_cluster,
n_neighbors=min_samples,
gen_min_span_tree=True,
min_samples=min_samples,
max_cluster_size=max_cluster_size,
Expand Down Expand Up @@ -263,7 +260,6 @@ def test_hdbscan_sklearn_extract_clusters(dataset,

cuml_agg = HDBSCAN(verbose=logger.level_info,
allow_single_cluster=allow_single_cluster,
n_neighbors=min_samples,
gen_min_span_tree=True,
min_samples=min_samples,
max_cluster_size=max_cluster_size,
Expand Down Expand Up @@ -313,7 +309,6 @@ def test_hdbscan_cluster_patterns(dataset, nrows,

cuml_agg = HDBSCAN(verbose=logger.level_info,
allow_single_cluster=allow_single_cluster,
n_neighbors=min_samples,
min_samples=min_samples,
max_cluster_size=max_cluster_size,
min_cluster_size=min_cluster_size,
Expand Down Expand Up @@ -367,7 +362,6 @@ def test_hdbscan_cluster_patterns_extract_clusters(dataset, nrows,

cuml_agg = HDBSCAN(verbose=logger.level_info,
allow_single_cluster=allow_single_cluster,
n_neighbors=min_samples,
min_samples=min_samples,
max_cluster_size=max_cluster_size,
min_cluster_size=min_cluster_size,
Expand All @@ -393,6 +387,22 @@ def test_hdbscan_cluster_patterns_extract_clusters(dataset, nrows,
sk_agg.probabilities_)


def test_hdbscan_core_dists_bug_4054():
"""
This test explicitly verifies that the MRE from
https://github.com/rapidsai/cuml/issues/4054
matches the reference impl
"""

X, y = datasets.make_moons(n_samples=10000, noise=0.12)

cu_labels_ = HDBSCAN(min_samples=25, min_cluster_size=25).fit_predict(X)
sk_labels_ = hdbscan.HDBSCAN(min_samples=25,
min_cluster_size=25).fit_predict(X)

assert adjusted_rand_score(cu_labels_, sk_labels_) == 1.0


def test_hdbscan_plots():

X, y = make_blobs(int(100),
Expand Down