diff --git a/cpp/src/randomforest/randomforest.cuh b/cpp/src/randomforest/randomforest.cuh index d98b889bac..f4cfb24d68 100644 --- a/cpp/src/randomforest/randomforest.cuh +++ b/cpp/src/randomforest/randomforest.cuh @@ -24,10 +24,9 @@ #include #include -#include +#include #include -#include #ifdef _OPENMP #include diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index caee5d4cff..a35943a8a2 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -24,6 +24,7 @@ import cupyx import cudf import ctypes import warnings +import math import cuml.internals from cuml.common.base import Base @@ -178,7 +179,7 @@ class NearestNeighbors(Base, - ``'rbc'``: for the random ball algorithm, which partitions the data space and uses the triangle inequality to lower the number of potential distances. Currently, this algorithm - supports 2d Euclidean and Haversine. + supports Haversine (2d) and Euclidean in 2d and 3d. - ``'brute'``: for brute-force, slow but produces exact results - ``'ivfflat'``: for inverted file, divide the dataset in partitions and perform search on relevant partitions only @@ -347,15 +348,17 @@ class NearestNeighbors(Base, self.n_dims = X.shape[1] if self.algorithm == "auto": - if self.n_dims == 2 and self.metric in \ - cuml.neighbors.VALID_METRICS["rbc"]: + if (self.n_dims == 2 or self.n_dims == 3) and \ + not is_sparse(X) and \ + self.metric in cuml.neighbors.VALID_METRICS["rbc"] and \ + math.sqrt(X.shape[0]) >= self.n_neighbors: self.working_algorithm_ = "rbc" else: self.working_algorithm_ = "brute" - if self.algorithm == "rbc" and self.n_dims > 2: + if self.algorithm == "rbc" and self.n_dims > 3: raise ValueError("The rbc algorithm is not supported for" - " >2 dimensions currently.") + " >3 dimensions currently.") if is_sparse(X): valid_metrics = cuml.neighbors.VALID_METRICS_SPARSE @@ -703,7 +706,16 @@ class NearestNeighbors(Base, cdef BallCoverIndex[int64_t, float, uint32_t]* rbc_index = \ 0 - if self.working_algorithm_ == 'brute': + fallback_to_brute = self.working_algorithm_ == "rbc" and \ + n_neighbors > math.sqrt(self.X_m.shape[0]) + + if fallback_to_brute: + warnings.warn("algorithm='rbc' requires sqrt(%s) be " + "> n_neighbors (%s). falling back to " + "brute force search" % + (self.X_m.shape[0], n_neighbors)) + + if self.working_algorithm_ == 'brute' or fallback_to_brute: inputs.push_back(self.X_m.ptr) sizes.push_back(self.X_m.shape[0]) @@ -886,12 +898,15 @@ class NearestNeighbors(Base, def __del__(self): cdef knnIndex* knn_index = 0 cdef BallCoverIndex* rbc_index = 0 - if self.knn_index is not None: + + kidx = self.__dict__['knn_index'] \ + if 'knn_index' in self.__dict__ else None + if kidx is not None: if self.working_algorithm_ in ["ivfflat", "ivfpq", "ivfsq"]: - knn_index = self.knn_index + knn_index = kidx del knn_index else: - rbc_index = self.knn_index + rbc_index = kidx del rbc_index diff --git a/python/cuml/tests/test_kneighbors_classifier.py b/python/cuml/tests/test_kneighbors_classifier.py index 05a7882d3f..8627139fc9 100644 --- a/python/cuml/tests/test_kneighbors_classifier.py +++ b/python/cuml/tests/test_kneighbors_classifier.py @@ -271,7 +271,7 @@ def test_nonmonotonic_labels(n_classes, n_rows, n_cols, @pytest.mark.parametrize("output_type", ["cudf", "numpy", "cupy"]) def test_predict_multioutput(input_type, output_type): - X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) + X = np.array([[0, 0, 1, 0], [1, 0, 1, 0]]).astype(np.float32) y = np.array([[15, 2], [5, 4]]).astype(np.int32) if input_type == "cudf": @@ -300,7 +300,7 @@ def test_predict_multioutput(input_type, output_type): @pytest.mark.parametrize("output_type", ["cudf", "numpy", "cupy"]) def test_predict_proba_multioutput(input_type, output_type): - X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) + X = np.array([[0, 0, 1, 0], [1, 0, 1, 0]]).astype(np.float32) y = np.array([[15, 2], [5, 4]]).astype(np.int32) if input_type == "cudf": diff --git a/python/cuml/tests/test_kneighbors_regressor.py b/python/cuml/tests/test_kneighbors_regressor.py index e50f7e4dfa..0f04764414 100644 --- a/python/cuml/tests/test_kneighbors_regressor.py +++ b/python/cuml/tests/test_kneighbors_regressor.py @@ -125,7 +125,7 @@ def test_score_dtype(dtype): @pytest.mark.parametrize("output_type", ["cudf", "numpy", "cupy"]) def test_predict_multioutput(input_type, output_type): - X = np.array([[0, 0, 1], [1, 0, 1]]).astype(np.float32) + X = np.array([[0, 0, 1, 0], [1, 0, 1, 0]]).astype(np.float32) y = np.array([[15, 2], [5, 4]]).astype(np.int32) if input_type == "cudf": diff --git a/python/cuml/tests/test_nearest_neighbors.py b/python/cuml/tests/test_nearest_neighbors.py index b26d655d6f..ea95a8cb83 100644 --- a/python/cuml/tests/test_nearest_neighbors.py +++ b/python/cuml/tests/test_nearest_neighbors.py @@ -516,21 +516,25 @@ def test_knn_graph(input_type, mode, output_type, as_instance, assert isspmatrix_csr(sparse_cu) -@pytest.mark.parametrize('distance', ["euclidean", "haversine"]) +@pytest.mark.parametrize('distance_dims', [("euclidean", 2), + ("euclidean", 3), + ("haversine", 2)]) @pytest.mark.parametrize('n_neighbors', [4, 25]) @pytest.mark.parametrize('nrows', [unit_param(10000), stress_param(70000)]) -def test_nearest_neighbors_rbc(distance, n_neighbors, nrows): +def test_nearest_neighbors_rbc(distance_dims, n_neighbors, nrows): + distance, dims = distance_dims + X, y = make_blobs(n_samples=nrows, centers=25, shuffle=True, - n_features=2, + n_features=dims, cluster_std=3.0, random_state=42) knn_cu = cuKNN(metric=distance, algorithm="rbc") knn_cu.fit(X) - query_rows = int(nrows/2) + query_rows = int(nrows / 2) rbc_d, rbc_i = knn_cu.kneighbors(X[:query_rows, :], n_neighbors=n_neighbors) @@ -548,7 +552,11 @@ def test_nearest_neighbors_rbc(distance, n_neighbors, nrows): X[:query_rows, :], n_neighbors=n_neighbors) assert len(brute_d[brute_d != rbc_d]) == 0 - assert len(brute_i[brute_i != rbc_i]) == 0 + + # All the distances match so allow a couple mismatched indices + # through from potential non-determinism in exact matching + # distances + assert len(brute_i[brute_i != rbc_i]) <= 3 @pytest.mark.parametrize("metric", valid_metrics_sparse())