diff --git a/python/pylibraft/pylibraft/common/device_ndarray.py b/python/pylibraft/pylibraft/common/device_ndarray.py index eebbca2f06..f267e0c644 100644 --- a/python/pylibraft/pylibraft/common/device_ndarray.py +++ b/python/pylibraft/pylibraft/common/device_ndarray.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -89,12 +89,8 @@ def c_contiguous(self): """ Is the current device_ndarray laid out in row-major format? """ - array_interface = self.ndarray_.__array_interface__ strides = self.strides - return ( - strides is None - or array_interface["strides"][1] == self.dtype.itemsize - ) + return strides is None or strides[1] == self.dtype.itemsize @property def f_contiguous(self): @@ -125,11 +121,7 @@ def strides(self): Strides of the current device_ndarray instance """ array_interface = self.ndarray_.__array_interface__ - return ( - None - if "strides" not in array_interface - else array_interface["strides"] - ) + return array_interface.get("strides") @property def __cuda_array_interface__(self): diff --git a/python/pylibraft/pylibraft/neighbors/brute_force.pyx b/python/pylibraft/pylibraft/neighbors/brute_force.pyx index 8836307a5a..2d118072ab 100644 --- a/python/pylibraft/pylibraft/neighbors/brute_force.pyx +++ b/python/pylibraft/pylibraft/neighbors/brute_force.pyx @@ -47,6 +47,7 @@ from pylibraft.distance.distance_type cimport DistanceType # TODO: Centralize this from pylibraft.distance.pairwise_distance import DISTANCE_TYPES +from pylibraft.neighbors.common import _check_input_array from pylibraft.common.cpp.mdspan cimport ( device_matrix_view, @@ -143,6 +144,11 @@ def knn(dataset, queries, k=None, indices=None, distances=None, raise ValueError("Argument k must be specified if both indices " "and distances arg is None") + # we require c-contiguous (rowmajor) inputs here + _check_input_array(dataset_cai, [np.dtype("float32")]) + _check_input_array(queries_cai, [np.dtype("float32")], + exp_cols=dataset_cai.shape[1]) + n_queries = queries_cai.shape[0] if indices is None: diff --git a/python/pylibraft/pylibraft/test/test_brute_force.py b/python/pylibraft/pylibraft/test/test_brute_force.py index 0bd5e6eaaf..2e118d210d 100644 --- a/python/pylibraft/pylibraft/test/test_brute_force.py +++ b/python/pylibraft/pylibraft/test/test_brute_force.py @@ -40,11 +40,8 @@ ], ) @pytest.mark.parametrize("inplace", [True, False]) -@pytest.mark.parametrize("order", ["F", "C"]) @pytest.mark.parametrize("dtype", [np.float32]) -def test_knn( - n_index_rows, n_query_rows, n_cols, k, inplace, metric, order, dtype -): +def test_knn(n_index_rows, n_query_rows, n_cols, k, inplace, metric, dtype): index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype) queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype) @@ -94,3 +91,21 @@ def test_knn( np.testing.assert_allclose( cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4 ) + + +def test_knn_check_col_major_inputs(): + # make sure that we get an exception if passed col-major inputs, + # instead of returning incorrect results + cp = pytest.importorskip("cupy") + n_index_rows, n_query_rows, n_cols = 128, 16, 32 + index = cp.random.random_sample((n_index_rows, n_cols), dtype="float32") + queries = cp.random.random_sample((n_query_rows, n_cols), dtype="float32") + + with pytest.raises(ValueError): + knn(cp.asarray(index, order="F"), queries, k=4) + + with pytest.raises(ValueError): + knn(index, cp.asarray(queries, order="F"), k=4) + + # shouldn't throw an exception with c-contiguous inputs + knn(index, queries, k=4) diff --git a/python/pylibraft/pylibraft/test/test_handle.py b/python/pylibraft/pylibraft/test/test_handle.py index ae519ea965..bb07df1000 100644 --- a/python/pylibraft/pylibraft/test/test_handle.py +++ b/python/pylibraft/pylibraft/test/test_handle.py @@ -19,10 +19,7 @@ from pylibraft.common import DeviceResources, Stream, device_ndarray from pylibraft.distance import pairwise_distance -try: - import cupy -except ImportError: - pytest.skip(reason="cupy not installed.") +cupy = pytest.importorskip("cupy") @pytest.mark.parametrize("stream", [cupy.cuda.Stream().ptr, Stream()])