From 31847afbaa55ead3ee99d44fcbe0c41ff8e1f726 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 23 Mar 2023 18:48:01 -0400 Subject: [PATCH] Python API for brute-force KNN (#1292) Closes #1289 Authors: - Corey J. Nolet (https://github.com/cjnolet) - Ben Frederickson (https://github.com/benfred) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1292 --- cpp/CMakeLists.txt | 1 + .../raft_runtime/neighbors/brute_force.hpp | 19 +- .../brute_force_knn_int64_t_float.cu | 46 ++--- python/pylibraft/pylibraft/common/mdspan.pyx | 1 - .../pylibraft/neighbors/CMakeLists.txt | 2 +- .../pylibraft/pylibraft/neighbors/__init__.py | 5 +- .../pylibraft/neighbors/brute_force.pyx | 179 ++++++++++++++++++ .../pylibraft/pylibraft/neighbors/common.pyx | 12 +- .../pylibraft/neighbors/cpp/__init__.pxd | 0 .../pylibraft/neighbors/cpp/__init__.py | 14 ++ .../pylibraft/neighbors/cpp/brute_force.pxd | 55 ++++++ .../pylibraft/test/test_brue_force.py | 99 ++++++++++ .../pylibraft/pylibraft/test/test_doctests.py | 3 +- 13 files changed, 395 insertions(+), 41 deletions(-) create mode 100644 python/pylibraft/pylibraft/neighbors/brute_force.pyx create mode 100644 python/pylibraft/pylibraft/neighbors/cpp/__init__.pxd create mode 100644 python/pylibraft/pylibraft/neighbors/cpp/__init__.py create mode 100644 python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd create mode 100644 python/pylibraft/pylibraft/test/test_brue_force.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 034dc059b0..c1704552ec 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -304,6 +304,7 @@ if(RAFT_COMPILE_LIBRARY) # These are somehow missing a kernel definition which is causing a compile error. # src/distance/specializations/detail/kernels/rbf_kernel_double.cu # src/distance/specializations/detail/kernels/rbf_kernel_float.cu + src/neighbors/brute_force_knn_int64_t_float.cu src/distance/specializations/detail/kernels/tanh_kernel_double.cu src/distance/specializations/detail/kernels/tanh_kernel_float.cu src/distance/specializations/detail/kl_divergence_float_float_float_int.cu diff --git a/cpp/include/raft_runtime/neighbors/brute_force.hpp b/cpp/include/raft_runtime/neighbors/brute_force.hpp index 19904f4f78..12da6ff101 100644 --- a/cpp/include/raft_runtime/neighbors/brute_force.hpp +++ b/cpp/include/raft_runtime/neighbors/brute_force.hpp @@ -21,18 +21,17 @@ namespace raft::runtime::neighbors::brute_force { -#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ - void knn(raft::device_resources const& handle, \ - std::vector> index, \ - raft::device_matrix_view search, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - int k, \ - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ - std::optional metric_arg = std::make_optional(2.0f), \ +#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void knn(raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ + std::optional metric_arg = std::make_optional(2.0f), \ std::optional global_id_offset = std::nullopt); -RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major); +RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); #undef RAFT_INST_BFKNN diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu index b0411a59ce..585084fc97 100644 --- a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu @@ -14,8 +14,6 @@ * limitations under the License. */ -#pragma once - #include #include #include @@ -24,30 +22,34 @@ #include +#include + namespace raft::runtime::neighbors::brute_force { -#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ - void knn(raft::device_resources const& handle, \ - std::vector> index, \ - raft::device_matrix_view search, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ - std::optional metric_arg = std::make_optional(2.0f), \ - std::optional global_id_offset = std::nullopt) \ - { \ - raft::neighbors::brute_force::knn(handle, \ - index, \ - search, \ - indices, \ - distances, \ - static_cast(indices.extent(1)), \ - metric, \ - metric_arg, \ - global_id_offset); \ +#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void knn(raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset) \ + { \ + std::vector> vec; \ + vec.push_back(index); \ + raft::neighbors::brute_force::knn(handle, \ + vec, \ + search, \ + indices, \ + distances, \ + static_cast(distances.extent(1)), \ + metric, \ + metric_arg, \ + global_id_offset); \ } -RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major); +RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); #undef RAFT_INST_BFKNN diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index c7b42ecab7..f35a94bb9c 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -159,7 +159,6 @@ cdef device_matrix_view[float, int64_t, row_major] \ return make_device_matrix_view[float, int64_t, row_major]( cai.data, shape[0], shape[1]) - cdef device_matrix_view[uint8_t, int64_t, row_major] \ get_dmv_uint8(cai, check_shape) except *: if cai.dtype != np.uint8: diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index 98f0d7f67a..7b9c1591c1 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources common.pyx refine.pyx) +set(cython_sources common.pyx refine.pyx brute_force.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index f7510ba2db..a50b6f21a7 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from pylibraft.neighbors import brute_force + from .refine import refine -__all__ = ["common", "refine"] +__all__ = ["common", "refine", "brute_force"] diff --git a/python/pylibraft/pylibraft/neighbors/brute_force.pyx b/python/pylibraft/pylibraft/neighbors/brute_force.pyx new file mode 100644 index 0000000000..dbd888756d --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/brute_force.pyx @@ -0,0 +1,179 @@ +# +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libcpp cimport bool, nullptr +from libcpp.vector cimport vector + +from pylibraft.distance.distance_type cimport DistanceType + +from pylibraft.common import ( + DeviceResources, + auto_convert_output, + cai_wrapper, + device_ndarray, +) + +from libc.stdint cimport int64_t, uintptr_t + +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.common.mdspan cimport get_dmv_float, get_dmv_int64 + +from pylibraft.common.handle import auto_sync_handle +from pylibraft.common.input_validation import is_c_contiguous +from pylibraft.common.interruptible import cuda_interruptible + +from pylibraft.distance.distance_type cimport DistanceType + +# TODO: Centralize this + +from pylibraft.distance.pairwise_distance import DISTANCE_TYPES + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.neighbors.cpp.brute_force cimport knn as c_knn + + +def _get_array_params(array_interface, check_dtype=None): + dtype = np.dtype(array_interface["typestr"]) + if check_dtype is None and dtype != check_dtype: + raise TypeError("dtype %s not supported" % dtype) + shape = array_interface["shape"] + if len(shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(shape)) + data = array_interface["data"][0] + return (shape, dtype, data) + + +@auto_sync_handle +@auto_convert_output +def knn(dataset, queries, k=None, indices=None, distances=None, + metric="sqeuclidean", metric_arg=2.0, + global_id_offset=0, handle=None): + """ + Perform a brute-force nearest neighbors search. + + Parameters + ---------- + dataset : array interface compliant matrix, row-major layout, + shape (n_samples, dim). Supported dtype [float] + queries : array interface compliant matrix, row-major layout, + shape (n_queries, dim) Supported dtype [float] + k : int + Number of neighbors to search (k <= 2048). Optional if indices or + distances arrays are given (in which case their second dimension + is k). + indices : Optional array interface compliant matrix shape + (n_queries, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + Supported dtype uint64 + distances : Optional array interface compliant matrix shape + (n_queries, k), dtype float. If supplied, neighbor + indices will be written here in-place. (default None) + + {handle_docstring} + + Returns + ------- + indices: array interface compliant object containing resulting indices + shape (n_queries, k) + + distances: array interface compliant object containing resulting distances + shape (n_queries, k) + + Examples + -------- + + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors.brute_force import knn + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 40 + >>> distances, neighbors = knn(dataset, queries, k) + >>> distances = cp.asarray(distances) + >>> neighbors = cp.asarray(neighbors) + """ + + if handle is None: + handle = DeviceResources() + + dataset_cai = cai_wrapper(dataset) + queries_cai = cai_wrapper(queries) + + if k is None: + if indices is not None: + k = cai_wrapper(indices).shape[1] + elif distances is not None: + k = cai_wrapper(distances).shape[1] + else: + raise ValueError("Argument k must be specified if both indices " + "and distances arg is None") + + n_queries = cai_wrapper(queries).shape[0] + + if indices is None: + indices = device_ndarray.empty((n_queries, k), dtype='int64') + + if distances is None: + distances = device_ndarray.empty((n_queries, k), dtype='float32') + + cdef DistanceType c_metric = DISTANCE_TYPES[metric] + + distances_cai = cai_wrapper(distances) + indices_cai = cai_wrapper(indices) + + cdef optional[float] c_metric_arg = metric_arg + cdef optional[int64_t] c_global_offset = global_id_offset + + cdef device_resources* handle_ = \ + handle.getHandle() + + if dataset_cai.dtype == np.float32: + with cuda_interruptible(): + c_knn(deref(handle_), + get_dmv_float(dataset_cai, check_shape=True), + get_dmv_float(queries_cai, check_shape=True), + get_dmv_int64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), + c_metric, + c_metric_arg, + c_global_offset) + else: + raise TypeError("dtype %s not supported" % dataset_cai.dtype) + + return (distances, indices) diff --git a/python/pylibraft/pylibraft/neighbors/common.pyx b/python/pylibraft/pylibraft/neighbors/common.pyx index a8380b589b..24c1abcf18 100644 --- a/python/pylibraft/pylibraft/neighbors/common.pyx +++ b/python/pylibraft/pylibraft/neighbors/common.pyx @@ -22,13 +22,15 @@ import warnings from pylibraft.distance.distance_type cimport DistanceType +SUPPORTED_DISTANCES = { + "sqeuclidean": DistanceType.L2Expanded, + "euclidean": DistanceType.L2SqrtExpanded, + "inner_product": DistanceType.InnerProduct, + +} + def _get_metric(metric): - SUPPORTED_DISTANCES = { - "sqeuclidean": DistanceType.L2Expanded, - "euclidean": DistanceType.L2SqrtExpanded, - "inner_product": DistanceType.InnerProduct - } if metric not in SUPPORTED_DISTANCES: if metric == "l2_expanded": warnings.warn("Using l2_expanded as a metric name is deprecated," diff --git a/python/pylibraft/pylibraft/neighbors/cpp/__init__.pxd b/python/pylibraft/pylibraft/neighbors/cpp/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/neighbors/cpp/__init__.py b/python/pylibraft/pylibraft/neighbors/cpp/__init__.py new file mode 100644 index 0000000000..a7e7b75096 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd b/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd new file mode 100644 index 0000000000..de5e0af267 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd @@ -0,0 +1,55 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +import pylibraft.common.handle + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t +from libcpp cimport bool, nullptr +from libcpp.string cimport string +from libcpp.vector cimport vector + +from rmm._lib.memory_resource cimport device_memory_resource + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.distance.distance_type cimport DistanceType + + +cdef extern from "raft_runtime/neighbors/brute_force.hpp" \ + namespace "raft::runtime::neighbors::brute_force" nogil: + + cdef void knn(const device_resources & handle, + device_matrix_view[float, int64_t, row_major] index, + device_matrix_view[float, int64_t, row_major] search, + device_matrix_view[int64_t, int64_t, row_major] indices, + device_matrix_view[float, int64_t, row_major] distances, + DistanceType metric, + optional[float] metric_arg, + optional[int64_t] global_id_offset) except + diff --git a/python/pylibraft/pylibraft/test/test_brue_force.py b/python/pylibraft/pylibraft/test/test_brue_force.py new file mode 100644 index 0000000000..f349be892d --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_brue_force.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.spatial.distance import cdist + +from pylibraft.common import DeviceResources, Stream, device_ndarray +from pylibraft.neighbors.brute_force import knn + + +@pytest.mark.parametrize("n_index_rows", [32, 100]) +@pytest.mark.parametrize("n_query_rows", [32, 100]) +@pytest.mark.parametrize("n_cols", [40, 100]) +@pytest.mark.parametrize("k", [1, 5, 32]) +@pytest.mark.parametrize( + "metric", + [ + "euclidean", + "cityblock", + "chebyshev", + "canberra", + "correlation", + "russellrao", + "cosine", + "sqeuclidean", + # "inner_product", + ], +) +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("order", ["F", "C"]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_knn( + n_index_rows, n_query_rows, n_cols, k, inplace, metric, order, dtype +): + index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype) + queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype) + + # RussellRao expects boolean arrays + if metric == "russellrao": + index[index < 0.5] = 0.0 + index[index >= 0.5] = 1.0 + queries[queries < 0.5] = 0.0 + queries[queries >= 0.5] = 1.0 + + indices = np.zeros((n_query_rows, k), dtype="int64") + distances = np.zeros((n_query_rows, k), dtype=dtype) + + index_device = device_ndarray(index) + + queries_device = device_ndarray(queries) + indices_device = device_ndarray(indices) + distances_device = device_ndarray(distances) + + s2 = Stream() + handle = DeviceResources(stream=s2) + ret_distances, ret_indices = knn( + index_device, + queries_device, + k, + indices=indices_device, + distances=distances_device, + metric=metric, + handle=handle, + ) + handle.sync() + + pw_dists = cdist(queries, index, metric=metric) + + distances_device = ret_distances if not inplace else distances_device + + actual_distances = distances_device.copy_to_host() + + actual_distances[actual_distances <= 1e-5] = 0.0 + argsort = np.argsort(pw_dists, axis=1) + + for i in range(pw_dists.shape[0]): + expected_indices = argsort[i] + gpu_dists = actual_distances[i] + + if metric == "correlation" or metric == "cosine": + gpu_dists = gpu_dists[::-1] + + cpu_ordered = pw_dists[i, expected_indices] + np.testing.assert_allclose( + cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4 + ) diff --git a/python/pylibraft/pylibraft/test/test_doctests.py b/python/pylibraft/pylibraft/test/test_doctests.py index 3276ca115f..34be6c55f5 100644 --- a/python/pylibraft/pylibraft/test/test_doctests.py +++ b/python/pylibraft/pylibraft/test/test_doctests.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -96,6 +96,7 @@ def _find_doctests_in_obj(obj, finder=None, criteria=None): DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.distance)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.ivf_pq)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.brute_force)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.random))