Skip to content

Commit

Permalink
Fixing neighbors tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Sep 29, 2021
1 parent 9e44f67 commit 6545f76
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
10 changes: 7 additions & 3 deletions python/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -378,24 +378,28 @@ class NearestNeighbors(Base,

if is_sparse(X):
valid_metrics = cuml.neighbors.VALID_METRICS_SPARSE
value_metric_str = "_SPARSE"
self.X_m = SparseCumlArray(X, convert_to_dtype=cp.float32,
convert_format=False)
self.n_rows = self.X_m.shape[0]

else:
valid_metrics = cuml.neighbors.VALID_METRICS
valid_metric_str = ""
self.X_m, self.n_rows, n_cols, dtype = \
input_to_cuml_array(X, order='C', check_dtype=np.float32,
convert_to_dtype=(np.float32
if convert_dtype
else None))

if self.metric not in \
cuml.neighbors.VALID_METRICS[self.working_algorithm_]:
valid_metrics[self.working_algorithm_]:
raise ValueError("Metric %s is not valid. "
"Use sorted(cuml.neighbors.VALID_METRICS[%s]) "
"Use sorted(cuml.neighbors.VALID_METRICS%s[%s]) "
"to get valid options." %
(self.metric, self.working_algorithm_))
(valid_metric_str,
self.metric,
self.working_algorithm_))

cdef handle_t* handle_ = <handle_t*><uintptr_t> self.handle.getHandle()
cdef knnIndexParam* algo_params = <knnIndexParam*> 0
Expand Down
24 changes: 17 additions & 7 deletions python/cuml/test/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,21 +513,31 @@ def test_knn_graph(input_type, mode, output_type, as_instance,


@pytest.mark.parametrize('distance', ["euclidean", "haversine"])
@pytest.mark.parametrize('n_neighbors', [2, 35])
@pytest.mark.parametrize('nrows', [unit_param(500), stress_param(70000)])
@pytest.mark.parametrize('n_neighbors', [2, 12])
@pytest.mark.parametrize('nrows', [unit_param(1000), stress_param(70000)])
def test_nearest_neighbors_rbc(distance, n_neighbors, nrows):
X, y = make_blobs(n_samples=nrows,
n_features=2, random_state=0)

knn_cu = cuKNN(metric=distance, algorithm="rbc")
knn_cu.fit(X)

rbc_d, rbc_i = knn_cu.kneighbors(X[:int(nrows/2), :],
query_rows = int(nrows/2)

rbc_d, rbc_i = knn_cu.kneighbors(X[:query_rows, :],
n_neighbors=n_neighbors)

pw_dists = cuPW(X, metric="l2")
brute_i = cp.argsort(X, axis=1)
brute_d = pw_dists[brute_i][:, :n_neighbors]
if distance == 'euclidean':
# Need to use unexpanded euclidean distance
pw_dists = cuPW(X, metric="l2")
brute_i = cp.argsort(pw_dists, axis=1)[:query_rows, :n_neighbors]
brute_d = cp.sort(pw_dists, axis=1)[:query_rows, :n_neighbors]
else:
knn_cu_brute = cuKNN(metric=distance, algorithm="brute")
knn_cu_brute.fit(X)

brute_d, brute_i = knn_cu_brute.kneighbors(
X[:query_rows, :], n_neighbors=n_neighbors)

cp.testing.assert_allclose(rbc_d, brute_d, atol=5e-2,
rtol=1e-3)
Expand All @@ -537,7 +547,7 @@ def test_nearest_neighbors_rbc(distance, n_neighbors, nrows):
diff = rbc_i != brute_i

# Using a very small tolerance for subtle differences
# in indices that result from
# in indices that result from non-determinism
assert diff.ravel().sum() < 5


Expand Down

0 comments on commit 6545f76

Please sign in to comment.