diff --git a/cpp/bench/sg/umap.cu b/cpp/bench/sg/umap.cu index 73366639e8..dd4a4a7dca 100644 --- a/cpp/bench/sg/umap.cu +++ b/cpp/bench/sg/umap.cu @@ -163,8 +163,6 @@ class UmapTransform : public UmapBase { this->data.X.data(), this->params.nrows, this->params.ncols, - nullptr, - nullptr, this->data.X.data(), this->params.nrows, embeddings, diff --git a/cpp/include/cuml/manifold/common.hpp b/cpp/include/cuml/manifold/common.hpp index 78d69987fc..3346f9127e 100644 --- a/cpp/include/cuml/manifold/common.hpp +++ b/cpp/include/cuml/manifold/common.hpp @@ -104,16 +104,10 @@ struct manifold_sparse_inputs_t : public manifold_inputs_t { * @tparam value_t */ template -struct manifold_precomputed_knn_inputs_t : public manifold_dense_inputs_t { - manifold_precomputed_knn_inputs_t(value_idx* knn_indices_, - value_t* knn_dists_, - value_t* X_, - value_t* y_, - int n_, - int d_, - int n_neighbors_) - : manifold_dense_inputs_t(X_, y_, n_, d_), - knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_) +struct manifold_precomputed_knn_inputs_t : public manifold_inputs_t { + manifold_precomputed_knn_inputs_t( + value_idx* knn_indices_, value_t* knn_dists_, value_t* y_, int n_, int d_, int n_neighbors_) + : manifold_inputs_t(y_, n_, d_), knn_graph(n_, n_neighbors_, knn_indices_, knn_dists_) { } diff --git a/cpp/include/cuml/manifold/umap.hpp b/cpp/include/cuml/manifold/umap.hpp index bdc704460e..a160464577 100644 --- a/cpp/include/cuml/manifold/umap.hpp +++ b/cpp/include/cuml/manifold/umap.hpp @@ -119,6 +119,8 @@ void fit(const raft::handle_t& handle, * @param[in] y: pointer to labels array * @param[in] n: n_samples of input array * @param[in] d: n_features of input array + * @param[in] knn_indices: pointer to knn_indices of input (optional) + * @param[in] knn_dists: pointer to knn_dists of input (optional) * @param[in] params: pointer to ML::UMAPParams object * @param[out] embeddings: pointer to embedding produced through projection * @param[out] graph: pointer to fuzzy simplicial set graph @@ -131,6 +133,8 @@ void fit_sparse(const raft::handle_t& handle, float* y, int n, int d, + int* knn_indices, + float* knn_dists, UMAPParams* params, float* embeddings, raft::sparse::COO* graph); @@ -142,8 +146,6 @@ void fit_sparse(const raft::handle_t& handle, * @param[in] X: pointer to input array to be infered * @param[in] n: n_samples of input array to be infered * @param[in] d: n_features of input array to be infered - * @param[in] knn_indices: pointer to knn_indices of input (optional) - * @param[in] knn_dists: pointer to knn_dists of input (optional) * @param[in] orig_X: pointer to original training array * @param[in] orig_n: number of rows in original training array * @param[in] embedding: pointer to embedding created during training @@ -155,8 +157,6 @@ void transform(const raft::handle_t& handle, float* X, int n, int d, - int64_t* knn_indices, - float* knn_dists, float* orig_X, int orig_n, float* embedding, diff --git a/cpp/src/umap/umap.cu b/cpp/src/umap/umap.cu index 9a653772c3..b9281b0d87 100644 --- a/cpp/src/umap/umap.cu +++ b/cpp/src/umap/umap.cu @@ -50,7 +50,7 @@ std::unique_ptr> get_graph( CUML_LOG_DEBUG("Calling UMAP::get_graph() with precomputed KNN"); manifold_precomputed_knn_inputs_t inputs( - knn_indices, knn_dists, X, y, n, d, params->n_neighbors); + knn_indices, knn_dists, y, n, d, params->n_neighbors); if (y != nullptr) { UMAPAlgo::_get_graph_supervised inputs( - knn_indices, knn_dists, X, y, n, d, params->n_neighbors); + knn_indices, knn_dists, y, n, d, params->n_neighbors); if (y != nullptr) { UMAPAlgo::_fit_supervised* graph) { - manifold_sparse_inputs_t inputs(indptr, indices, data, y, nnz, n, d); - if (y != nullptr) { - UMAPAlgo:: - _fit_supervised, TPB_X>( - handle, inputs, params, embeddings, graph); + if (knn_indices != nullptr && knn_dists != nullptr) { + manifold_precomputed_knn_inputs_t inputs( + knn_indices, knn_dists, y, n, d, params->n_neighbors); + if (y != nullptr) { + UMAPAlgo::_fit_supervised, + TPB_X>(handle, inputs, params, embeddings, graph); + } else { + UMAPAlgo::_fit, + TPB_X>(handle, inputs, params, embeddings, graph); + } } else { - UMAPAlgo::_fit, TPB_X>( - handle, inputs, params, embeddings, graph); + manifold_sparse_inputs_t inputs(indptr, indices, data, y, nnz, n, d); + if (y != nullptr) { + UMAPAlgo::_fit_supervised, + TPB_X>(handle, inputs, params, embeddings, graph); + } else { + UMAPAlgo::_fit, + TPB_X>(handle, inputs, params, embeddings, graph); + } } } @@ -158,8 +179,6 @@ void transform(const raft::handle_t& handle, float* X, int n, int d, - knn_indices_dense_t* knn_indices, - float* knn_dists, float* orig_X, int orig_n, float* embedding, @@ -167,20 +186,10 @@ void transform(const raft::handle_t& handle, UMAPParams* params, float* transformed) { - if (knn_indices != nullptr && knn_dists != nullptr) { - manifold_precomputed_knn_inputs_t inputs( - knn_indices, knn_dists, X, nullptr, n, d, params->n_neighbors); - UMAPAlgo::_transform, - TPB_X>( - handle, inputs, inputs, embedding, embedding_n, params, transformed); - } else { - manifold_dense_inputs_t inputs(X, nullptr, n, d); - manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); - UMAPAlgo::_transform, TPB_X>( - handle, inputs, orig_inputs, embedding, embedding_n, params, transformed); - } + manifold_dense_inputs_t inputs(X, nullptr, n, d); + manifold_dense_inputs_t orig_inputs(orig_X, nullptr, orig_n, d); + UMAPAlgo::_transform, TPB_X>( + handle, inputs, orig_inputs, embedding, embedding_n, params, transformed); } void transform_sparse(const raft::handle_t& handle, diff --git a/cpp/test/sg/umap_parametrizable_test.cu b/cpp/test/sg/umap_parametrizable_test.cu index bd73273141..8a259f4c4e 100644 --- a/cpp/test/sg/umap_parametrizable_test.cu +++ b/cpp/test/sg/umap_parametrizable_test.cu @@ -225,8 +225,6 @@ class UMAPParametrizableTest : public ::testing::Test { X, n_samples, umap_params.n_components, - knn_indices, - knn_dists, X, n_samples, model_embedding, diff --git a/python/cuml/common/sparsefuncs.py b/python/cuml/common/sparsefuncs.py index dd0e7bbaba..019821e174 100644 --- a/python/cuml/common/sparsefuncs.py +++ b/python/cuml/common/sparsefuncs.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from cuml.internals.safe_imports import gpu_only_import_from -from cuml.common.kernel_utils import cuda_kernel_factory -import cuml.internals -from cuml.internals.import_utils import has_scipy -from cuml.internals.memory_utils import with_cupy_rmm -from cuml.internals.input_utils import input_to_cuml_array -from cuml.internals.safe_imports import gpu_only_import + import math +import cuml +from cuml.internals.input_utils import input_to_cuml_array, input_to_cupy_array +from cuml.internals.memory_utils import with_cupy_rmm +from cuml.internals.import_utils import has_scipy +from cuml.common.kernel_utils import cuda_kernel_factory from cuml.internals.safe_imports import cpu_only_import +from cuml.internals.safe_imports import gpu_only_import +from cuml.internals.safe_imports import gpu_only_import_from np = cpu_only_import('numpy') cp = gpu_only_import('cupy') cupyx = gpu_only_import('cupyx') @@ -30,6 +31,15 @@ cp_csc_matrix = gpu_only_import_from('cupyx.scipy.sparse', 'csc_matrix') +if has_scipy(): + from scipy.sparse import csr_matrix, coo_matrix, csc_matrix +else: + from cuml.common.import_utils import DummyClass + csr_matrix = DummyClass + coo_matrix = DummyClass + csc_matrix = DummyClass + + def _map_l1_norm_kernel(dtype): """Creates cupy RawKernel for csr_raw_normalize_l1 function.""" @@ -180,20 +190,12 @@ def _insert_zeros(ary, zero_indices): @with_cupy_rmm -def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False): +def extract_knn_graph(knn_graph): """ Converts KNN graph from CSR, COO and CSC formats into separate distance and indice arrays. Input can be a cupy sparse graph (device) or a numpy sparse graph (host). """ - if has_scipy(): - from scipy.sparse import csr_matrix, coo_matrix, csc_matrix - else: - from cuml.internals.import_utils import DummyClass - csr_matrix = DummyClass - coo_matrix = DummyClass - csc_matrix = DummyClass - if isinstance(knn_graph, (csc_matrix, cp_csc_matrix)): knn_graph = cupyx.scipy.sparse.csr_matrix(knn_graph) n_samples = knn_graph.shape[0] @@ -212,25 +214,90 @@ def extract_knn_graph(knn_graph, convert_dtype=True, sparse=False): knn_indices = knn_graph.col if knn_indices is not None: - convert_to_dtype = None - if convert_dtype: - convert_to_dtype = np.int32 if sparse else np.int64 - knn_dists = knn_graph.data + return knn_indices, knn_dists + else: + return None + + +@with_cupy_rmm +def extract_pairwise_dists(pw_dists, n_neighbors): + """ + Extract the nearest neighbors distances and indices + from a pairwise distance matrix. + + Parameters + ---------- + pw_dists: paiwise distances matrix of shape (n_samples, n_samples) + n_neighbors: number of nearest neighbors + + (inspired from Scikit-Learn code) + """ + pw_dists, _, _, _ = input_to_cupy_array(pw_dists) + + n_rows = pw_dists.shape[0] + sample_range = cp.arange(n_rows)[:, None] + knn_indices = cp.argpartition(pw_dists, n_neighbors - 1, axis=1) + knn_indices = knn_indices[:, :n_neighbors] + argdist = cp.argsort(pw_dists[sample_range, knn_indices]) + knn_indices = knn_indices[sample_range, argdist] + knn_dists = pw_dists[sample_range, knn_indices] + return knn_indices, knn_dists + + +@with_cupy_rmm +def extract_knn_infos(knn_info, n_neighbors): + """ + Extract the nearest neighbors distances and indices + from the knn_info parameter. + + Parameters + ---------- + knn_info : array / sparse array / tuple, optional (device or host) + Either one of : + - Tuple (indices, distances) of arrays of + shape (n_samples, n_neighbors) + - Pairwise distances dense array of shape (n_samples, n_samples) + - KNN graph sparse array (preferably CSR/COO) + n_neighbors: number of nearest neighbors + """ + if knn_info is None: + # no KNN was provided + return None + + deepcopy = False + if isinstance(knn_info, tuple): + # dists and indices provided as a tuple + results = knn_info + else: + isaKNNGraph = isinstance(knn_info, (csr_matrix, coo_matrix, csc_matrix, + cp_csr_matrix, cp_coo_matrix, + cp_csc_matrix)) + if isaKNNGraph: + # extract dists and indices from a KNN graph + deepcopy = True + results = extract_knn_graph(knn_info) + else: + # extract dists and indices from a pairwise distance matrix + results = extract_pairwise_dists(knn_info, n_neighbors) + + if results is not None: + knn_indices, knn_dists = results + knn_indices_m, _, _, _ = \ - input_to_cuml_array(knn_indices, order='C', - deepcopy=True, - check_dtype=(np.int64, np.int32), - convert_to_dtype=convert_to_dtype) + input_to_cuml_array(knn_indices.flatten(), + order='C', + deepcopy=deepcopy, + check_dtype=np.int64, + convert_to_dtype=np.int64) knn_dists_m, _, _, _ = \ - input_to_cuml_array(knn_dists, order='C', - deepcopy=True, + input_to_cuml_array(knn_dists.flatten(), + order='C', + deepcopy=deepcopy, check_dtype=np.float32, - convert_to_dtype=(np.float32 - if convert_dtype - else None)) + convert_to_dtype=np.float32) - return (knn_indices_m, knn_indices_m.ptr),\ - (knn_dists_m, knn_dists_m.ptr) - return (None, None), (None, None) + return knn_indices_m, knn_dists_m + else: + return None diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index d8e2aefb1f..62316010c4 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -39,7 +39,7 @@ from cuml.common.sparse_utils import is_sparse from cuml.common.doc_utils import generate_docstring from cuml.common import input_to_cuml_array from cuml.internals.mixins import CMajorInputTagMixin -from cuml.common.sparsefuncs import extract_knn_graph +from cuml.common.sparsefuncs import extract_knn_infos from cuml.metrics.distance_type cimport DistanceType rmm = gpu_only_import('rmm') @@ -201,6 +201,14 @@ class TSNE(Base, 'sqeuclidean' metric, the distances will still be squared when True. Note: This argument should likely be set to False for distance metrics other than 'euclidean' and 'l2'. + precomputed_knn : array / sparse array / tuple, optional (device or host) + Either one of a tuple (indices, distances) of + arrays of shape (n_samples, n_neighbors), a pairwise distances + dense array of shape (n_samples, n_samples) or a KNN graph + sparse array (preferably CSR/COO). This feature allows + the precomputation of the KNN outside of TSNE + and also allows the use of a custom distance function. This function + should match the metric used to train the TSNE embeedings. handle : cuml.Handle Specifies the cuml.handle that holds internal CUDA state for computations in this model. Most importantly, this specifies the CUDA @@ -280,6 +288,7 @@ class TSNE(Base, pre_momentum=0.5, post_momentum=0.8, square_distances=True, + precomputed_knn=None, handle=None, output_type=None): @@ -391,6 +400,9 @@ class TSNE(Base, self.sparse_fit = False + self.precomputed_knn = extract_knn_infos(precomputed_knn, + n_neighbors) + @generate_docstring(skip_parameters_heading=True, X='dense_sparse', convert_dtype_cast='np.float32') @@ -400,22 +412,15 @@ class TSNE(Base, Parameters ---------- - knn_graph : sparse array-like (device or host), \ - shape=(n_samples, n_samples) - A sparse array containing the k-nearest neighbors of X, - where the columns are the nearest neighbor indices - for each row and the values are their distances. - Users using the knn_graph parameter provide t-SNE - with their own run of the KNN algorithm. This allows the user - to pick a custom distance function (sometimes useful - on certain datasets) whereas t-SNE uses euclidean by default. - The custom distance function should match the metric used - to train t-SNE embeddings. Storing and reusing a knn_graph - will also provide a speedup to the t-SNE algorithm - when performing a grid search. - Acceptable formats: sparse SciPy ndarray, CuPy device ndarray, - CSR/COO preferred other formats will go through conversion to CSR - + knn_graph : array / sparse array / tuple, optional (device or host) + Either one of a tuple (indices, distances) of + arrays of shape (n_samples, n_neighbors), a pairwise distances + dense array of shape (n_samples, n_samples) or a KNN graph + sparse array (preferably CSR/COO). This feature allows + the precomputation of the KNN outside of TSNE + and also allows the use of a custom distance function. This function + should match the metric used to train the TSNE embeedings. + Takes precedence over the precomputed_knn parameter. """ cdef int n, p cdef handle_t* handle_ = self.handle.getHandle() @@ -450,11 +455,21 @@ class TSNE(Base, "# of datapoints = {}.".format(self.perplexity, n)) self.perplexity = n - (knn_indices_m, knn_indices_ctype), (knn_dists_m, knn_dists_ctype) =\ - extract_knn_graph(knn_graph, convert_dtype, self.sparse_fit) + cdef uintptr_t knn_dists_ptr = 0 + cdef uintptr_t knn_indices_ptr = 0 + if knn_graph is not None or self.precomputed_knn is not None: + if knn_graph is not None: + knn_indices, knn_dists = extract_knn_infos(knn_graph, + self.n_neighbors) + elif self.precomputed_knn is not None: + knn_indices, knn_dists = self.precomputed_knn + + if self.sparse_fit: + knn_indices, _, _, _ = \ + input_to_cuml_array(knn_indices, convert_to_dtype=np.int32) - cdef uintptr_t knn_indices_raw = knn_indices_ctype or 0 - cdef uintptr_t knn_dists_raw = knn_dists_ctype or 0 + knn_dists_ptr = knn_dists.ptr + knn_indices_ptr = knn_indices.ptr # Prepare output embeddings self.embedding_ = CumlArray.zeros( @@ -516,8 +531,8 @@ class TSNE(Base, self.X_m.nnz, n, p, - knn_indices_raw, - knn_dists_raw, + knn_indices_ptr, + knn_dists_ptr, deref(params), &kl_divergence) else: @@ -526,8 +541,8 @@ class TSNE(Base, embed_ptr, n, p, - knn_indices_raw, - knn_dists_raw, + knn_indices_ptr, + knn_dists_ptr, deref(params), &kl_divergence) @@ -669,5 +684,6 @@ class TSNE(Base, "exaggeration_iter", "pre_momentum", "post_momentum", - "square_distances" + "square_distances", + "precomputed_knn" ] diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index bc2b1b24da..8a55487874 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -34,7 +34,7 @@ cuda = gpu_only_import('numba.cuda') from cuml.manifold.umap_utils cimport * from cuml.manifold.umap_utils import GraphHolder, find_ab_params -from cuml.common.sparsefuncs import extract_knn_graph +from cuml.common.sparsefuncs import extract_knn_infos from cuml.internals.safe_imports import gpu_only_import_from cp_csr_matrix = gpu_only_import_from('cupyx.scipy.sparse', 'csr_matrix') cp_coo_matrix = gpu_only_import_from('cupyx.scipy.sparse', 'coo_matrix') @@ -96,6 +96,8 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": float *y, int n, int d, + int * knn_indices, + float * knn_dists, UMAPParams *params, float *embeddings, COO * graph) except + @@ -104,8 +106,6 @@ cdef extern from "cuml/manifold/umap.hpp" namespace "ML::UMAP": float * X, int n, int d, - int64_t * knn_indices, - float * knn_dists, float * orig_X, int orig_n, float * embedding, @@ -221,13 +221,6 @@ class UMAP(UniversalBase, More specific parameters controlling the embedding. If None these values are set automatically as determined by ``min_dist`` and ``spread``. - handle : cuml.Handle - Specifies the cuml.handle that holds internal CUDA state for - computations in this model. Most importantly, this specifies the CUDA - stream that will be used for the model's computations, so users can - run different models concurrently in different streams by creating - handles in several streams. - If it is None, a new one is created. hash_input: bool, optional (default = False) UMAP can hash the training input so that exact embeddings are returned when transform is called on the same data upon @@ -238,6 +231,14 @@ class UMAP(UniversalBase, feature is made optional in the GPU version due to the significant overhead in copying memory to the host for computing the hash. + precomputed_knn : array / sparse array / tuple, optional (device or host) + Either one of a tuple (indices, distances) of + arrays of shape (n_samples, n_neighbors), a pairwise distances + dense array of shape (n_samples, n_samples) or a KNN graph + sparse array (preferably CSR/COO). This feature allows + the precomputation of the KNN outside of UMAP + and also allows the use of a custom distance function. This function + should match the metric used to train the UMAP embeedings. random_state : int, RandomState instance or None, optional (default=None) random_state is the seed used by the random number generator during embedding initialization and during sampling used by the optimizer. @@ -268,6 +269,13 @@ class UMAP(UniversalBase, def on_train_end(self, embeddings): print(embeddings.copy_to_host()) + handle : cuml.Handle + Specifies the cuml.handle that holds internal CUDA state for + computations in this model. Most importantly, this specifies the CUDA + stream that will be used for the model's computations, so users can + run different models concurrently in different streams by creating + handles in several streams. + If it is None, a new one is created. verbose : int or boolean, default=False Sets logging level. It must be one of `cuml.common.logger.level_*`. See :ref:`verbosity-levels` for more info. @@ -325,16 +333,17 @@ class UMAP(UniversalBase, negative_sample_rate=5, transform_queue_size=4.0, init="spectral", - verbose=False, a=None, b=None, target_n_neighbors=-1, target_weight=0.5, target_metric="categorical", - handle=None, hash_input=False, random_state=None, + precomputed_knn=None, callback=None, + handle=None, + verbose=False, output_type=None): super().__init__(handle=handle, @@ -404,6 +413,9 @@ class UMAP(UniversalBase, self._input_hash = None self._small_data = False + self.precomputed_knn = extract_knn_infos(precomputed_knn, + n_neighbors) + def validate_hyperparams(self): if self.min_dist > self.spread: @@ -497,24 +509,15 @@ class UMAP(UniversalBase, Parameters ---------- - knn_graph : sparse array-like (device or host) - shape=(n_samples, n_samples) - A sparse array containing the k-nearest neighbors of X, - where the columns are the nearest neighbor indices - for each row and the values are their distances. - It's important that `k>=n_neighbors`, - so that UMAP can model the neighbors from this graph, - instead of building its own internally. - Users using the knn_graph parameter provide UMAP - with their own run of the KNN algorithm. This allows the user - to pick a custom distance function (sometimes useful - on certain datasets) whereas UMAP uses euclidean by default. - The custom distance function should match the metric used - to train UMAP embeddings. Storing and reusing a knn_graph - will also provide a speedup to the UMAP algorithm - when performing a grid search. - Acceptable formats: sparse SciPy ndarray, CuPy device ndarray, - CSR/COO preferred other formats will go through conversion to CSR + knn_graph : array / sparse array / tuple, optional (device or host) + Either one of a tuple (indices, distances) of + arrays of shape (n_samples, n_neighbors), a pairwise distances + dense array of shape (n_samples, n_samples) or a KNN graph + sparse array (preferably CSR/COO). This feature allows + the precomputation of the KNN outside of UMAP + and also allows the use of a custom distance function. This function + should match the metric used to train the UMAP embeedings. + Takes precedence over the precomputed_knn parameter. """ if len(X.shape) != 2: raise ValueError("data should be two dimensional") @@ -544,11 +547,21 @@ class UMAP(UniversalBase, raise ValueError("There needs to be more than 1 sample to " "build nearest the neighbors graph") - (knn_indices_m, knn_indices_ctype), (knn_dists_m, knn_dists_ctype) =\ - extract_knn_graph(knn_graph, convert_dtype) + cdef uintptr_t knn_dists_ptr = 0 + cdef uintptr_t knn_indices_ptr = 0 + if knn_graph is not None or self.precomputed_knn is not None: + if knn_graph is not None: + knn_indices, knn_dists = extract_knn_infos(knn_graph, + self.n_neighbors) + elif self.precomputed_knn is not None: + knn_indices, knn_dists = self.precomputed_knn + + if self.sparse_fit: + knn_indices, _, _, _ = \ + input_to_cuml_array(knn_indices, convert_to_dtype=np.int32) - cdef uintptr_t knn_indices_raw = knn_indices_ctype or 0 - cdef uintptr_t knn_dists_raw = knn_dists_ctype or 0 + knn_dists_ptr = knn_dists.ptr + knn_indices_ptr = knn_indices.ptr self.n_neighbors = min(self.n_rows, self.n_neighbors) @@ -588,6 +601,8 @@ class UMAP(UniversalBase, y_raw, self.n_rows, self.n_dims, + knn_indices_ptr, + knn_dists_ptr, umap_params, embed_raw, fss_graph.get()) @@ -598,8 +613,8 @@ class UMAP(UniversalBase, y_raw, self.n_rows, self.n_dims, - knn_indices_raw, - knn_dists_raw, + knn_indices_ptr, + knn_dists_ptr, umap_params, embed_raw, fss_graph.get()) @@ -669,7 +684,7 @@ class UMAP(UniversalBase, low-dimensional space.', 'shape': '(n_samples, n_components)'}) @enable_device_interop - def transform(self, X, convert_dtype=True, knn_graph=None) -> CumlArray: + def transform(self, X, convert_dtype=True) -> CumlArray: """ Transform X into the existing embedded space and return that transformed output. @@ -680,28 +695,6 @@ class UMAP(UniversalBase, Specifically, the transform() function is stochastic: https://github.com/lmcinnes/umap/issues/158 - - Parameters - ---------- - knn_graph : sparse array-like (device or host) - shape=(n_samples, n_samples) - A sparse array containing the k-nearest neighbors of X, - where the columns are the nearest neighbor indices - for each row and the values are their distances. - It's important that `k>=n_neighbors`, - so that UMAP can model the neighbors from this graph, - instead of building its own internally. - Users using the knn_graph parameter provide UMAP - with their own run of the KNN algorithm. This allows the user - to pick a custom distance function (sometimes useful - on certain datasets) whereas UMAP uses euclidean by default. - The custom distance function should match the metric used - to train UMAP embeddings. Storing and reusing a knn_graph - will also provide a speedup to the UMAP algorithm - when performing a grid search. - Acceptable formats: sparse SciPy ndarray, CuPy device ndarray, - CSR/COO preferred other formats will go through conversion to CSR - """ if len(X.shape) != 2: raise ValueError("X should be two dimensional") @@ -741,18 +734,11 @@ class UMAP(UniversalBase, del X_m return self.embedding_ - embedding = CumlArray.zeros((X_m.shape[0], - self.n_components), + embedding = CumlArray.zeros((n_rows, self.n_components), order="C", dtype=np.float32, index=index) cdef uintptr_t xformed_ptr = embedding.ptr - (knn_indices_m, knn_indices_ctype), (knn_dists_m, knn_dists_ctype) =\ - extract_knn_graph(knn_graph, convert_dtype) - - cdef uintptr_t knn_indices_raw = knn_indices_ctype or 0 - cdef uintptr_t knn_dists_raw = knn_dists_ctype or 0 - cdef handle_t * handle_ = \ self.handle.getHandle() @@ -767,28 +753,26 @@ class UMAP(UniversalBase, X_m.indices.ptr, X_m.data.ptr, X_m.nnz, - X_m.shape[0], - X_m.shape[1], + n_rows, + n_cols, self._raw_data.indptr.ptr, self._raw_data.indices.ptr, self._raw_data.data.ptr, self._raw_data.nnz, self._raw_data.shape[0], embed_ptr, - self._raw_data.shape[0], + n_rows, umap_params, xformed_ptr) else: transform(handle_[0], X_m.ptr, - X_m.shape[0], - X_m.shape[1], - knn_indices_raw, - knn_dists_raw, + n_rows, + n_cols, self._raw_data.ptr, self._raw_data.shape[0], embed_ptr, - self._raw_data.shape[0], + n_rows, umap_params, xformed_ptr) self.handle.sync() @@ -821,7 +805,8 @@ class UMAP(UniversalBase, "random_state", "callback", "metric", - "metric_kwds" + "metric_kwds", + "precomputed_knn" ] def get_attr_names(self): diff --git a/python/cuml/tests/dask/test_dask_umap.py b/python/cuml/tests/dask/test_dask_umap.py index 425c5c0f15..0a6fdc095d 100644 --- a/python/cuml/tests/dask/test_dask_umap.py +++ b/python/cuml/tests/dask/test_dask_umap.py @@ -130,13 +130,13 @@ def _run_mnmg_test(n_parts, n_rows, sampling_ratio, supervised, trust_diff = loc_umap - dist_umap - return trust_diff <= 0.1 + return trust_diff <= 0.15 @pytest.mark.mg @pytest.mark.parametrize("n_parts", [2, 9]) @pytest.mark.parametrize("n_rows", [100, 500]) -@pytest.mark.parametrize("sampling_ratio", [0.4, 0.9]) +@pytest.mark.parametrize("sampling_ratio", [0.55, 0.9]) @pytest.mark.parametrize("supervised", [True, False]) @pytest.mark.parametrize("dataset", ["digits", "iris"]) @pytest.mark.parametrize("n_neighbors", [10]) diff --git a/python/cuml/tests/test_tsne.py b/python/cuml/tests/test_tsne.py index 2ade330ef5..964b2478eb 100644 --- a/python/cuml/tests/test_tsne.py +++ b/python/cuml/tests/test_tsne.py @@ -13,16 +13,18 @@ # limitations under the License. # +import pytest from sklearn.manifold import TSNE as skTSNE from sklearn import datasets from sklearn.manifold import trustworthiness from sklearn.datasets import make_blobs +from sklearn.neighbors import NearestNeighbors +from cuml.manifold import TSNE from cuml.neighbors import NearestNeighbors as cuKNN +from cuml.metrics import pairwise_distances from cuml.testing.utils import array_equal, stress_param -from cuml.manifold import TSNE -from cuml.internals.safe_imports import gpu_only_import -import pytest from cuml.internals.safe_imports import cpu_only_import +from cuml.internals.safe_imports import gpu_only_import np = cpu_only_import('numpy') scipy = cpu_only_import('scipy') cupyx = gpu_only_import('cupyx') @@ -138,6 +140,41 @@ def test_tsne_knn_parameters(test_datasets, type_knn_graph, method): validate_embedding(X, embed) +@pytest.mark.parametrize('precomputed_type', ['knn_graph', 'tuple', + 'pairwise']) +@pytest.mark.parametrize('sparse_input', [False, True]) +def test_tsne_precomputed_knn(precomputed_type, sparse_input): + data, labels = make_blobs(n_samples=2000, n_features=10, + centers=5, random_state=0) + data = data.astype(np.float32) + + if sparse_input: + sparsification = np.random.choice([0., 1.], + p=[0.1, 0.9], + size=data.shape) + data = np.multiply(data, sparsification) + data = scipy.sparse.csr_matrix(data) + + n_neighbors = DEFAULT_N_NEIGHBORS + + if precomputed_type == 'knn_graph': + nn = NearestNeighbors(n_neighbors=n_neighbors) + nn.fit(data) + precomputed_knn = nn.kneighbors_graph(data, mode="distance") + elif precomputed_type == 'tuple': + nn = NearestNeighbors(n_neighbors=n_neighbors) + nn.fit(data) + precomputed_knn = nn.kneighbors(data, return_distance=True) + precomputed_knn = (precomputed_knn[1], precomputed_knn[0]) + elif precomputed_type == 'pairwise': + precomputed_knn = pairwise_distances(data) + + model = TSNE(n_neighbors=n_neighbors, precomputed_knn=precomputed_knn) + embedding = model.fit_transform(data) + trust = trustworthiness(data, embedding, n_neighbors=n_neighbors) + assert trust >= 0.92 + + @pytest.mark.parametrize('method', ['fft', 'barnes_hut']) def test_tsne(test_datasets, method): """ @@ -348,4 +385,4 @@ def test_tsne_distance_metrics_on_sparse_input(method, metric): assert cu_trust > 0.85 assert nans == 0 - assert array_equal(sk_trust, cu_trust, 0.05, with_sign=True) + assert array_equal(sk_trust, cu_trust, 0.06, with_sign=True) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index 9173570bcd..8faed4af75 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -17,25 +17,25 @@ # Please install UMAP before running the code # use 'conda install -c conda-forge umap-learn' command to install it +import pytest +import copy +import joblib +import umap from sklearn.metrics import adjusted_rand_score from sklearn.manifold import trustworthiness from sklearn.datasets import make_blobs from sklearn.cluster import KMeans +from sklearn.neighbors import NearestNeighbors from sklearn import datasets from cuml.internals import logger -import joblib -from sklearn.neighbors import NearestNeighbors +from cuml.metrics import pairwise_distances from cuml.testing.utils import array_equal, unit_param, \ quality_param, stress_param from cuml.manifold.umap import UMAP as cuUMAP -import cupy as cp -from cuml.internals.safe_imports import gpu_only_import -import copy -import umap -import pytest from cuml.internals.safe_imports import cpu_only_import +from cuml.internals.safe_imports import gpu_only_import np = cpu_only_import('numpy') - +cp = gpu_only_import('cupy') cupyx = gpu_only_import('cupyx') scipy_sparse = cpu_only_import('scipy.sparse') @@ -477,7 +477,7 @@ def compare_exp_decay_params(a=None, b=None, min_dist=0.1, spread=1.0): @pytest.mark.parametrize('n_neighbors', [5, 15]) -def test_umap_knn_parameters(n_neighbors): +def test_umap_knn_graph(n_neighbors): data, labels = datasets.make_blobs( n_samples=2000, n_features=10, centers=5, random_state=0) data = data.astype(np.float32) @@ -494,8 +494,7 @@ def transform_embed(knn_graph=None): init='random', n_neighbors=n_neighbors) model.fit(data, knn_graph=knn_graph, convert_dtype=True) - return model.transform(data, knn_graph=knn_graph, - convert_dtype=True) + return model.transform(data, convert_dtype=True) def test_trustworthiness(embedding): trust = trustworthiness(data, embedding, n_neighbors=n_neighbors) @@ -532,6 +531,41 @@ def test_equality(e1, e2): test_equality(embedding6, embedding7) +@pytest.mark.parametrize('precomputed_type', ['knn_graph', 'tuple', + 'pairwise']) +@pytest.mark.parametrize('sparse_input', [False, True]) +def test_umap_precomputed_knn(precomputed_type, sparse_input): + data, labels = make_blobs(n_samples=2000, n_features=10, + centers=5, random_state=0) + data = data.astype(np.float32) + + if sparse_input: + sparsification = np.random.choice([0., 1.], + p=[0.1, 0.9], + size=data.shape) + data = np.multiply(data, sparsification) + data = scipy_sparse.csr_matrix(data) + + n_neighbors = 8 + + if precomputed_type == 'knn_graph': + nn = NearestNeighbors(n_neighbors=n_neighbors) + nn.fit(data) + precomputed_knn = nn.kneighbors_graph(data, mode="distance") + elif precomputed_type == 'tuple': + nn = NearestNeighbors(n_neighbors=n_neighbors) + nn.fit(data) + precomputed_knn = nn.kneighbors(data, return_distance=True) + precomputed_knn = (precomputed_knn[1], precomputed_knn[0]) + elif precomputed_type == 'pairwise': + precomputed_knn = pairwise_distances(data) + + model = cuUMAP(n_neighbors=n_neighbors, precomputed_knn=precomputed_knn) + embedding = model.fit_transform(data) + trust = trustworthiness(data, embedding, n_neighbors=n_neighbors) + assert trust >= 0.92 + + def correctness_sparse(a, b, atol=0.1, rtol=0.2, threshold=0.95): n_ref_zeros = (a == 0).sum() n_ref_non_zero_elms = a.size - n_ref_zeros