From bd899e184d63a34e073bc84172a9c16f883b614d Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Thu, 19 Jan 2023 16:20:34 -0500 Subject: [PATCH] HDBSCAN CPU/GPU Interop (#5137) Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Victor Lafargue (https://github.com/viclafargue) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/5137 --- python/cuml/cluster/hdbscan/hdbscan.pyx | 128 +++++++++++++++++++-- python/cuml/cluster/hdbscan/prediction.pyx | 64 ++++++++++- python/cuml/tests/test_device_selection.py | 60 +++++++++- 3 files changed, 240 insertions(+), 12 deletions(-) diff --git a/python/cuml/cluster/hdbscan/hdbscan.pyx b/python/cuml/cluster/hdbscan/hdbscan.pyx index 035ca15f22..f41c86ea01 100644 --- a/python/cuml/cluster/hdbscan/hdbscan.pyx +++ b/python/cuml/cluster/hdbscan/hdbscan.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ cp = gpu_only_import('cupy') from warnings import warn from cuml.internals.array import CumlArray -from cuml.internals.base import Base +from cuml.internals.base import UniversalBase from cuml.common.doc_utils import generate_docstring from pylibraft.common.handle cimport handle_t from rmm._lib.device_uvector cimport device_uvector @@ -36,6 +36,8 @@ from rmm._lib.device_uvector cimport device_uvector from pylibraft.common.handle import Handle from cuml.common import input_to_cuml_array from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.internals.api_decorators import device_interop_preparation +from cuml.internals.api_decorators import enable_device_interop from cuml.internals.mixins import ClusterMixin from cuml.internals.mixins import CMajorInputTagMixin from cuml.internals import logger @@ -321,7 +323,7 @@ def delete_hdbscan_output(obj): del obj.hdbscan_output_ -class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): +class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin): """ HDBSCAN Clustering @@ -470,7 +472,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): Even then in some optimized cases a tree may not be generated. """ - sk_import_path_ = 'hdbscan' + _cpu_estimator_import_path = 'hdbscan.HDBSCAN' labels_ = CumlArrayDescriptor() probabilities_ = CumlArrayDescriptor() @@ -487,6 +489,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): mst_dst_ = CumlArrayDescriptor() mst_weights_ = CumlArrayDescriptor() + @device_interop_preparation def __init__(self, *, min_cluster_size=5, min_samples=None, @@ -591,9 +594,6 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): @property def prediction_data_(self): - if not self.fit_called_: - raise ValueError( - 'The model is not trained yet (call fit() first).') if not self.prediction_data: raise ValueError( @@ -648,6 +648,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): MinimumSpanningTree(raw_tree, X.to_output("numpy")) return self.minimum_spanning_tree_ + @enable_device_interop def generate_prediction_data(self): """ Create data that caches intermediate results used for predicting @@ -701,6 +702,13 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): self.prediction_data_ptr free(prediction_data_ptr) del self.prediction_data_ptr + # this is only constructed when trying to gpu predict + # with a cpu model + if hasattr(self, "inverse_label_map_ptr"): + inverse_label_map_ptr = \ + self.inverse_label_map_ptr + free(inverse_label_map_ptr) + del self.inverse_label_map_ptr def _construct_output_attributes(self): @@ -746,6 +754,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): self.inverse_label_map = CumlArray.empty((0,), dtype="int32") @generate_docstring() + @enable_device_interop def fit(self, X, y=None, convert_dtype=True) -> "HDBSCAN": """ Fit HDBSCAN model from features. @@ -864,6 +873,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): 'type': 'dense', 'description': 'Cluster indexes', 'shape': '(n_samples, 1)'}) + @enable_device_interop def fit_predict(self, X, y=None) -> CumlArray: """ Fit the HDBSCAN model from features and return @@ -997,6 +1007,108 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): self.__dict__.update(state) + def _prep_cpu_to_gpu_prediction(self, convert_dtype=True): + """ + This is an internal function, to be called when HDBSCAN + is trained on CPU but GPU inference is desired. + """ + if not self.prediction_data: + raise ValueError("PredictionData not generated. " + "Please call clusterer.fit again with " + "prediction_data=True or call " + "clusterer.generate_prediction_data()") + + if self._cpu_to_gpu_interop_prepped: + return + + cdef handle_t* handle_ = self.handle.getHandle() + + self.X_m, self.n_rows, self.n_cols, dtype = \ + input_to_cuml_array(self._cpu_model._raw_data, order='C', + check_dtype=[np.float32], + convert_to_dtype=(np.float32 + if convert_dtype + else None)) + + self.condensed_parent_, n_edges, _, _ = \ + input_to_cuml_array(self.condensed_tree_.to_numpy()['parent'], + order='C', + convert_to_dtype=np.int32) + + self.condensed_child_, _, _, _ = \ + input_to_cuml_array(self.condensed_tree_.to_numpy()['child'], + order='C', + convert_to_dtype=np.int32) + + self.condensed_lambdas_, _, _, _ = \ + input_to_cuml_array(self.condensed_tree_.to_numpy()['lambda_val'], + order='C', + convert_to_dtype=np.float32) + + self.condensed_sizes_, _, _, _ = \ + input_to_cuml_array(self.condensed_tree_.to_numpy()['child_size'], + order='C', + convert_to_dtype=np.int32) + + cdef uintptr_t parent_ptr = self.condensed_parent_.ptr + cdef uintptr_t child_ptr = self.condensed_child_.ptr + cdef uintptr_t lambdas_ptr = self.condensed_lambdas_.ptr + cdef uintptr_t sizes_ptr = self.condensed_sizes_.ptr + + cdef CondensedHierarchy[int, float] *condensed_tree = \ + new CondensedHierarchy[int, float]( + handle_[0], self.n_rows, n_edges, + parent_ptr, child_ptr, + lambdas_ptr, sizes_ptr) + self.condensed_tree_ptr = condensed_tree + + self.core_dists = CumlArray.empty(self.n_rows, dtype="float32") + metric = _metrics_mapping[self.metric] + + cdef uintptr_t X_ptr = self.X_m.ptr + cdef uintptr_t core_dists_ptr = self.core_dists.ptr + + compute_core_dists(handle_[0], + X_ptr, + core_dists_ptr, + self.n_rows, + self.n_cols, + metric, + self.min_samples) + + cdef device_uvector[int] *inverse_label_map = \ + new device_uvector[int](0, handle_[0].get_stream()) + + cdef CLUSTER_SELECTION_METHOD cluster_selection_method + if self.cluster_selection_method == 'eom': + cluster_selection_method = CLUSTER_SELECTION_METHOD.EOM + elif self.cluster_selection_method == 'leaf': + cluster_selection_method = CLUSTER_SELECTION_METHOD.LEAF + + compute_inverse_label_map(handle_[0], + deref(condensed_tree), + self.n_rows, + + cluster_selection_method, + deref(inverse_label_map), + self.allow_single_cluster, + self.max_cluster_size, + self.cluster_selection_epsilon) + + self.n_clusters_ = inverse_label_map[0].size() + self.inverse_label_map_ptr = inverse_label_map[0].data() + self.inverse_label_map = \ + _cuml_array_from_ptr(self.inverse_label_map_ptr, + self.n_clusters_ * sizeof(int), + (self.n_clusters_, ), "int32", self) + + self.fit_called_ = True + self.generate_prediction_data() + + self.handle.sync() + + self._cpu_to_gpu_interop_prepped = True + def get_param_names(self): return super().get_param_names() + [ "metric", @@ -1013,7 +1125,7 @@ class HDBSCAN(Base, ClusterMixin, CMajorInputTagMixin): "prediction_data" ] - def get_attributes_names(self): + def get_attr_names(self): attr_names = ['labels_', 'probabilities_', 'cluster_persistence_', 'condensed_tree_', 'single_linkage_tree_', 'outlier_scores_'] diff --git a/python/cuml/cluster/hdbscan/prediction.pyx b/python/cuml/cluster/hdbscan/prediction.pyx index d5e829e2ad..63f2dadb45 100644 --- a/python/cuml/cluster/hdbscan/prediction.pyx +++ b/python/cuml/cluster/hdbscan/prediction.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,12 +30,18 @@ from cuml.common.doc_utils import generate_docstring from pylibraft.common.handle cimport handle_t from pylibraft.common.handle import Handle -from cuml.common import input_to_cuml_array +from cuml.common import ( + input_to_cuml_array, + input_to_host_array +) from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.internals.available_devices import is_cuda_available +from cuml.internals.device_type import DeviceType from cuml.internals.mixins import ClusterMixin from cuml.internals.mixins import CMajorInputTagMixin from cuml.internals import logger from cuml.internals.import_utils import has_hdbscan_plots +from cuml.internals.import_utils import has_hdbscan_prediction import cuml from cuml.metrics.distance_type cimport DistanceType @@ -134,6 +140,34 @@ def all_points_membership_vectors(clusterer): cluster ``j`` is in ``membership_vectors[i, j]``. """ + device_type = cuml.global_settings.device_type + + # cpu infer, cpu/gpu train + if device_type == DeviceType.host: + assert has_hdbscan_prediction() + from hdbscan.prediction import all_points_membership_vectors \ + as cpu_all_points_membership_vectors + + # trained on gpu + if not hasattr(clusterer, "_cpu_model"): + # the reference HDBSCAN implementations uses @property + # for attributes without setters available for them, + # so they can't be transferred from the GPU model + # to the CPU model + raise ValueError("Inferring on CPU is not supported yet when the " + "model has been trained on GPU") + + # this took a long debugging session to figure out, but + # this method on cpu does not work without this copy for some reason + clusterer._cpu_model.prediction_data_.raw_data = \ + clusterer._cpu_model.prediction_data_.raw_data.copy() + return cpu_all_points_membership_vectors(clusterer._cpu_model) + + elif device_type == DeviceType.device: + # trained on cpu + if hasattr(clusterer, "_cpu_model"): + clusterer._prep_cpu_to_gpu_prediction() + if not clusterer.fit_called_: raise ValueError("The clusterer is not fit on data. " "Please call clusterer.fit first") @@ -209,6 +243,32 @@ def approximate_predict(clusterer, points_to_predict, convert_dtype=True): The soft cluster scores for each of the ``points_to_predict`` """ + device_type = cuml.global_settings.device_type + + # cpu infer, cpu/gpu train + if device_type == DeviceType.host: + assert has_hdbscan_prediction() + from hdbscan.prediction import approximate_predict \ + as cpu_approximate_predict + + # trained on gpu + if not hasattr(clusterer, "_cpu_model"): + # the reference HDBSCAN implementations uses @property + # for attributes without setters available for them, + # so they can't be transferred from the GPU model + # to the CPU model + raise ValueError("Inferring on CPU is not supported yet when the " + "model has been trained on GPU") + + host_points_to_predict = input_to_host_array(points_to_predict).array + return cpu_approximate_predict(clusterer._cpu_model, + host_points_to_predict) + + elif device_type == DeviceType.device: + # trained on cpu + if hasattr(clusterer, "_cpu_model"): + clusterer._prep_cpu_to_gpu_prediction() + if not clusterer.fit_called_: raise ValueError("The clusterer is not fit on data. " "Please call clusterer.fit first") diff --git a/python/cuml/tests/test_device_selection.py b/python/cuml/tests/test_device_selection.py index 69b87a2c87..9599a26b13 100644 --- a/python/cuml/tests/test_device_selection.py +++ b/python/cuml/tests/test_device_selection.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,10 +13,13 @@ # limitations under the License. # - from cuml.testing.test_preproc_utils import to_output_type +from cuml.testing.utils import array_equal + +from cuml.cluster.hdbscan import HDBSCAN from cuml.neighbors import NearestNeighbors from cuml.metrics import trustworthiness +from cuml.metrics import adjusted_rand_score from cuml.manifold import UMAP from cuml.linear_model import ( ElasticNet, @@ -29,6 +32,7 @@ from cuml.internals.mem_type import MemoryType from cuml.decomposition import PCA, TruncatedSVD from cuml.common.device_selection import DeviceType, using_device_type +from hdbscan import HDBSCAN as refHDBSCAN from umap import UMAP as refUMAP from sklearn.neighbors import NearestNeighbors as skNearestNeighbors from sklearn.linear_model import Ridge as skRidge @@ -53,6 +57,21 @@ cudf = gpu_only_import('cudf') +def assert_membership_vectors(cu_vecs, sk_vecs): + """ + Assert the membership vectors by taking the adjusted rand score + of the argsorted membership vectors. + """ + if sk_vecs.shape == cu_vecs.shape: + cu_labels_sorted = np.argsort(cu_vecs)[::-1] + sk_labels_sorted = np.argsort(sk_vecs)[::-1] + + k = min(sk_vecs.shape[1], 10) + for i in range(k): + assert adjusted_rand_score(cu_labels_sorted[:, i], + sk_labels_sorted[:, i]) >= 0.85 + + @pytest.mark.parametrize('input', [('cpu', DeviceType.host), ('gpu', DeviceType.device)]) def test_device_type(input): @@ -793,3 +812,40 @@ def test_nn_methods(train_device, infer_device): ref_output = ref_output.todense() output = output.todense() np.testing.assert_allclose(ref_output, output, rtol=0.15) + + +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +def test_hdbscan_methods(train_device, infer_device): + + if train_device == "gpu" and infer_device == "cpu": + pytest.skip("Can't transfer attributes to cpu for now") + + ref_model = refHDBSCAN(prediction_data=True, + approx_min_span_tree=False, + max_cluster_size=0, + min_cluster_size=30) + ref_trained_labels = ref_model.fit_predict(X_train_blob) + + from hdbscan.prediction import all_points_membership_vectors \ + as cpu_all_points_membership_vectors, approximate_predict \ + as cpu_approximate_predict + ref_membership = cpu_all_points_membership_vectors(ref_model) + ref_labels, ref_probs = cpu_approximate_predict(ref_model, X_test_blob) + + model = HDBSCAN(prediction_data=True, + approx_min_span_tree=False, + max_cluster_size=0, + min_cluster_size=30) + with using_device_type(train_device): + trained_labels = model.fit_predict(X_train_blob) + with using_device_type(infer_device): + from cuml.cluster.hdbscan.prediction import \ + all_points_membership_vectors, approximate_predict + membership = all_points_membership_vectors(model) + labels, probs = approximate_predict(model, X_test_blob) + + assert(adjusted_rand_score(trained_labels, ref_trained_labels) >= 0.95) + assert_membership_vectors(membership, ref_membership) + assert(adjusted_rand_score(labels, ref_labels) >= 0.98) + assert(array_equal(probs, ref_probs, unit_tol=0.001, total_tol=0.006))