Skip to content

Commit

Permalink
AgglomerativeClustering support single cluster and ignore only zero d…
Browse files Browse the repository at this point in the history
…istances from self-loops (rapidsai#3824)

Closes rapidsai#3801 
Closes rapidsai#3802 

Corresponding RAFT PR: rapidsai/raft#217

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3824
  • Loading branch information
cjnolet authored May 20, 2021
1 parent e5a2ac6 commit eaadf44
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
15 changes: 15 additions & 0 deletions cpp/src/hierarchy/pw_dist_graph.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/distance/distance.cuh>

#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <raft/linalg/distance_type.h>
#include <raft/mr/device/buffer.hpp>
Expand Down Expand Up @@ -71,6 +72,7 @@ void pairwise_distances(const raft::handle_t &handle, const value_t *X,
value_idx *indptr, value_idx *indices, value_t *data) {
auto d_alloc = handle.get_device_allocator();
auto stream = handle.get_stream();
auto exec_policy = rmm::exec_policy(stream);

value_idx nnz = m * m;

Expand All @@ -90,6 +92,19 @@ void pairwise_distances(const raft::handle_t &handle, const value_t *X,
// usage to hand it a sparse array here.
raft::distance::pairwise_distance<value_t, value_idx>(
X, X, data, m, m, n, workspace, metric, stream);

// self-loops get max distance
auto transform_in = thrust::make_zip_iterator(
thrust::make_tuple(thrust::make_counting_iterator(0), data));

thrust::transform(
exec_policy, transform_in, transform_in + nnz, data,
[=] __device__(const thrust::tuple<value_idx, value_t> &tup) {
value_idx idx = thrust::get<0>(tup);
bool self_loop = idx % m == idx / m;
return (self_loop * std::numeric_limits<value_t>::max()) +
(!self_loop * thrust::get<1>(tup));
});
}

/**
Expand Down
28 changes: 21 additions & 7 deletions python/cuml/test/test_agglomerative.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,27 @@
import cupy as cp


@pytest.mark.parametrize('connectivity', ['knn', 'pairwise'])
def test_duplicate_distances(connectivity):
X = cp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [2.0, 2.0, 2.0]])

cuml_agg = AgglomerativeClustering(n_clusters=2, affinity="euclidean",
linkage="single", n_neighbors=3,
connectivity=connectivity)

sk_agg = cluster.AgglomerativeClustering(n_clusters=2,
affinity="euclidean",
linkage="single")

cuml_agg.fit(X)
sk_agg.fit(X.get())

assert(adjusted_rand_score(cuml_agg.labels_, sk_agg.labels_) == 1.0)


@pytest.mark.parametrize('nrows', [100, 1000])
@pytest.mark.parametrize('ncols', [25, 50])
@pytest.mark.parametrize('nclusters', [2, 10, 50])
@pytest.mark.parametrize('nclusters', [1, 2, 10, 50])
@pytest.mark.parametrize('k', [3, 5, 15])
@pytest.mark.parametrize('connectivity', ['knn', 'pairwise'])
def test_single_linkage_sklearn_compare(nrows, ncols, nclusters,
Expand All @@ -37,17 +55,13 @@ def test_single_linkage_sklearn_compare(nrows, ncols, nclusters,
ncols,
nclusters,
cluster_std=1.0,
shuffle=False,
random_state=42)
shuffle=False)

cuml_agg = AgglomerativeClustering(
n_clusters=nclusters, affinity='euclidean', linkage='single',
n_neighbors=k, connectivity=connectivity)

try:
cuml_agg.fit(X)
except Exception:
cuml_agg.fit(X)
cuml_agg.fit(X)

sk_agg = cluster.AgglomerativeClustering(
n_clusters=nclusters, affinity='euclidean', linkage='single')
Expand Down

0 comments on commit eaadf44

Please sign in to comment.