From 5e84daa6ad0233e80996505086dde9e616c1321b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 17 Apr 2023 14:19:51 -0700 Subject: [PATCH] Add python bindings for matrix::select_k (#1422) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1422 --- cpp/CMakeLists.txt | 7 +- cpp/include/raft/matrix/select_k.cuh | 20 +-- cpp/include/raft_runtime/matrix/select_k.hpp | 32 +++++ .../raft_internal/matrix/select_k.cuh | 13 +- .../matrix/select_k_float_int64_t.cu | 37 +++++ python/pylibraft/CMakeLists.txt | 1 + .../pylibraft/pylibraft/matrix/CMakeLists.txt | 24 ++++ .../pylibraft/pylibraft/matrix/__init__.pxd | 14 ++ python/pylibraft/pylibraft/matrix/__init__.py | 18 +++ .../pylibraft/matrix/cpp/__init__.pxd | 0 .../pylibraft/matrix/cpp/__init__.py | 14 ++ .../pylibraft/matrix/cpp/select_k.pxd | 39 +++++ .../pylibraft/pylibraft/matrix/select_k.pyx | 133 ++++++++++++++++++ .../pylibraft/neighbors/brute_force.pyx | 3 +- ...test_brue_force.py => test_brute_force.py} | 0 .../pylibraft/pylibraft/test/test_doctests.py | 2 + .../pylibraft/pylibraft/test/test_select_k.py | 54 +++++++ 17 files changed, 389 insertions(+), 22 deletions(-) create mode 100644 cpp/include/raft_runtime/matrix/select_k.hpp create mode 100644 cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu create mode 100644 python/pylibraft/pylibraft/matrix/CMakeLists.txt create mode 100644 python/pylibraft/pylibraft/matrix/__init__.pxd create mode 100644 python/pylibraft/pylibraft/matrix/__init__.py create mode 100644 python/pylibraft/pylibraft/matrix/cpp/__init__.pxd create mode 100644 python/pylibraft/pylibraft/matrix/cpp/__init__.py create mode 100644 python/pylibraft/pylibraft/matrix/cpp/select_k.pxd create mode 100644 python/pylibraft/pylibraft/matrix/select_k.pyx rename python/pylibraft/pylibraft/test/{test_brue_force.py => test_brute_force.py} (100%) create mode 100644 python/pylibraft/pylibraft/test/test_select_k.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 1c705cc786..7f64b92306 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -70,13 +70,11 @@ option(RAFT_COMPILE_LIBRARY "Enable building raft shared library instantiations" ${RAFT_COMPILE_LIBRARY_DEFAULT} ) - -# Needed because GoogleBenchmark changes the state of FindThreads.cmake, causing subsequent runs -# to have different values for the `Threads::Threads` target. Setting this flag ensures +# Needed because GoogleBenchmark changes the state of FindThreads.cmake, causing subsequent runs to +# have different values for the `Threads::Threads` target. Setting this flag ensures # `Threads::Threads` is the same value across all builds so that cache hits occur set(THREADS_PREFER_PTHREAD_FLAG ON) - include(CMakeDependentOption) # cmake_dependent_option( RAFT_USE_FAISS_STATIC "Build and statically link the FAISS library for # nearest neighbors search on GPU" ON RAFT_COMPILE_LIBRARY OFF ) @@ -357,6 +355,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/cluster/update_centroids_float.cu src/raft_runtime/distance/fused_l2_min_arg.cu src/raft_runtime/distance/pairwise_distance.cu + src/raft_runtime/matrix/select_k_float_int64_t.cu src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu src/raft_runtime/neighbors/ivf_flat_build.cu src/raft_runtime/neighbors/ivf_flat_search.cu diff --git a/cpp/include/raft/matrix/select_k.cuh b/cpp/include/raft/matrix/select_k.cuh index 9a1a14fd73..7951cbdb03 100644 --- a/cpp/include/raft/matrix/select_k.cuh +++ b/cpp/include/raft/matrix/select_k.cuh @@ -42,13 +42,13 @@ namespace raft::matrix { * @code{.cpp} * using namespace raft; * // get a 2D row-major array of values to search through - * auto in_values = {... input device_matrix_view ...} + * auto in_values = {... input device_matrix_view ...} * // prepare output arrays - * auto out_extents = make_extents(in_values.extent(0), k); + * auto out_extents = make_extents(in_values.extent(0), k); * auto out_values = make_device_mdarray(handle, out_extents); - * auto out_indices = make_device_mdarray(handle, out_extents); + * auto out_indices = make_device_mdarray(handle, out_extents); * // search `k` smallest values in each row - * matrix::select_k( + * matrix::select_k( * handle, in_values, std::nullopt, out_values.view(), out_indices.view(), true); * @endcode * @@ -76,13 +76,13 @@ namespace raft::matrix { */ template void select_k(const device_resources& handle, - raft::device_matrix_view in_val, - std::optional> in_idx, - raft::device_matrix_view out_val, - raft::device_matrix_view out_idx, + raft::device_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, bool select_min) { - RAFT_EXPECTS(out_val.extent(1) <= size_t(std::numeric_limits::max()), + RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits::max()), "output k must fit the int type."); auto batch_size = in_val.extent(0); auto len = in_val.extent(1); @@ -93,7 +93,7 @@ void select_k(const device_resources& handle, RAFT_EXPECTS(batch_size == in_idx->extent(0), "batch sizes must be equal"); RAFT_EXPECTS(len == in_idx->extent(1), "value and index input lengths must be equal"); } - RAFT_EXPECTS(size_t(k) == out_idx.extent(1), "value and index output lengths must be equal"); + RAFT_EXPECTS(int64_t(k) == out_idx.extent(1), "value and index output lengths must be equal"); return detail::select_k(in_val.data_handle(), in_idx.has_value() ? in_idx->data_handle() : nullptr, batch_size, diff --git a/cpp/include/raft_runtime/matrix/select_k.hpp b/cpp/include/raft_runtime/matrix/select_k.hpp new file mode 100644 index 0000000000..08c0e01d0a --- /dev/null +++ b/cpp/include/raft_runtime/matrix/select_k.hpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace raft::runtime::matrix { +void select_k(const device_resources& handle, + raft::device_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min); + +} // namespace raft::runtime::matrix diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index 8dedec67cb..3d7a11e91e 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -86,12 +86,13 @@ void select_k_impl(const device_resources& handle, auto stream = handle.get_stream(); switch (algo) { case Algo::kPublicApi: { - auto in_extent = make_extents(batch_size, len); - auto out_extent = make_extents(batch_size, k); - auto in_span = make_mdspan(in, in_extent); - auto in_idx_span = make_mdspan(in_idx, in_extent); - auto out_span = make_mdspan(out, out_extent); - auto out_idx_span = make_mdspan(out_idx, out_extent); + auto in_extent = make_extents(batch_size, len); + auto out_extent = make_extents(batch_size, k); + auto in_span = make_mdspan(in, in_extent); + auto in_idx_span = + make_mdspan(in_idx, in_extent); + auto out_span = make_mdspan(out, out_extent); + auto out_idx_span = make_mdspan(out_idx, out_extent); if (in_idx == nullptr) { // NB: std::nullopt prevents automatic inference of the template parameters. return matrix::select_k( diff --git a/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu new file mode 100644 index 0000000000..309ac50c6b --- /dev/null +++ b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include + +namespace raft::runtime::matrix { + +void select_k(const device_resources& handle, + raft::device_matrix_view in_val, + std::optional> in_idx, + raft::device_matrix_view out_val, + raft::device_matrix_view out_idx, + bool select_min) +{ + raft::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, select_min); +} +} // namespace raft::runtime::matrix diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index 349a2b08ba..069bd98222 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -86,6 +86,7 @@ rapids_cython_init() add_subdirectory(pylibraft/common) add_subdirectory(pylibraft/distance) +add_subdirectory(pylibraft/matrix) add_subdirectory(pylibraft/neighbors) add_subdirectory(pylibraft/random) add_subdirectory(pylibraft/cluster) diff --git a/python/pylibraft/pylibraft/matrix/CMakeLists.txt b/python/pylibraft/pylibraft/matrix/CMakeLists.txt new file mode 100644 index 0000000000..ffba10dea9 --- /dev/null +++ b/python/pylibraft/pylibraft/matrix/CMakeLists.txt @@ -0,0 +1,24 @@ +# ============================================================================= +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +# Set the list of Cython files to build +set(cython_sources select_k.pyx) +set(linked_libraries raft::raft raft::compiled) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX matrix_ +) diff --git a/python/pylibraft/pylibraft/matrix/__init__.pxd b/python/pylibraft/pylibraft/matrix/__init__.pxd new file mode 100644 index 0000000000..a7e7b75096 --- /dev/null +++ b/python/pylibraft/pylibraft/matrix/__init__.pxd @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pylibraft/pylibraft/matrix/__init__.py b/python/pylibraft/pylibraft/matrix/__init__.py new file mode 100644 index 0000000000..5eb35795ed --- /dev/null +++ b/python/pylibraft/pylibraft/matrix/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .select_k import select_k + +__all__ = ["select_k"] diff --git a/python/pylibraft/pylibraft/matrix/cpp/__init__.pxd b/python/pylibraft/pylibraft/matrix/cpp/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/matrix/cpp/__init__.py b/python/pylibraft/pylibraft/matrix/cpp/__init__.py new file mode 100644 index 0000000000..8f2cc34855 --- /dev/null +++ b/python/pylibraft/pylibraft/matrix/cpp/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pylibraft/pylibraft/matrix/cpp/select_k.pxd b/python/pylibraft/pylibraft/matrix/cpp/select_k.pxd new file mode 100644 index 0000000000..ab466fdce6 --- /dev/null +++ b/python/pylibraft/pylibraft/matrix/cpp/select_k.pxd @@ -0,0 +1,39 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport int64_t +from libcpp cimport bool + +from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources + + +cdef extern from "raft_runtime/matrix/select_k.hpp" \ + namespace "raft::runtime::matrix" nogil: + + cdef void select_k(const device_resources & handle, + device_matrix_view[float, int64_t, row_major], + optional[device_matrix_view[int64_t, + int64_t, + row_major]], + device_matrix_view[float, int64_t, row_major], + device_matrix_view[int64_t, int64_t, row_major], + bool) except + diff --git a/python/pylibraft/pylibraft/matrix/select_k.pyx b/python/pylibraft/pylibraft/matrix/select_k.pyx new file mode 100644 index 0000000000..fbb1e2e5d3 --- /dev/null +++ b/python/pylibraft/pylibraft/matrix/select_k.pyx @@ -0,0 +1,133 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from cython.operator cimport dereference as deref +from libc.stdint cimport int64_t +from libcpp cimport bool + +import numpy as np + +from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray +from pylibraft.common.handle import auto_sync_handle +from pylibraft.common.input_validation import is_c_contiguous + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.common.mdspan cimport get_dmv_float, get_dmv_int64 +from pylibraft.matrix.cpp.select_k cimport select_k as c_select_k + + +@auto_sync_handle +@auto_convert_output +def select_k(dataset, k=None, distances=None, indices=None, select_min=True, + handle=None): + """ + Selects the top k items from each row in a matrix + + + Parameters + ---------- + dataset : array interface compliant matrix, row-major layout, + shape (n_rows, dim). Supported dtype [float] + k : int + Number of items to return for each row. Optional if indices or + distances arrays are given (in which case their second dimension + is k). + distances : Optional array interface compliant matrix shape + (n_rows, k), dtype float. If supplied, + distances will be written here in-place. (default None) + indices : Optional array interface compliant matrix shape + (n_rows, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + select_min: : bool + Whether to select the minimum or maximum K items + + {handle_docstring} + + Returns + ------- + distances: array interface compliant object containing resulting distances + shape (n_rows, k) + + indices: array interface compliant object containing resulting indices + shape (n_rows, k) + + Examples + -------- + + >>> import cupy as cp + + >>> from pylibraft.matrix import select_k + + >>> n_features = 50 + >>> n_rows = 1000 + + >>> queries = cp.random.random_sample((n_rows, n_features), + ... dtype=cp.float32) + >>> k = 40 + >>> distances, ids = select_k(queries, k) + >>> distances = cp.asarray(distances) + >>> ids = cp.asarray(ids) + """ + + dataset_cai = cai_wrapper(dataset) + + if k is None: + if indices is not None: + k = cai_wrapper(indices).shape[1] + elif distances is not None: + k = cai_wrapper(distances).shape[1] + else: + raise ValueError("Argument k must be specified if both indices " + "and distances arg is None") + + n_rows = dataset.shape[0] + if indices is None: + indices = device_ndarray.empty((n_rows, k), dtype='int64') + + if distances is None: + distances = device_ndarray.empty((n_rows, k), dtype='float32') + + distances_cai = cai_wrapper(distances) + indices_cai = cai_wrapper(indices) + + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef optional[device_matrix_view[int64_t, int64_t, row_major]] in_idx + + if dataset_cai.dtype == np.float32: + c_select_k(deref(handle_), + get_dmv_float(dataset_cai, check_shape=True), + in_idx, + get_dmv_float(distances_cai, check_shape=True), + get_dmv_int64(indices_cai, check_shape=True), + select_min) + else: + raise TypeError("dtype %s not supported" % dataset_cai.dtype) + + return distances, indices diff --git a/python/pylibraft/pylibraft/neighbors/brute_force.pyx b/python/pylibraft/pylibraft/neighbors/brute_force.pyx index dbd888756d..8836307a5a 100644 --- a/python/pylibraft/pylibraft/neighbors/brute_force.pyx +++ b/python/pylibraft/pylibraft/neighbors/brute_force.pyx @@ -40,7 +40,6 @@ from pylibraft.common.handle cimport device_resources from pylibraft.common.mdspan cimport get_dmv_float, get_dmv_int64 from pylibraft.common.handle import auto_sync_handle -from pylibraft.common.input_validation import is_c_contiguous from pylibraft.common.interruptible import cuda_interruptible from pylibraft.distance.distance_type cimport DistanceType @@ -144,7 +143,7 @@ def knn(dataset, queries, k=None, indices=None, distances=None, raise ValueError("Argument k must be specified if both indices " "and distances arg is None") - n_queries = cai_wrapper(queries).shape[0] + n_queries = queries_cai.shape[0] if indices is None: indices = device_ndarray.empty((n_queries, k), dtype='int64') diff --git a/python/pylibraft/pylibraft/test/test_brue_force.py b/python/pylibraft/pylibraft/test/test_brute_force.py similarity index 100% rename from python/pylibraft/pylibraft/test/test_brue_force.py rename to python/pylibraft/pylibraft/test/test_brute_force.py diff --git a/python/pylibraft/pylibraft/test/test_doctests.py b/python/pylibraft/pylibraft/test/test_doctests.py index 34be6c55f5..19e5c5c22f 100644 --- a/python/pylibraft/pylibraft/test/test_doctests.py +++ b/python/pylibraft/pylibraft/test/test_doctests.py @@ -22,6 +22,7 @@ import pylibraft.cluster import pylibraft.distance +import pylibraft.matrix import pylibraft.neighbors import pylibraft.random @@ -94,6 +95,7 @@ def _find_doctests_in_obj(obj, finder=None, criteria=None): DOC_STRINGS = list(_find_doctests_in_obj(pylibraft.cluster)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.common)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.distance)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.matrix.select_k)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.ivf_pq)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.brute_force)) diff --git a/python/pylibraft/pylibraft/test/test_select_k.py b/python/pylibraft/pylibraft/test/test_select_k.py new file mode 100644 index 0000000000..203e735b9c --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_select_k.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest + +from pylibraft.common import device_ndarray +from pylibraft.matrix import select_k + + +@pytest.mark.parametrize("n_rows", [32, 100]) +@pytest.mark.parametrize("n_cols", [40, 100]) +@pytest.mark.parametrize("k", [1, 5, 16, 35]) +@pytest.mark.parametrize("inplace", [True, False]) +def test_select_k(n_rows, n_cols, k, inplace): + dataset = np.random.random_sample((n_rows, n_cols)).astype("float32") + dataset_device = device_ndarray(dataset) + + indices = np.zeros((n_rows, k), dtype="int64") + distances = np.zeros((n_rows, k), dtype="float32") + indices_device = device_ndarray(indices) + distances_device = device_ndarray(distances) + + ret_distances, ret_indices = select_k( + dataset_device, + k=k, + distances=distances_device, + indices=indices_device, + ) + + distances_device = ret_distances if not inplace else distances_device + actual_distances = distances_device.copy_to_host() + argsort = np.argsort(dataset, axis=1) + + for i in range(dataset.shape[0]): + expected_indices = argsort[i] + gpu_dists = actual_distances[i] + + cpu_ordered = dataset[i, expected_indices] + np.testing.assert_allclose( + cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4 + )