diff --git a/python/cuml/experimental/cluster/hdbscan.pyx b/python/cuml/experimental/cluster/hdbscan.pyx index 1ba8e40a9c..986dd51f30 100644 --- a/python/cuml/experimental/cluster/hdbscan.pyx +++ b/python/cuml/experimental/cluster/hdbscan.pyx @@ -268,8 +268,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): cluster_selection_method='eom', allow_single_cluster=False, gen_min_span_tree=False, - gen_condensed_tree=False, - gen_single_linkage_tree=False, handle=None, verbose=False, connectivity='knn', @@ -308,16 +306,16 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): self.n_clusters_ = None self.n_leaves_ = None - self._condensed_tree = None - self._single_linkage_tree = None + self.condensed_tree_obj = None + self.single_linkage_tree_obj = None + self.minimum_spanning_tree_ = None self.gen_min_span_tree_ = gen_min_span_tree - self.gen_condensed_tree = gen_condensed_tree - self.gen_single_linkage_tree = gen_single_linkage_tree - def _build_condensed_tree(self): + @property + def condensed_tree_(self): - if self.gen_condensed_tree: + if self.condensed_tree_obj is None: raw_tree = np.recarray(shape=(self.condensed_parent_.shape[0],), formats=[np.intp, np.intp, float, np.intp], names=('parent', 'child', 'lambda_val', @@ -329,14 +327,17 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): if has_hdbscan_plots(): from hdbscan.plots import CondensedTree - self.condensed_tree_ = \ + self.condensed_tree_obj = \ CondensedTree(raw_tree, self.cluster_selection_epsilon, self.allow_single_cluster) - def _build_single_linkage_tree(self): + return self.condensed_tree_obj - if self.gen_single_linkage_tree: + @property + def single_linkage_tree_(self): + + if self.single_linkage_tree_obj is None: with cuml.using_output_type("numpy"): raw_tree = np.column_stack( (self.children_[0, :self.n_leaves_-1], @@ -348,10 +349,13 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): if has_hdbscan_plots(): from hdbscan.plots import SingleLinkageTree - self.single_linkage_tree_ = SingleLinkageTree(raw_tree) + self.single_linkage_tree_obj = SingleLinkageTree(raw_tree) + + return self.single_linkage_tree_obj + + def build_minimum_spanning_tree(self, X): - def _build_minimum_spanning_tree(self, X): - if self.gen_min_span_tree_: + if self.gen_min_span_tree_ and self.minimum_spanning_tree_ is None: with cuml.using_output_type("numpy"): raw_tree = np.column_stack((self.mst_src_, self.mst_dst_, @@ -363,6 +367,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): from hdbscan.plots import MinimumSpanningTree self.minimum_spanning_tree_ = \ MinimumSpanningTree(raw_tree, X.to_output("numpy")) + return self.minimum_spanning_tree_ def __dealloc__(self): delete_hdbscan_output(self) @@ -499,8 +504,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): else: raise ValueError("'affinity' %s not supported." % self.affinity) - print("Calling HDBSCAN") - if self.connectivity == 'knn': hdbscan(handle_[0], input_ptr, @@ -515,15 +518,11 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): self.handle.sync() - print("Done") - self.fit_called_ = True self._construct_output_attributes() - self._build_minimum_spanning_tree(X_m) - self._build_condensed_tree() - self._build_single_linkage_tree() + self.build_minimum_spanning_tree(X_m) return self @@ -553,6 +552,4 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): "n_neighbors", "alpha", "gen_min_span_tree", - "gen_single_linkage_tree", - "gen_condensed_tree" ] diff --git a/python/cuml/test/test_hdbscan.py b/python/cuml/test/test_hdbscan.py index c394543197..740d888c46 100644 --- a/python/cuml/test/test_hdbscan.py +++ b/python/cuml/test/test_hdbscan.py @@ -213,3 +213,25 @@ def test_hdbscan_cluster_patterns(dataset, nrows, assert np.allclose(np.sort(sk_agg.cluster_persistence_), np.sort(cuml_agg.cluster_persistence_), rtol=0.1, atol=0.1) + + +def test_hdbscan_plots(): + + X, y = make_blobs(int(100), + 100, + 10, + cluster_std=0.7, + shuffle=False, + random_state=42) + + cuml_agg = HDBSCAN(gen_min_span_tree=True) + cuml_agg.fit(X) + + assert cuml_agg.condensed_tree_ is not None + assert cuml_agg.minimum_spanning_tree_ is not None + assert cuml_agg.single_linkage_tree_ is not None + + cuml_agg = HDBSCAN(gen_min_span_tree=False) + cuml_agg.fit(X) + + assert cuml_agg.minimum_spanning_tree_ is None