From 5a8889d97fced0c499d78240c8420a448486b826 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 29 Sep 2021 19:21:03 -0400 Subject: [PATCH] Changes to NearestNeighbors to call 2d random ball cover (#4003) This PR integrates the [random ball cover PoC](https://github.com/rapidsai/raft/pull/213) into cuml's brute-force knn for executing the random ball cover algorithm for haversine distance. Authors: - Corey J. Nolet (https://github.com/cjnolet) - Dante Gama Dessavre (https://github.com/dantegd) - Paul Taylor (https://github.com/trxcllnt) Approvers: - AJ Schmidt (https://github.com/ajschmidt8) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/4003 --- ci/local/build.sh | 6 ++ cpp/include/cuml/neighbors/knn.hpp | 11 +++ cpp/src/knn/knn.cu | 19 ++++ cpp/test/sg/hdbscan_test.cu | 1 - python/cuml/neighbors/__init__.py | 6 +- python/cuml/neighbors/nearest_neighbors.pyx | 101 +++++++++++++++++--- python/cuml/test/test_nearest_neighbors.py | 40 ++++++++ 7 files changed, 170 insertions(+), 14 deletions(-) mode change 100644 => 100755 ci/local/build.sh diff --git a/ci/local/build.sh b/ci/local/build.sh old mode 100644 new mode 100755 index 18ee1cfc53..644647e038 --- a/ci/local/build.sh +++ b/ci/local/build.sh @@ -1,5 +1,11 @@ #!/bin/bash +# Copyright (c) 2018-2021, NVIDIA CORPORATION. +############################################## +# cuML local build and test script for CI # +############################################## + + GIT_DESCRIBE_TAG=`git describe --tags` MINOR_VERSION=`echo $GIT_DESCRIBE_TAG | grep -o -E '([0-9]+\.[0-9]+)'` diff --git a/cpp/include/cuml/neighbors/knn.hpp b/cpp/include/cuml/neighbors/knn.hpp index b236aff698..08f726c6af 100644 --- a/cpp/include/cuml/neighbors/knn.hpp +++ b/cpp/include/cuml/neighbors/knn.hpp @@ -18,6 +18,7 @@ #include #include +#include namespace raft { class handle_t; @@ -60,6 +61,16 @@ void brute_force_knn(const raft::handle_t& handle, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, float metric_arg = 2.0f); +void rbc_build_index(const raft::handle_t& handle, + raft::spatial::knn::BallCoverIndex& index); + +void rbc_knn_query(const raft::handle_t& handle, + raft::spatial::knn::BallCoverIndex& index, + uint32_t k, + const float* search_items, + uint32_t n_search_items, + int64_t* out_inds, + float* out_dists); /** * @brief Flat C++ API function to build an approximate nearest neighbors index * from an index array and a set of parameters. diff --git a/cpp/src/knn/knn.cu b/cpp/src/knn/knn.cu index cee3d34501..0b9fa1640d 100644 --- a/cpp/src/knn/knn.cu +++ b/cpp/src/knn/knn.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -64,6 +65,24 @@ void brute_force_knn(const raft::handle_t& handle, metric_arg); } +void rbc_build_index(const raft::handle_t& handle, + raft::spatial::knn::BallCoverIndex& index) +{ + raft::spatial::knn::rbc_build_index(handle, index); +} + +void rbc_knn_query(const raft::handle_t& handle, + raft::spatial::knn::BallCoverIndex& index, + uint32_t k, + const float* search_items, + uint32_t n_search_items, + int64_t* out_inds, + float* out_dists) +{ + raft::spatial::knn::rbc_knn_query( + handle, index, k, search_items, n_search_items, out_inds, out_dists); +} + void approx_knn_build_index(raft::handle_t& handle, raft::spatial::knn::knnIndex* index, raft::spatial::knn::knnIndexParam* params, diff --git a/cpp/test/sg/hdbscan_test.cu b/cpp/test/sg/hdbscan_test.cu index a9299cb1d7..11e65e8553 100644 --- a/cpp/test/sg/hdbscan_test.cu +++ b/cpp/test/sg/hdbscan_test.cu @@ -116,7 +116,6 @@ class HDBSCANTest : public ::testing::TestWithParam> { protected: HDBSCANInputs params; IdxT* labels_ref; - int k; double score; }; diff --git a/python/cuml/neighbors/__init__.py b/python/cuml/neighbors/__init__.py index 68a301bfc7..2cece8b2f7 100644 --- a/python/cuml/neighbors/__init__.py +++ b/python/cuml/neighbors/__init__.py @@ -33,6 +33,10 @@ "inner_product", "sqeuclidean", "haversine" ]), + "rbc": set([ + "euclidean", "haversine", + "l2" + ]), "ivfflat": set([ "l2", "euclidean", "sqeuclidean", "inner_product", "cosine", "correlation" @@ -45,7 +49,7 @@ "l2", "euclidean", "sqeuclidean", "inner_product", "cosine", "correlation" ]) - } +} VALID_METRICS_SPARSE = { "brute": set(["euclidean", "l2", "inner_product", diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index b4bc1eae39..d63ff3ba00 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -48,7 +48,7 @@ from cython.operator cimport dereference as deref from libcpp cimport bool from libcpp.memory cimport shared_ptr -from libc.stdint cimport uintptr_t, int64_t +from libc.stdint cimport uintptr_t, int64_t, uint32_t from libc.stdlib cimport calloc, malloc, free from libcpp.vector cimport vector @@ -63,6 +63,16 @@ cimport cuml.common.cuda if has_scipy(): import scipy.sparse + +cdef extern from "raft/spatial/knn/ball_cover_common.h" \ + namespace "raft::spatial::knn": + cdef cppclass BallCoverIndex[int64_t, float, uint32_t]: + BallCoverIndex(const handle_t &handle, + float *X, + uint32_t n_rows, + uint32_t n_cols, + DistanceType metric) except + + cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": void brute_force_knn( const handle_t &handle, @@ -80,6 +90,21 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": float metric_arg ) except + + void rbc_build_index( + const handle_t &handle, + BallCoverIndex[int64_t, float, uint32_t] &index, + ) except + + + void rbc_knn_query( + const handle_t &handle, + BallCoverIndex[int64_t, float, uint32_t] &index, + uint32_t k, + float *search_items, + uint32_t n_search_items, + int64_t *out_inds, + float *out_dists + ) except + + void approx_knn_build_index( handle_t &handle, knnIndex* index, @@ -101,6 +126,7 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": int n ) except + + cdef extern from "cuml/neighbors/knn_sparse.hpp" namespace "ML::Sparse": void brute_force_knn(handle_t &handle, const int *idxIndptr, @@ -148,6 +174,12 @@ class NearestNeighbors(Base, algorithm : string (default='brute') The query algorithm to use. Valid options are: + - ``'auto'``: to automatically select brute-force or + random ball cover based on data shape and metric + - ``'rbc'``: for the random ball algorithm, which partitions + the data space and uses the triangle inequality to lower the + number of potential distances. Currently, this algorithm + supports 2d Euclidean and Haversine. - ``'brute'``: for brute-force, slow but produces exact results - ``'ivfflat'``: for inverted file, divide the dataset in partitions and perform search on relevant partitions only @@ -299,7 +331,7 @@ class NearestNeighbors(Base, n_neighbors=5, verbose=False, handle=None, - algorithm="brute", + algorithm="auto", metric="euclidean", p=2, algo_params=None, @@ -318,8 +350,10 @@ class NearestNeighbors(Base, self.algo_params = algo_params self.p = p self.algorithm = algorithm + self.working_algorithm_ = self.algorithm + self.selected_algorithm_ = algorithm self.algo_params = algo_params - self.knn_index = 0 + self.knn_index = None @generate_docstring(X='dense_sparse') def fit(self, X, convert_dtype=True) -> "NearestNeighbors": @@ -332,29 +366,44 @@ class NearestNeighbors(Base, self.n_dims = X.shape[1] + if self.algorithm == "auto": + if self.n_dims == 2 and self.metric in \ + cuml.neighbors.VALID_METRICS["rbc"]: + self.working_algorithm_ = "rbc" + else: + self.working_algorithm_ = "brute" + + if self.algorithm == "rbc" and self.n_dims > 2: + raise ValueError("rbc algorithm currently only supports 2d data") + if is_sparse(X): valid_metrics = cuml.neighbors.VALID_METRICS_SPARSE + value_metric_str = "_SPARSE" self.X_m = SparseCumlArray(X, convert_to_dtype=cp.float32, convert_format=False) self.n_rows = self.X_m.shape[0] else: valid_metrics = cuml.neighbors.VALID_METRICS + valid_metric_str = "" self.X_m, self.n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype else None)) - if self.metric not in valid_metrics[self.algorithm]: + if self.metric not in \ + valid_metrics[self.working_algorithm_]: raise ValueError("Metric %s is not valid. " - "Use sorted(cuml.neighbors.VALID_METRICS[%s]) " + "Use sorted(cuml.neighbors.VALID_METRICS%s[%s]) " "to get valid options." % - (self.metric, self.algorithm)) + (valid_metric_str, + self.metric, + self.working_algorithm_)) cdef handle_t* handle_ = self.handle.getHandle() cdef knnIndexParam* algo_params = 0 - if self.algorithm in ['ivfflat', 'ivfpq', 'ivfsq']: + if self.working_algorithm_ in ['ivfflat', 'ivfpq', 'ivfsq']: warnings.warn("\nWarning: Approximate Nearest Neighbor methods " "might be unstable in this version of cuML. " "This is due to a known issue in the FAISS " @@ -370,7 +419,7 @@ class NearestNeighbors(Base, knn_index = new knnIndex() self.knn_index = knn_index algo_params = \ - build_algo_params(self.algorithm, self.algo_params, + build_algo_params(self.working_algorithm_, self.algo_params, additional_info) metric = self._build_metric_type(self.metric) @@ -387,6 +436,16 @@ class NearestNeighbors(Base, destroy_algo_params(algo_params) del self.X_m + elif self.working_algorithm_ == "rbc": + metric = self._build_metric_type(self.metric) + + rbc_index = new BallCoverIndex[int64_t, float, uint32_t]( + handle_[0], self.X_m.ptr, + self.n_rows, n_cols, + metric) + rbc_build_index(handle_[0], + deref(rbc_index)) + self.knn_index = rbc_index self.n_indices = 1 return self @@ -654,8 +713,10 @@ class NearestNeighbors(Base, cdef vector[float*] *inputs = new vector[float*]() cdef vector[int] *sizes = new vector[int]() cdef knnIndex* knn_index = 0 + cdef BallCoverIndex[int64_t, float, uint32_t]* rbc_index = \ + 0 - if self.algorithm == 'brute': + if self.working_algorithm_ == 'brute': inputs.push_back(self.X_m.ptr) sizes.push_back(self.X_m.shape[0]) @@ -675,6 +736,16 @@ class NearestNeighbors(Base, # minkowski order is currently the only metric argument. self.p ) + elif self.working_algorithm_ == "rbc": + rbc_index = \ + self.knn_index + rbc_knn_query(handle_[0], + deref(rbc_index), + n_neighbors, + X_m.ptr, + N, + I_ptr, + D_ptr) else: knn_index = self.knn_index approx_knn_search( @@ -826,9 +897,15 @@ class NearestNeighbors(Base, return sparse_csr def __del__(self): - cdef knnIndex* knn_index = self.knn_index - if knn_index: - del knn_index + cdef knnIndex* knn_index = 0 + cdef BallCoverIndex* rbc_index = 0 + if self.knn_index is not None: + if self.working_algorithm_ in ["ivfflat", "ivfpq", "ivfsq"]: + knn_index = self.knn_index + del knn_index + else: + rbc_index = self.knn_index + del rbc_index @cuml.internals.api_return_sparse_array() diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index 1aa0e86290..35f1c5229e 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -25,6 +25,7 @@ from cuml.datasets import make_blobs from sklearn.metrics import pairwise_distances +from cuml.metrics import pairwise_distances as cuPW from cuml.common import logger @@ -511,6 +512,45 @@ def test_knn_graph(input_type, mode, output_type, as_instance, assert isspmatrix_csr(sparse_cu) +@pytest.mark.parametrize('distance', ["euclidean", "haversine"]) +@pytest.mark.parametrize('n_neighbors', [2, 12]) +@pytest.mark.parametrize('nrows', [unit_param(1000), stress_param(70000)]) +def test_nearest_neighbors_rbc(distance, n_neighbors, nrows): + X, y = make_blobs(n_samples=nrows, + n_features=2, random_state=0) + + knn_cu = cuKNN(metric=distance, algorithm="rbc") + knn_cu.fit(X) + + query_rows = int(nrows/2) + + rbc_d, rbc_i = knn_cu.kneighbors(X[:query_rows, :], + n_neighbors=n_neighbors) + + if distance == 'euclidean': + # Need to use unexpanded euclidean distance + pw_dists = cuPW(X, metric="l2") + brute_i = cp.argsort(pw_dists, axis=1)[:query_rows, :n_neighbors] + brute_d = cp.sort(pw_dists, axis=1)[:query_rows, :n_neighbors] + else: + knn_cu_brute = cuKNN(metric=distance, algorithm="brute") + knn_cu_brute.fit(X) + + brute_d, brute_i = knn_cu_brute.kneighbors( + X[:query_rows, :], n_neighbors=n_neighbors) + + cp.testing.assert_allclose(rbc_d, brute_d, atol=5e-2, + rtol=1e-3) + rbc_i = cp.sort(rbc_i, axis=1) + brute_i = cp.sort(brute_i, axis=1) + + diff = rbc_i != brute_i + + # Using a very small tolerance for subtle differences + # in indices that result from non-determinism + assert diff.ravel().sum() < 5 + + @pytest.mark.parametrize("metric", valid_metrics_sparse()) @pytest.mark.parametrize( 'nrows,ncols,density,n_neighbors,batch_size_index,batch_size_query',