From acd03f1c32bd6f744c174775adfa8c7dbf4e2104 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Feb 2022 18:31:48 -0500 Subject: [PATCH 01/22] Adding support for rbc in 3d --- python/cuml/neighbors/nearest_neighbors.pyx | 6 +++--- python/cuml/test/test_nearest_neighbors.py | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 0790d2cc87..24c2ebdd33 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -366,15 +366,15 @@ class NearestNeighbors(Base, self.n_dims = X.shape[1] if self.algorithm == "auto": - if self.n_dims == 2 and self.metric in \ + if (self.n_dims == 2 or self.n_dims == 3) and self.metric in \ cuml.neighbors.VALID_METRICS["rbc"]: self.working_algorithm_ = "rbc" else: self.working_algorithm_ = "brute" - if self.algorithm == "rbc" and self.n_dims > 2: + if self.algorithm == "rbc" and self.n_dims > 3: raise ValueError("The rbc algorithm is not supported for" - " >2 dimensions currently.") + " >3 dimensions currently.") if is_sparse(X): valid_metrics = cuml.neighbors.VALID_METRICS_SPARSE diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index dfc8711711..d6370902e1 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -516,21 +516,23 @@ def test_knn_graph(input_type, mode, output_type, as_instance, assert isspmatrix_csr(sparse_cu) -@pytest.mark.parametrize('distance', ["euclidean", "haversine"]) +@pytest.mark.parametrize('distance_dims', [("euclidean", 2), ("euclidean", 3), ("haversine", 2)]) @pytest.mark.parametrize('n_neighbors', [4, 25]) @pytest.mark.parametrize('nrows', [unit_param(10000), stress_param(70000)]) -def test_nearest_neighbors_rbc(distance, n_neighbors, nrows): +def test_nearest_neighbors_rbc(distance_dims, n_neighbors, nrows): + distance, dims = distance_dims + X, y = make_blobs(n_samples=nrows, centers=25, shuffle=True, - n_features=2, + n_features=dims, cluster_std=3.0, random_state=42) knn_cu = cuKNN(metric=distance, algorithm="rbc") knn_cu.fit(X) - query_rows = int(nrows/2) + query_rows = int(nrows / 2) rbc_d, rbc_i = knn_cu.kneighbors(X[:query_rows, :], n_neighbors=n_neighbors) @@ -547,14 +549,15 @@ def test_nearest_neighbors_rbc(distance, n_neighbors, nrows): brute_d, brute_i = knn_cu_brute.kneighbors( X[:query_rows, :], n_neighbors=n_neighbors) - rbc_i = cp.sort(rbc_i, axis=1) - brute_i = cp.sort(brute_i, axis=1) + # rbc_i = cp.sort(rbc_i, axis=1) + # brute_i = cp.sort(brute_i, axis=1) # TODO: These are failing with 1 or 2 mismatched elements # for very small values of k: # https://github.com/rapidsai/cuml/issues/4262 - assert len(brute_d[brute_d != rbc_d]) <= 3 - assert len(brute_i[brute_i != rbc_i]) <= 3 + + assert len(brute_d[brute_d != rbc_d]) == 0 + assert len(brute_i[brute_i != rbc_i]) == 0 @pytest.mark.parametrize("metric", valid_metrics_sparse()) From 427dbd0ca7c83240062d8709a42ad1c5d1bd07de Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Feb 2022 18:43:08 -0500 Subject: [PATCH 02/22] Setting raft pin --- cpp/cmake/thirdparty/get_raft.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 63f795d519..14e3551db3 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -57,8 +57,8 @@ set(CUML_BRANCH_VERSION_raft "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}") # To use a different RAFT locally, set the CMake variable # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft} - FORK rapidsai - PINNED_TAG branch-${CUML_BRANCH_VERSION_raft} + FORK cjnolet + PINNED_TAG fea-2204-rbc_3d USE_RAFT_NN ${CUML_USE_RAFT_NN} USE_FAISS_STATIC ${CUML_USE_FAISS_STATIC} ) From 1eec5410110eef8d2933387b771a65a6ece56fa4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Feb 2022 18:43:59 -0500 Subject: [PATCH 03/22] Fixing style --- python/cuml/test/test_nearest_neighbors.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index d6370902e1..dce93f3ff5 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -516,7 +516,9 @@ def test_knn_graph(input_type, mode, output_type, as_instance, assert isspmatrix_csr(sparse_cu) -@pytest.mark.parametrize('distance_dims', [("euclidean", 2), ("euclidean", 3), ("haversine", 2)]) +@pytest.mark.parametrize('distance_dims', [("euclidean", 2), + ("euclidean", 3), + ("haversine", 2)]) @pytest.mark.parametrize('n_neighbors', [4, 25]) @pytest.mark.parametrize('nrows', [unit_param(10000), stress_param(70000)]) def test_nearest_neighbors_rbc(distance_dims, n_neighbors, nrows): From ea8f9fe0f3023421279843117d9c3a0056131e6b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 15 Feb 2022 18:45:24 -0500 Subject: [PATCH 04/22] Removing unecessary code --- python/cuml/test/test_nearest_neighbors.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index dce93f3ff5..ca5dd45164 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -551,13 +551,6 @@ def test_nearest_neighbors_rbc(distance_dims, n_neighbors, nrows): brute_d, brute_i = knn_cu_brute.kneighbors( X[:query_rows, :], n_neighbors=n_neighbors) - # rbc_i = cp.sort(rbc_i, axis=1) - # brute_i = cp.sort(brute_i, axis=1) - - # TODO: These are failing with 1 or 2 mismatched elements - # for very small values of k: - # https://github.com/rapidsai/cuml/issues/4262 - assert len(brute_d[brute_d != rbc_d]) == 0 assert len(brute_i[brute_i != rbc_i]) == 0 From d860150fd1fd75fb05752b954ae2da675cb3dbc6 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Feb 2022 16:03:51 -0500 Subject: [PATCH 05/22] Updating copyright --- python/cuml/neighbors/nearest_neighbors.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 24c2ebdd33..732e26fae7 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 294d815e6e7c00c2d224c4b9a40b1205d7404df3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 23 Feb 2022 21:36:41 -0500 Subject: [PATCH 06/22] Using static linking when internal raft clone is used --- cpp/cmake/thirdparty/get_raft.cmake | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 3b289662e4..653344df59 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -23,9 +23,11 @@ function(find_and_configure_raft) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) + set(STATIC_LINK_LIBRARIES OFF) if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "branch-${CUML_BRANCH_VERSION_raft}") message("Pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.") set(CPM_DOWNLOAD_raft ON) + set(STATIC_LINK_LIBRARIES ON) endif() string(APPEND RAFT_COMPONENTS "distance") @@ -45,6 +47,7 @@ function(find_and_configure_raft) SOURCE_SUBDIR cpp FIND_PACKAGE_ARGUMENTS "COMPONENTS ${RAFT_COMPONENTS}" OPTIONS + "RAFT_STATIC_LINK_LIBRARIES ${STATIC_LINK_LIBRARIES}" "BUILD_TESTS OFF" "RAFT_USE_FAISS_STATIC ${PKG_USE_FAISS_STATIC}" "NVTX ${NVTX}" From 0191abe8e7f52f6b72739d9646950fa2ad95df7b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 23 Feb 2022 21:39:27 -0500 Subject: [PATCH 07/22] RAFT qualify environment var --- cpp/cmake/thirdparty/get_raft.cmake | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 653344df59..de455e14fe 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -23,11 +23,11 @@ function(find_and_configure_raft) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - set(STATIC_LINK_LIBRARIES OFF) + set(RAFT_STATIC_LINK_LIBRARIES OFF) if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "branch-${CUML_BRANCH_VERSION_raft}") message("Pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.") set(CPM_DOWNLOAD_raft ON) - set(STATIC_LINK_LIBRARIES ON) + set(RAFT_STATIC_LINK_LIBRARIES ON) endif() string(APPEND RAFT_COMPONENTS "distance") @@ -47,7 +47,7 @@ function(find_and_configure_raft) SOURCE_SUBDIR cpp FIND_PACKAGE_ARGUMENTS "COMPONENTS ${RAFT_COMPONENTS}" OPTIONS - "RAFT_STATIC_LINK_LIBRARIES ${STATIC_LINK_LIBRARIES}" + "RAFT_STATIC_LINK_LIBRARIES ${RAFT_STATIC_LINK_LIBRARIES}" "BUILD_TESTS OFF" "RAFT_USE_FAISS_STATIC ${PKG_USE_FAISS_STATIC}" "NVTX ${NVTX}" From ad91e411b5074c2af9e30f662cf367078f962cf5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 23 Feb 2022 23:19:56 -0500 Subject: [PATCH 08/22] Adding -fPIC to build flags --- cpp/cmake/thirdparty/get_raft.cmake | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index de455e14fe..4dc34e6158 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -27,6 +27,10 @@ function(find_and_configure_raft) if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "branch-${CUML_BRANCH_VERSION_raft}") message("Pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.") set(CPM_DOWNLOAD_raft ON) + set(RAFT_CXX_FLAGS ${RAFT_CXX_FLAGS} -fPIC) + set(RAFT_CUDA_FLAGS ${RAFT_CUDA_FLAGS} -fPIC) + set(CUML_CXX_FLAGS ${CUML_CXX_FLAGS} -fPIC) + set(CUML_CUDA_FLAGS ${CUML_CUDA_FLAGS} -fPIC) set(RAFT_STATIC_LINK_LIBRARIES ON) endif() From 7b92b2543abb0a3ef30c3c298129d0d5a01d2cc8 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 25 Feb 2022 23:18:21 -0500 Subject: [PATCH 09/22] Updating pin --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 4dc34e6158..77a56aed20 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -71,7 +71,7 @@ endfunction() # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft} FORK cjnolet - PINNED_TAG fea-2204-rbc_3d + PINNED_TAG fea-2204-rbc_3d_2 # When PINNED_TAG above doesn't match cuml, # force local raft clone in build directory From f8c8406504fdb29e9f234005e444ae70b48c1848 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 26 Feb 2022 11:44:07 -0500 Subject: [PATCH 10/22] Fixing destructor of neighbors cython so subclasses can call it. --- python/cuml/neighbors/nearest_neighbors.pyx | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index ea7219976d..9d098a44af 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -905,12 +905,15 @@ class NearestNeighbors(Base, def __del__(self): cdef knnIndex* knn_index = 0 cdef BallCoverIndex* rbc_index = 0 - if self.knn_index is not None: + + kidx = self.__dict__['knn_index'] \ + if 'knn_index' in self.__dict__ else None + if kidx is not None: if self.working_algorithm_ in ["ivfflat", "ivfpq", "ivfsq"]: - knn_index = self.knn_index + knn_index = kidx del knn_index else: - rbc_index = self.knn_index + rbc_index = kidx del rbc_index From e9aed72f07650bd7df3d936d3e61d59a880df265 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 26 Feb 2022 20:16:41 -0500 Subject: [PATCH 11/22] Updating knn regressor and classifier --- python/cuml/test/test_kneighbors_classifier.py | 4 ++-- python/cuml/test/test_kneighbors_regressor.py | 2 +- python/cuml/test/test_umap.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/cuml/test/test_kneighbors_classifier.py b/python/cuml/test/test_kneighbors_classifier.py index 6117a3e958..1033a93dd5 100644 --- a/python/cuml/test/test_kneighbors_classifier.py +++ b/python/cuml/test/test_kneighbors_classifier.py @@ -271,7 +271,7 @@ def test_nonmonotonic_labels(n_classes, n_rows, n_cols, @pytest.mark.parametrize("output_type", ["cudf", "numpy", "cupy"]) def test_predict_multioutput(input_type, output_type): - X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) + X = np.array([[0, 0, 1, 0], [1, 0, 1, 0]]).astype(np.float32) y = np.array([[15, 2], [5, 4]]).astype(np.int32) if input_type == "cudf": @@ -300,7 +300,7 @@ def test_predict_multioutput(input_type, output_type): @pytest.mark.parametrize("output_type", ["cudf", "numpy", "cupy"]) def test_predict_proba_multioutput(input_type, output_type): - X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) + X = np.array([[0, 0, 1, 0], [1, 0, 1, 0]]).astype(np.float32) y = np.array([[15, 2], [5, 4]]).astype(np.int32) if input_type == "cudf": diff --git a/python/cuml/test/test_kneighbors_regressor.py b/python/cuml/test/test_kneighbors_regressor.py index 53ffdd8c1a..c192d4052b 100644 --- a/python/cuml/test/test_kneighbors_regressor.py +++ b/python/cuml/test/test_kneighbors_regressor.py @@ -125,7 +125,7 @@ def test_score_dtype(dtype): @pytest.mark.parametrize("output_type", ["cudf", "numpy", "cupy"]) def test_predict_multioutput(input_type, output_type): - X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) + X = np.array([[0, 0, 1, 0], [1, 0, 1, 0]]).astype(np.float32) y = np.array([[15, 2], [5, 4]]).astype(np.int32) if input_type == "cudf": diff --git a/python/cuml/test/test_umap.py b/python/cuml/test/test_umap.py index f706046282..9974ddea91 100644 --- a/python/cuml/test/test_umap.py +++ b/python/cuml/test/test_umap.py @@ -59,9 +59,9 @@ def test_blobs_cluster(nrows, n_feats): assert score == 1.0 -@pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000), +@pytest.mark.parametrize('nrows', [unit_param(16384), quality_param(5000), stress_param(500000)]) -@pytest.mark.parametrize('n_feats', [unit_param(10), quality_param(100), +@pytest.mark.parametrize('n_feats', [unit_param(1000), quality_param(100), stress_param(1000)]) def test_umap_fit_transform_score(nrows, n_feats): @@ -71,23 +71,23 @@ def test_umap_fit_transform_score(nrows, n_feats): data, labels = make_blobs(n_samples=n_samples, n_features=n_features, centers=10, random_state=42) - model = umap.UMAP(n_neighbors=10, min_dist=0.1) +# model = umap.UMAP(n_neighbors=10, min_dist=0.1) cuml_model = cuUMAP(n_neighbors=10, min_dist=0.01) - embedding = model.fit_transform(data) +# embedding = model.fit_transform(data) cuml_embedding = cuml_model.fit_transform(data, convert_dtype=True) - assert not np.isnan(embedding).any() +# assert not np.isnan(embedding).any() assert not np.isnan(cuml_embedding).any() if nrows < 500000: cuml_score = adjusted_rand_score(labels, KMeans(10).fit_predict( cuml_embedding)) - score = adjusted_rand_score(labels, - KMeans(10).fit_predict(embedding)) +# score = adjusted_rand_score(labels, +# KMeans(10).fit_predict(embedding)) - assert array_equal(score, cuml_score, 1e-2, with_sign=True) +# assert array_equal(score, cuml_score, 1e-2, with_sign=True) def test_supervised_umap_trustworthiness_on_iris(): From afab91c399d92d57ef492a8dfd62507892fe74bb Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 27 Feb 2022 09:54:27 -0500 Subject: [PATCH 12/22] nearest neighbors python to fall back to brute force when rbc shouldn't be used --- python/cuml/neighbors/nearest_neighbors.pyx | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 9d098a44af..c6f34e07f5 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -24,6 +24,7 @@ import cupyx import cudf import ctypes import warnings +import math import cuml.internals from cuml.common.base import Base @@ -366,8 +367,10 @@ class NearestNeighbors(Base, self.n_dims = X.shape[1] if self.algorithm == "auto": - if (self.n_dims == 2 or self.n_dims == 3) and self.metric in \ - cuml.neighbors.VALID_METRICS["rbc"]: + if (self.n_dims == 2 or self.n_dims == 3) and \ + not is_sparse(X) and \ + self.metric in cuml.neighbors.VALID_METRICS["rbc"] and \ + math.sqrt(X.shape[0]) >= self.n_neighbors: self.working_algorithm_ = "rbc" else: self.working_algorithm_ = "brute" @@ -722,7 +725,15 @@ class NearestNeighbors(Base, cdef BallCoverIndex[int64_t, float, uint32_t]* rbc_index = \ 0 - if self.working_algorithm_ == 'brute': + fallback_to_brute = self.working_algorithm_ == "rbc" and \ + n_neighbors > math.sqrt(self.X_m.shape[0]) + + if fallback_to_brute: + warnings.warn("sqrt(%s) < n_neighbors (%s). " + "falling back to brute force search" % + (self.X_m.shape[0], n_neighbors)) + + if self.working_algorithm_ == 'brute' or fallback_to_brute: inputs.push_back(self.X_m.ptr) sizes.push_back(self.X_m.shape[0]) From fe9bfd4028fa769f49d174c876cd0fecfe7f84ab Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 27 Feb 2022 09:55:40 -0500 Subject: [PATCH 13/22] Updating copyrights for python --- python/cuml/test/test_kneighbors_classifier.py | 2 +- python/cuml/test/test_kneighbors_regressor.py | 2 +- python/cuml/test/test_umap.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cuml/test/test_kneighbors_classifier.py b/python/cuml/test/test_kneighbors_classifier.py index 1033a93dd5..0da02fa1f1 100644 --- a/python/cuml/test/test_kneighbors_classifier.py +++ b/python/cuml/test/test_kneighbors_classifier.py @@ -1,5 +1,5 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/cuml/test/test_kneighbors_regressor.py b/python/cuml/test/test_kneighbors_regressor.py index c192d4052b..de518422f8 100644 --- a/python/cuml/test/test_kneighbors_regressor.py +++ b/python/cuml/test/test_kneighbors_regressor.py @@ -1,5 +1,5 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/cuml/test/test_umap.py b/python/cuml/test/test_umap.py index 9974ddea91..8175edca4e 100644 --- a/python/cuml/test/test_umap.py +++ b/python/cuml/test/test_umap.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From abafc35f4ef1bd86c474c3338f3933d0d066a735 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 27 Feb 2022 10:02:09 -0500 Subject: [PATCH 14/22] Reverting changes to test_umap --- python/cuml/test/test_umap.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/cuml/test/test_umap.py b/python/cuml/test/test_umap.py index 8175edca4e..abf203f365 100644 --- a/python/cuml/test/test_umap.py +++ b/python/cuml/test/test_umap.py @@ -71,23 +71,23 @@ def test_umap_fit_transform_score(nrows, n_feats): data, labels = make_blobs(n_samples=n_samples, n_features=n_features, centers=10, random_state=42) -# model = umap.UMAP(n_neighbors=10, min_dist=0.1) + model = umap.UMAP(n_neighbors=10, min_dist=0.1) cuml_model = cuUMAP(n_neighbors=10, min_dist=0.01) -# embedding = model.fit_transform(data) + embedding = model.fit_transform(data) cuml_embedding = cuml_model.fit_transform(data, convert_dtype=True) -# assert not np.isnan(embedding).any() + assert not np.isnan(embedding).any() assert not np.isnan(cuml_embedding).any() if nrows < 500000: cuml_score = adjusted_rand_score(labels, KMeans(10).fit_predict( cuml_embedding)) -# score = adjusted_rand_score(labels, -# KMeans(10).fit_predict(embedding)) + score = adjusted_rand_score(labels, + KMeans(10).fit_predict(embedding)) -# assert array_equal(score, cuml_score, 1e-2, with_sign=True) + assert array_equal(score, cuml_score, 1e-2, with_sign=True) def test_supervised_umap_trustworthiness_on_iris(): From 3cb96c0082106ff6f5d50e13ec9a91a5f802ec25 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 27 Feb 2022 10:08:05 -0500 Subject: [PATCH 15/22] Fixing style --- python/cuml/test/test_umap.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cuml/test/test_umap.py b/python/cuml/test/test_umap.py index abf203f365..06cd03d837 100644 --- a/python/cuml/test/test_umap.py +++ b/python/cuml/test/test_umap.py @@ -84,10 +84,10 @@ def test_umap_fit_transform_score(nrows, n_feats): cuml_score = adjusted_rand_score(labels, KMeans(10).fit_predict( cuml_embedding)) - score = adjusted_rand_score(labels, - KMeans(10).fit_predict(embedding)) + score = adjusted_rand_score(labels, + KMeans(10).fit_predict(embedding)) - assert array_equal(score, cuml_score, 1e-2, with_sign=True) + assert array_equal(score, cuml_score, 1e-2, with_sign=True) def test_supervised_umap_trustworthiness_on_iris(): From 9b5c3ee4a4aab77e3d0b98ba86c5273cfb4cd454 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 28 Apr 2022 16:20:05 -0400 Subject: [PATCH 16/22] Turning off static linking --- cpp/cmake/thirdparty/get_raft.cmake | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 54b2133531..7d57d3e91d 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -22,7 +22,6 @@ function(find_and_configure_raft) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - set(RAFT_STATIC_LINK_LIBRARIES OFF) if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "branch-${CUML_BRANCH_VERSION_raft}") message(STATUS "CUML: RAFT pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.") set(CPM_DOWNLOAD_raft ON) @@ -31,10 +30,6 @@ function(find_and_configure_raft) if(PKG_USE_RAFT_STATIC) message(STATUS "CUML: Cloning raft locally to build static libraries.") set(CPM_DOWNLOAD_raft ON) - set(RAFT_CXX_FLAGS ${RAFT_CXX_FLAGS} -fPIC) - set(RAFT_CUDA_FLAGS ${RAFT_CUDA_FLAGS} -fPIC) - set(CUML_CXX_FLAGS ${CUML_CXX_FLAGS} -fPIC) - set(CUML_CUDA_FLAGS ${CUML_CUDA_FLAGS} -fPIC) set(RAFT_STATIC_LINK_LIBRARIES ON) endif() @@ -63,7 +58,6 @@ function(find_and_configure_raft) SOURCE_SUBDIR cpp FIND_PACKAGE_ARGUMENTS "COMPONENTS ${RAFT_COMPONENTS}" OPTIONS - "RAFT_STATIC_LINK_LIBRARIES ${RAFT_STATIC_LINK_LIBRARIES}" "BUILD_TESTS OFF" "RAFT_COMPILE_LIBRARIES ${RAFT_COMPILE_LIBRARIES}" "RAFT_COMPILE_NN_LIBRARY ${PKG_USE_RAFT_NN}" From a7d363e3b35c6ddb6c40310c315f5fb7030d6b4b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 28 Apr 2022 16:24:04 -0400 Subject: [PATCH 17/22] Not setting static link libraries to on --- cpp/cmake/thirdparty/get_raft.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 7d57d3e91d..e1f8333867 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -30,7 +30,6 @@ function(find_and_configure_raft) if(PKG_USE_RAFT_STATIC) message(STATUS "CUML: Cloning raft locally to build static libraries.") set(CPM_DOWNLOAD_raft ON) - set(RAFT_STATIC_LINK_LIBRARIES ON) endif() if(PKG_USE_RAFT_DIST) From 1564a040ddc9253e3dfe77e171930a5fec9dcc53 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 28 Apr 2022 17:53:44 -0400 Subject: [PATCH 18/22] Fixing duplicate rng import --- cpp/src/randomforest/randomforest.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/randomforest/randomforest.cuh b/cpp/src/randomforest/randomforest.cuh index d98b889bac..f4cfb24d68 100644 --- a/cpp/src/randomforest/randomforest.cuh +++ b/cpp/src/randomforest/randomforest.cuh @@ -24,10 +24,9 @@ #include #include -#include +#include #include -#include #ifdef _OPENMP #include From e5d016eb06021ae18d36b825ac532860d7fe2e6d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 4 May 2022 10:28:33 -0400 Subject: [PATCH 19/22] Allowing a couple mismatched indices just in the case of non-determinisms for conflicting distances --- python/cuml/tests/test_nearest_neighbors.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/cuml/tests/test_nearest_neighbors.py b/python/cuml/tests/test_nearest_neighbors.py index e613570539..ea95a8cb83 100644 --- a/python/cuml/tests/test_nearest_neighbors.py +++ b/python/cuml/tests/test_nearest_neighbors.py @@ -552,7 +552,11 @@ def test_nearest_neighbors_rbc(distance_dims, n_neighbors, nrows): X[:query_rows, :], n_neighbors=n_neighbors) assert len(brute_d[brute_d != rbc_d]) == 0 - assert len(brute_i[brute_i != rbc_i]) == 0 + + # All the distances match so allow a couple mismatched indices + # through from potential non-determinism in exact matching + # distances + assert len(brute_i[brute_i != rbc_i]) <= 3 @pytest.mark.parametrize("metric", valid_metrics_sparse()) From ad4e472811a1b17b19973bad74100a1fae1d3a28 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 12 May 2022 20:47:25 -0400 Subject: [PATCH 20/22] Reverting change --- cpp/cmake/thirdparty/get_raft.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index e1f8333867..d63087a7fb 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -78,8 +78,8 @@ endfunction() # To use a different RAFT locally, set the CMake variable # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft} - FORK cjnolet - PINNED_TAG fea-2204-rbc_3d_2 + FORK rapidsai + PINNED_TAG branch-${CUML_BRANCH_VERSION_raft} # When PINNED_TAG above doesn't match cuml, # force local raft clone in build directory From 589a9c47e4b20de5fcad8fac27cbd326366b3909 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 13 May 2022 10:16:30 -0400 Subject: [PATCH 21/22] Reverting umap changes --- python/cuml/tests/test_umap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cuml/tests/test_umap.py b/python/cuml/tests/test_umap.py index db3c758e56..cd7b5c3bf6 100644 --- a/python/cuml/tests/test_umap.py +++ b/python/cuml/tests/test_umap.py @@ -59,9 +59,9 @@ def test_blobs_cluster(nrows, n_feats): assert score == 1.0 -@pytest.mark.parametrize('nrows', [unit_param(16384), quality_param(5000), +@pytest.mark.parametrize('nrows', [unit_param(500), quality_param(5000), stress_param(500000)]) -@pytest.mark.parametrize('n_feats', [unit_param(1000), quality_param(100), +@pytest.mark.parametrize('n_feats', [unit_param(10), quality_param(100), stress_param(1000)]) def test_umap_fit_transform_score(nrows, n_feats): From aced1870c2557bc595c992c1c30342a125cdb772 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 13 May 2022 13:23:55 -0400 Subject: [PATCH 22/22] Review feedback --- python/cuml/neighbors/nearest_neighbors.pyx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index ab79458fca..a35943a8a2 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -179,7 +179,7 @@ class NearestNeighbors(Base, - ``'rbc'``: for the random ball algorithm, which partitions the data space and uses the triangle inequality to lower the number of potential distances. Currently, this algorithm - supports 2d Euclidean and Haversine. + supports Haversine (2d) and Euclidean in 2d and 3d. - ``'brute'``: for brute-force, slow but produces exact results - ``'ivfflat'``: for inverted file, divide the dataset in partitions and perform search on relevant partitions only @@ -710,8 +710,9 @@ class NearestNeighbors(Base, n_neighbors > math.sqrt(self.X_m.shape[0]) if fallback_to_brute: - warnings.warn("sqrt(%s) < n_neighbors (%s). " - "falling back to brute force search" % + warnings.warn("algorithm='rbc' requires sqrt(%s) be " + "> n_neighbors (%s). falling back to " + "brute force search" % (self.X_m.shape[0], n_neighbors)) if self.working_algorithm_ == 'brute' or fallback_to_brute: