From 6545f7642498fb492d6e35d7ea9598d5d72c79ba Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 29 Sep 2021 13:17:19 -0400 Subject: [PATCH] Fixing neighbors tests --- python/cuml/neighbors/nearest_neighbors.pyx | 10 ++++++--- python/cuml/test/test_nearest_neighbors.py | 24 +++++++++++++++------ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index d324558e13..d63ff3ba00 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -378,12 +378,14 @@ class NearestNeighbors(Base, if is_sparse(X): valid_metrics = cuml.neighbors.VALID_METRICS_SPARSE + value_metric_str = "_SPARSE" self.X_m = SparseCumlArray(X, convert_to_dtype=cp.float32, convert_format=False) self.n_rows = self.X_m.shape[0] else: valid_metrics = cuml.neighbors.VALID_METRICS + valid_metric_str = "" self.X_m, self.n_rows, n_cols, dtype = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 @@ -391,11 +393,13 @@ class NearestNeighbors(Base, else None)) if self.metric not in \ - cuml.neighbors.VALID_METRICS[self.working_algorithm_]: + valid_metrics[self.working_algorithm_]: raise ValueError("Metric %s is not valid. " - "Use sorted(cuml.neighbors.VALID_METRICS[%s]) " + "Use sorted(cuml.neighbors.VALID_METRICS%s[%s]) " "to get valid options." % - (self.metric, self.working_algorithm_)) + (valid_metric_str, + self.metric, + self.working_algorithm_)) cdef handle_t* handle_ = self.handle.getHandle() cdef knnIndexParam* algo_params = 0 diff --git a/python/cuml/test/test_nearest_neighbors.py b/python/cuml/test/test_nearest_neighbors.py index 7cfcdc46c2..35f1c5229e 100644 --- a/python/cuml/test/test_nearest_neighbors.py +++ b/python/cuml/test/test_nearest_neighbors.py @@ -513,8 +513,8 @@ def test_knn_graph(input_type, mode, output_type, as_instance, @pytest.mark.parametrize('distance', ["euclidean", "haversine"]) -@pytest.mark.parametrize('n_neighbors', [2, 35]) -@pytest.mark.parametrize('nrows', [unit_param(500), stress_param(70000)]) +@pytest.mark.parametrize('n_neighbors', [2, 12]) +@pytest.mark.parametrize('nrows', [unit_param(1000), stress_param(70000)]) def test_nearest_neighbors_rbc(distance, n_neighbors, nrows): X, y = make_blobs(n_samples=nrows, n_features=2, random_state=0) @@ -522,12 +522,22 @@ def test_nearest_neighbors_rbc(distance, n_neighbors, nrows): knn_cu = cuKNN(metric=distance, algorithm="rbc") knn_cu.fit(X) - rbc_d, rbc_i = knn_cu.kneighbors(X[:int(nrows/2), :], + query_rows = int(nrows/2) + + rbc_d, rbc_i = knn_cu.kneighbors(X[:query_rows, :], n_neighbors=n_neighbors) - pw_dists = cuPW(X, metric="l2") - brute_i = cp.argsort(X, axis=1) - brute_d = pw_dists[brute_i][:, :n_neighbors] + if distance == 'euclidean': + # Need to use unexpanded euclidean distance + pw_dists = cuPW(X, metric="l2") + brute_i = cp.argsort(pw_dists, axis=1)[:query_rows, :n_neighbors] + brute_d = cp.sort(pw_dists, axis=1)[:query_rows, :n_neighbors] + else: + knn_cu_brute = cuKNN(metric=distance, algorithm="brute") + knn_cu_brute.fit(X) + + brute_d, brute_i = knn_cu_brute.kneighbors( + X[:query_rows, :], n_neighbors=n_neighbors) cp.testing.assert_allclose(rbc_d, brute_d, atol=5e-2, rtol=1e-3) @@ -537,7 +547,7 @@ def test_nearest_neighbors_rbc(distance, n_neighbors, nrows): diff = rbc_i != brute_i # Using a very small tolerance for subtle differences - # in indices that result from + # in indices that result from non-determinism assert diff.ravel().sum() < 5