Skip to content

Commit

Permalink
HDBSCAN: Lazy-loading (and caching) condensed & single-linkage tree o…
Browse files Browse the repository at this point in the history
…bjects (#3986)

…ects

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3986
  • Loading branch information
cjnolet authored Jun 16, 2021
1 parent 87f8e90 commit edecd3b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
43 changes: 20 additions & 23 deletions python/cuml/experimental/cluster/hdbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
cluster_selection_method='eom',
allow_single_cluster=False,
gen_min_span_tree=False,
gen_condensed_tree=False,
gen_single_linkage_tree=False,
handle=None,
verbose=False,
connectivity='knn',
Expand Down Expand Up @@ -308,16 +306,16 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
self.n_clusters_ = None
self.n_leaves_ = None

self._condensed_tree = None
self._single_linkage_tree = None
self.condensed_tree_obj = None
self.single_linkage_tree_obj = None
self.minimum_spanning_tree_ = None

self.gen_min_span_tree_ = gen_min_span_tree
self.gen_condensed_tree = gen_condensed_tree
self.gen_single_linkage_tree = gen_single_linkage_tree

def _build_condensed_tree(self):
@property
def condensed_tree_(self):

if self.gen_condensed_tree:
if self.condensed_tree_obj is None:
raw_tree = np.recarray(shape=(self.condensed_parent_.shape[0],),
formats=[np.intp, np.intp, float, np.intp],
names=('parent', 'child', 'lambda_val',
Expand All @@ -329,14 +327,17 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):

if has_hdbscan_plots():
from hdbscan.plots import CondensedTree
self.condensed_tree_ = \
self.condensed_tree_obj = \
CondensedTree(raw_tree,
self.cluster_selection_epsilon,
self.allow_single_cluster)

def _build_single_linkage_tree(self):
return self.condensed_tree_obj

if self.gen_single_linkage_tree:
@property
def single_linkage_tree_(self):

if self.single_linkage_tree_obj is None:
with cuml.using_output_type("numpy"):
raw_tree = np.column_stack(
(self.children_[0, :self.n_leaves_-1],
Expand All @@ -348,10 +349,13 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):

if has_hdbscan_plots():
from hdbscan.plots import SingleLinkageTree
self.single_linkage_tree_ = SingleLinkageTree(raw_tree)
self.single_linkage_tree_obj = SingleLinkageTree(raw_tree)

return self.single_linkage_tree_obj

def build_minimum_spanning_tree(self, X):

def _build_minimum_spanning_tree(self, X):
if self.gen_min_span_tree_:
if self.gen_min_span_tree_ and self.minimum_spanning_tree_ is None:
with cuml.using_output_type("numpy"):
raw_tree = np.column_stack((self.mst_src_,
self.mst_dst_,
Expand All @@ -363,6 +367,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
from hdbscan.plots import MinimumSpanningTree
self.minimum_spanning_tree_ = \
MinimumSpanningTree(raw_tree, X.to_output("numpy"))
return self.minimum_spanning_tree_

def __dealloc__(self):
delete_hdbscan_output(self)
Expand Down Expand Up @@ -499,8 +504,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
else:
raise ValueError("'affinity' %s not supported." % self.affinity)

print("Calling HDBSCAN")

if self.connectivity == 'knn':
hdbscan(handle_[0],
<float*>input_ptr,
Expand All @@ -515,15 +518,11 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):

self.handle.sync()

print("Done")

self.fit_called_ = True

self._construct_output_attributes()

self._build_minimum_spanning_tree(X_m)
self._build_condensed_tree()
self._build_single_linkage_tree()
self.build_minimum_spanning_tree(X_m)

return self

Expand Down Expand Up @@ -553,6 +552,4 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin):
"n_neighbors",
"alpha",
"gen_min_span_tree",
"gen_single_linkage_tree",
"gen_condensed_tree"
]
22 changes: 22 additions & 0 deletions python/cuml/test/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,25 @@ def test_hdbscan_cluster_patterns(dataset, nrows,

assert np.allclose(np.sort(sk_agg.cluster_persistence_),
np.sort(cuml_agg.cluster_persistence_), rtol=0.1, atol=0.1)


def test_hdbscan_plots():

X, y = make_blobs(int(100),
100,
10,
cluster_std=0.7,
shuffle=False,
random_state=42)

cuml_agg = HDBSCAN(gen_min_span_tree=True)
cuml_agg.fit(X)

assert cuml_agg.condensed_tree_ is not None
assert cuml_agg.minimum_spanning_tree_ is not None
assert cuml_agg.single_linkage_tree_ is not None

cuml_agg = HDBSCAN(gen_min_span_tree=False)
cuml_agg.fit(X)

assert cuml_agg.minimum_spanning_tree_ is None

0 comments on commit edecd3b

Please sign in to comment.