From c87af2d2e88d171a86674eddcd1a8de4609705b3 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Mon, 31 Jul 2023 11:17:48 -0500 Subject: [PATCH] CAGRA Python wrappers (#1665) First verstion of a CAGRA API in pylibraft. Todos: - [x] C++ raft_runtime instantiations and void overloads - [x] Cython API - [x] Solve issue of `cagra_types.hpp` including `#include ` that makes it need nvcc, blocking a clean C++ only cython build - [x] Check in pytests - [x] Add examples to docstrings - [x] Accommodate for parameter rename of #1676 - [x] Accomodate changes of #1664 - [x] Move out of experimental namespace Authors: - Dante Gama Dessavre (https://github.com/dantegd) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) --- cpp/CMakeLists.txt | 3 + cpp/include/raft/neighbors/cagra.cuh | 2 +- cpp/include/raft/neighbors/cagra_types.hpp | 8 +- .../neighbors/detail/cagra/cagra_build.cuh | 2 +- cpp/include/raft_runtime/neighbors/cagra.hpp | 91 ++ cpp/src/raft_runtime/neighbors/cagra_build.cu | 81 ++ .../raft_runtime/neighbors/cagra_search.cu | 39 + .../raft_runtime/neighbors/cagra_serialize.cu | 65 ++ docs/source/pylibraft_api/neighbors.rst | 14 + .../pylibraft/pylibraft/common/ai_wrapper.py | 1 + .../pylibraft/pylibraft/common/cai_wrapper.py | 1 + python/pylibraft/pylibraft/common/mdspan.pxd | 26 +- python/pylibraft/pylibraft/common/mdspan.pyx | 70 ++ .../pylibraft/neighbors/CMakeLists.txt | 1 + .../pylibraft/pylibraft/neighbors/__init__.py | 4 +- .../pylibraft/neighbors/cagra/CMakeLists.txt | 24 + .../pylibraft/neighbors/cagra/__init__.pxd | 0 .../pylibraft/neighbors/cagra/__init__.py | 26 + .../pylibraft/neighbors/cagra/cagra.pyx | 841 ++++++++++++++++++ .../neighbors/cagra/cpp/__init__.pxd | 0 .../pylibraft/neighbors/cagra/cpp/__init__.py | 14 + .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 202 +++++ .../pylibraft/neighbors/ivf_flat/ivf_flat.pyx | 22 +- .../pylibraft/neighbors/ivf_pq/ivf_pq.pxd | 25 + .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 2 - python/pylibraft/pylibraft/test/test_cagra.py | 296 ++++++ .../pylibraft/pylibraft/test/test_doctests.py | 5 +- 27 files changed, 1832 insertions(+), 33 deletions(-) create mode 100644 cpp/include/raft_runtime/neighbors/cagra.hpp create mode 100644 cpp/src/raft_runtime/neighbors/cagra_build.cu create mode 100644 cpp/src/raft_runtime/neighbors/cagra_search.cu create mode 100644 cpp/src/raft_runtime/neighbors/cagra_serialize.cu create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/CMakeLists.txt create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/__init__.pxd create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/__init__.py create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/cpp/__init__.pxd create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/cpp/__init__.py create mode 100644 python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd create mode 100644 python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pxd create mode 100644 python/pylibraft/pylibraft/test/test_cagra.py diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 81c8192fb2..7ee8293c5d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -390,6 +390,9 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/distance/pairwise_distance.cu src/raft_runtime/matrix/select_k_float_int64_t.cu src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu + src/raft_runtime/neighbors/cagra_build.cu + src/raft_runtime/neighbors/cagra_search.cu + src/raft_runtime/neighbors/cagra_serialize.cu src/raft_runtime/neighbors/ivf_flat_build.cu src/raft_runtime/neighbors/ivf_flat_search.cu src/raft_runtime/neighbors/ivf_flat_serialize.cu diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 6f73a35742..6bb7beca55 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -64,7 +64,7 @@ namespace raft::neighbors::cagra { * optimized_graph.view()); * @endcode * - * @tparam T data element type + * @tparam DataT data element type * @tparam IdxT type of the dataset vector indices * * @param[in] res raft resources diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 2583afdaa9..01d6a92235 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -26,7 +26,6 @@ #include #include #include -#include #include #include @@ -113,7 +112,6 @@ static_assert(std::is_aggregate_v); */ template struct index : ann::index { - using AlignDim = raft::Pow2<16 / sizeof(T)>; static_assert(!raft::is_narrowing_v, "IdxT must be able to represent all values of uint32_t"); @@ -252,7 +250,7 @@ struct index : ann::index { void update_dataset(raft::resources const& res, raft::device_matrix_view dataset) { - if (dataset.extent(1) % AlignDim::Value != 0) { + if (dataset.extent(1) * sizeof(T) % 16 != 0) { RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory"); copy_padded(res, dataset); } else { @@ -308,8 +306,8 @@ struct index : ann::index { void copy_padded(raft::resources const& res, mdspan, row_major, data_accessor> dataset) { - dataset_ = - make_device_matrix(res, dataset.extent(0), AlignDim::roundUp(dataset.extent(1))); + size_t padded_dim = round_up_safe(dataset.extent(1) * sizeof(T), 16) / sizeof(T); + dataset_ = make_device_matrix(res, dataset.extent(0), padded_dim); if (dataset_.extent(1) == dataset.extent(1)) { raft::copy(dataset_.data_handle(), dataset.data_handle(), diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 738be62e48..d19d7e7904 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -222,7 +222,7 @@ void build_knn_graph(raft::resources const& res, 1e-6; const auto throughput = num_queries_done / time; - RAFT_LOG_INFO( + RAFT_LOG_DEBUG( "# Search %12lu / %12lu (%3.2f %%), %e queries/sec, %.2f minutes ETA, self included = " "%3.2f %% \r", num_queries_done, diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp new file mode 100644 index 0000000000..6f56302776 --- /dev/null +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace raft::runtime::neighbors::cagra { + +// Using device and host_matrix_view avoids needing to typedef mutltiple mdspans based on accessors +#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::cagra::index; \ + \ + void build_device(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void build_host(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset, \ + raft::neighbors::cagra::index& idx); \ + \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra::search_params const& params, \ + const raft::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index); \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::cagra::index* index); + +RAFT_INST_CAGRA_FUNCS(float, uint32_t); +RAFT_INST_CAGRA_FUNCS(int8_t, uint32_t); +RAFT_INST_CAGRA_FUNCS(uint8_t, uint32_t); + +#undef RAFT_INST_CAGRA_FUNCS + +#define RAFT_INST_CAGRA_OPTIMIZE(IdxT) \ + void optimize_device(raft::resources const& res, \ + raft::device_matrix_view knn_graph, \ + raft::host_matrix_view new_graph); \ + \ + void optimize_host(raft::resources const& res, \ + raft::host_matrix_view knn_graph, \ + raft::host_matrix_view new_graph); + +RAFT_INST_CAGRA_OPTIMIZE(uint32_t); + +#undef RAFT_INST_CAGRA_OPTIMIZE + +} // namespace raft::runtime::neighbors::cagra diff --git a/cpp/src/raft_runtime/neighbors/cagra_build.cu b/cpp/src/raft_runtime/neighbors/cagra_build.cu new file mode 100644 index 0000000000..225d645e4e --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/cagra_build.cu @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +namespace raft::runtime::neighbors::cagra { + +#define RAFT_INST_CAGRA_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::cagra::index \ + { \ + return raft::neighbors::cagra::build(handle, params, dataset); \ + } \ + \ + auto build(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset) \ + ->raft::neighbors::cagra::index \ + { \ + return raft::neighbors::cagra::build(handle, params, dataset); \ + } \ + \ + void build_device(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::cagra::index& idx) \ + { \ + idx = build(handle, params, dataset); \ + } \ + \ + void build_host(raft::resources const& handle, \ + const raft::neighbors::cagra::index_params& params, \ + raft::host_matrix_view dataset, \ + raft::neighbors::cagra::index& idx) \ + { \ + idx = build(handle, params, dataset); \ + } + +RAFT_INST_CAGRA_BUILD(float, uint32_t); +RAFT_INST_CAGRA_BUILD(int8_t, uint32_t); +RAFT_INST_CAGRA_BUILD(uint8_t, uint32_t); + +#undef RAFT_INST_CAGRA_BUILD + +#define RAFT_INST_CAGRA_OPTIMIZE(IdxT) \ + void optimize_device(raft::resources const& handle, \ + raft::device_matrix_view knn_graph, \ + raft::host_matrix_view new_graph) \ + { \ + raft::neighbors::cagra::optimize(handle, knn_graph, new_graph); \ + } \ + void optimize_host(raft::resources const& handle, \ + raft::host_matrix_view knn_graph, \ + raft::host_matrix_view new_graph) \ + { \ + raft::neighbors::cagra::optimize(handle, knn_graph, new_graph); \ + } + +RAFT_INST_CAGRA_OPTIMIZE(uint32_t); + +#undef RAFT_INST_CAGRA_OPTIMIZE + +} // namespace raft::runtime::neighbors::cagra diff --git a/cpp/src/raft_runtime/neighbors/cagra_search.cu b/cpp/src/raft_runtime/neighbors/cagra_search.cu new file mode 100644 index 0000000000..149ae01392 --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/cagra_search.cu @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors::cagra { + +#define RAFT_INST_CAGRA_SEARCH(T, IdxT) \ + void search(raft::resources const& handle, \ + raft::neighbors::cagra::search_params const& params, \ + const raft::neighbors::cagra::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::cagra::search(handle, params, index, queries, neighbors, distances); \ + } + +RAFT_INST_CAGRA_SEARCH(float, uint32_t); +RAFT_INST_CAGRA_SEARCH(int8_t, uint32_t); +RAFT_INST_CAGRA_SEARCH(uint8_t, uint32_t); + +#undef RAFT_INST_CAGRA_SEARCH + +} // namespace raft::runtime::neighbors::cagra diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu new file mode 100644 index 0000000000..be9788562a --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include + +namespace raft::runtime::neighbors::cagra { + +#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::cagra::index& index) \ + { \ + raft::neighbors::cagra::serialize(handle, filename, index); \ + }; \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::cagra::index* index) \ + { \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, filename); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::cagra::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::cagra::serialize(handle, os, index); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::cagra::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::cagra::deserialize(handle, is); \ + } + +RAFT_INST_CAGRA_SERIALIZE(float); +RAFT_INST_CAGRA_SERIALIZE(int8_t); +RAFT_INST_CAGRA_SERIALIZE(uint8_t); + +#undef RAFT_INST_CAGRA_SERIALIZE +} // namespace raft::runtime::neighbors::cagra diff --git a/docs/source/pylibraft_api/neighbors.rst b/docs/source/pylibraft_api/neighbors.rst index c314f1c84d..ca89c25ed4 100644 --- a/docs/source/pylibraft_api/neighbors.rst +++ b/docs/source/pylibraft_api/neighbors.rst @@ -14,6 +14,20 @@ Brute Force .. autofunction:: pylibraft.neighbors.brute_force.knn +CAGRA +##### + +.. autoclass:: pylibraft.neighbors.cagra.IndexParams + :members: + +.. autofunction:: pylibraft.neighbors.cagra.build + +.. autoclass:: pylibraft.neighbors.cagra.SearchParams + :members: + +.. autofunction:: pylibraft.neighbors.cagra.search + + IVF-Flat ######## diff --git a/python/pylibraft/pylibraft/common/ai_wrapper.py b/python/pylibraft/pylibraft/common/ai_wrapper.py index b6b1f02187..b2b5935ede 100644 --- a/python/pylibraft/pylibraft/common/ai_wrapper.py +++ b/python/pylibraft/pylibraft/common/ai_wrapper.py @@ -34,6 +34,7 @@ def __init__(self, ai_arr): ai_arr : array interface array """ self.ai_ = ai_arr.__array_interface__ + self.from_cai = False @property def dtype(self): diff --git a/python/pylibraft/pylibraft/common/cai_wrapper.py b/python/pylibraft/pylibraft/common/cai_wrapper.py index cf11ea29ce..8a77a9b1b6 100644 --- a/python/pylibraft/pylibraft/common/cai_wrapper.py +++ b/python/pylibraft/pylibraft/common/cai_wrapper.py @@ -37,6 +37,7 @@ def __init__(self, cai_arr): __array_interface__=cai_arr.__cuda_array_interface__ ) super().__init__(helper) + self.from_cai = True def wrap_array(array): diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 3be8d5e1a6..6b202c2b69 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -19,10 +19,14 @@ # cython: embedsignature = True # cython: language_level = 3 -from libc.stdint cimport int8_t, int64_t, uint8_t +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t from libcpp.string cimport string -from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + host_matrix_view, + row_major, +) from pylibraft.common.handle cimport device_resources from pylibraft.common.optional cimport make_optional, optional @@ -41,3 +45,21 @@ cdef device_matrix_view[int64_t, int64_t, row_major] get_dmv_int64( cdef optional[device_matrix_view[int64_t, int64_t, row_major]] make_optional_view_int64( # noqa: E501 device_matrix_view[int64_t, int64_t, row_major]& dmv) except * + +cdef device_matrix_view[uint32_t, int64_t, row_major] get_dmv_uint32( + array, check_shape) except * + +cdef host_matrix_view[float, int64_t, row_major] get_hmv_float( + array, check_shape) except * + +cdef host_matrix_view[uint8_t, int64_t, row_major] get_hmv_uint8( + array, check_shape) except * + +cdef host_matrix_view[int8_t, int64_t, row_major] get_hmv_int8( + array, check_shape) except * + +cdef host_matrix_view[int64_t, int64_t, row_major] get_hmv_int64( + array, check_shape) except * + +cdef host_matrix_view[uint32_t, int64_t, row_major] get_hmv_uint32( + array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index f35a94bb9c..1219b1612d 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -30,6 +30,7 @@ from libc.stdint cimport int8_t, int32_t, int64_t, uint8_t, uint32_t, uintptr_t from pylibraft.common.cpp.mdspan cimport ( col_major, device_matrix_view, + host_matrix_view, host_mdspan, make_device_matrix_view, make_host_matrix_view, @@ -195,3 +196,72 @@ cdef device_matrix_view[int64_t, int64_t, row_major] \ cdef optional[device_matrix_view[int64_t, int64_t, row_major]] \ make_optional_view_int64(device_matrix_view[int64_t, int64_t, row_major]& dmv) except *: # noqa: E501 return make_optional[device_matrix_view[int64_t, int64_t, row_major]](dmv) + + +# todo(dantegd): we can unify and simplify this functions a little bit +# defining extra functions as-is is the quickest way to get what we need for +# cagra.pyx +cdef device_matrix_view[uint32_t, int64_t, row_major] \ + get_dmv_uint32(cai, check_shape) except *: + if cai.dtype != np.uint32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[uint32_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[float, int64_t, row_major] \ + get_hmv_float(cai, check_shape) except *: + if cai.dtype != np.float32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[float, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[uint8_t, int64_t, row_major] \ + get_hmv_uint8(cai, check_shape) except *: + if cai.dtype != np.uint8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[uint8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[int8_t, int64_t, row_major] \ + get_hmv_int8(cai, check_shape) except *: + if cai.dtype != np.int8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[int8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[int64_t, int64_t, row_major] \ + get_hmv_int64(cai, check_shape) except *: + if cai.dtype != np.int64: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[int64_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[uint32_t, int64_t, row_major] \ + get_hmv_uint32(cai, check_shape) except *: + if cai.dtype != np.int64: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[uint32_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index 7b9c1591c1..45cd9f74e6 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -23,5 +23,6 @@ rapids_cython_create_modules( LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX neighbors_ ) +add_subdirectory(cagra) add_subdirectory(ivf_flat) add_subdirectory(ivf_pq) diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index a50b6f21a7..325ea5842e 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -13,8 +13,8 @@ # limitations under the License. # -from pylibraft.neighbors import brute_force +from pylibraft.neighbors import brute_force, cagra, ivf_flat, ivf_pq from .refine import refine -__all__ = ["common", "refine", "brute_force"] +__all__ = ["common", "refine", "brute_force", "ivf_flat", "ivf_pq", "cagra"] diff --git a/python/pylibraft/pylibraft/neighbors/cagra/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/cagra/CMakeLists.txt new file mode 100644 index 0000000000..441bb0b311 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra/CMakeLists.txt @@ -0,0 +1,24 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +# Set the list of Cython files to build +set(cython_sources cagra.pyx) +set(linked_libraries raft::raft raft::compiled) + +# Build all of the Cython targets +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS raft MODULE_PREFIX neighbors_cagra_ +) diff --git a/python/pylibraft/pylibraft/neighbors/cagra/__init__.pxd b/python/pylibraft/pylibraft/neighbors/cagra/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/neighbors/cagra/__init__.py b/python/pylibraft/pylibraft/neighbors/cagra/__init__.py new file mode 100644 index 0000000000..b2a872fc89 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .cagra import Index, IndexParams, SearchParams, build, load, save, search + +__all__ = [ + "Index", + "IndexParams", + "SearchParams", + "build", + "load", + "save", + "search", +] diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx new file mode 100644 index 0000000000..7d758a32ef --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -0,0 +1,841 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import warnings + +import numpy as np + +from cython.operator cimport dereference as deref +from libc.stdint cimport ( + int8_t, + int32_t, + int64_t, + uint8_t, + uint32_t, + uint64_t, + uintptr_t, +) +from libcpp cimport bool, nullptr +from libcpp.string cimport string + +from pylibraft.distance.distance_type cimport DistanceType + +from pylibraft.common import ( + DeviceResources, + ai_wrapper, + auto_convert_output, + cai_wrapper, + device_ndarray, +) +from pylibraft.common.cai_wrapper import wrap_array +from pylibraft.common.interruptible import cuda_interruptible + +from pylibraft.common.handle cimport device_resources + +from pylibraft.common.handle import auto_sync_handle +from pylibraft.common.input_validation import is_c_contiguous + +from rmm._lib.memory_resource cimport ( + DeviceMemoryResource, + device_memory_resource, +) + +cimport pylibraft.neighbors.cagra.cpp.c_cagra as c_cagra +from pylibraft.common.optional cimport make_optional, optional + +from pylibraft.neighbors.common import _check_input_array, _get_metric + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + device_vector_view, + make_device_vector_view, + row_major, +) +from pylibraft.common.mdspan cimport ( + get_dmv_float, + get_dmv_int8, + get_dmv_int64, + get_dmv_uint8, + get_dmv_uint32, + get_hmv_float, + get_hmv_int8, + get_hmv_int64, + get_hmv_uint8, + get_hmv_uint32, + make_optional_view_int64, +) +from pylibraft.neighbors.common cimport _get_metric_string + + +cdef class IndexParams: + cdef c_cagra.index_params params + + def __init__(self, *, + metric="sqeuclidean", + intermediate_graph_degree=128, + graph_degree=64, + add_data_on_build=True): + """" + Parameters to build index for CAGRA nearest neighbor search + + Parameters + ---------- + metric : string denoting the metric type, default="sqeuclidean" + Valid values for metric: ["sqeuclidean", "inner_product", + "euclidean"], where + - sqeuclidean is the euclidean distance without the square root + operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, + - euclidean is the euclidean distance + - inner product distance is defined as + distance(a, b) = \\sum_i a_i * b_i. + intermediate_graph_degree : int, default = 128 + + graph_degree : int, default = 64 + + add_data_on_build : bool, default = True + After training the coarse and fine quantizers, we will populate + the index with the dataset if add_data_on_build == True, otherwise + the index is left empty, and the extend method can be used + to add new vectors to the index. + """ + self.params.metric = _get_metric(metric) + self.params.metric_arg = 0 + self.params.intermediate_graph_degree = intermediate_graph_degree + self.params.graph_degree = graph_degree + self.params.add_data_on_build = add_data_on_build + + @property + def metric(self): + return self.params.metric + + @property + def intermediate_graph_degree(self): + return self.params.intermediate_graph_degree + + @property + def graph_degree(self): + return self.params.graph_degree + + @property + def add_data_on_build(self): + return self.params.add_data_on_build + + +cdef class Index: + cdef readonly bool trained + cdef str active_index_type + + def __cinit__(self): + self.trained = False + self.active_index_type = None + + +cdef class IndexFloat(Index): + cdef c_cagra.index[float, uint32_t] * index + + def __cinit__(self, handle=None): + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + self.index = new c_cagra.index[float, uint32_t]( + deref(handle_)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["metric", "dim", "graph_degree"]] + attr_str = m_str + attr_str + return "Index(type=CAGRA, " + (", ".join(attr_str)) + ")" + + @property + def metric(self): + return self.index[0].metric() + + @property + def size(self): + return self.index[0].size() + + @property + def dim(self): + return self.index[0].dim() + + @property + def graph_degree(self): + return self.index[0].graph_degree() + + def __dealloc__(self): + if self.index is not NULL: + del self.index + + +cdef class IndexInt8(Index): + cdef c_cagra.index[int8_t, uint32_t] * index + + def __cinit__(self, handle=None): + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + self.index = new c_cagra.index[int8_t, uint32_t]( + deref(handle_)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["metric", "dim", "graph_degree"]] + attr_str = m_str + attr_str + return "Index(type=CAGRA, " + (", ".join(attr_str)) + ")" + + @property + def metric(self): + return self.index[0].metric() + + @property + def size(self): + return self.index[0].size() + + @property + def dim(self): + return self.index[0].dim() + + @property + def graph_degree(self): + return self.index[0].graph_degree() + + def __dealloc__(self): + if self.index is not NULL: + del self.index + + +cdef class IndexUint8(Index): + cdef c_cagra.index[uint8_t, uint32_t] * index + + def __cinit__(self, handle=None): + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + self.index = new c_cagra.index[uint8_t, uint32_t]( + deref(handle_)) + + def __repr__(self): + m_str = "metric=" + _get_metric_string(self.index.metric()) + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["metric", "dim", "graph_degree"]] + attr_str = m_str + attr_str + return "Index(type=CAGRA, " + (", ".join(attr_str)) + ")" + + @property + def metric(self): + return self.index[0].metric() + + @property + def size(self): + return self.index[0].size() + + @property + def dim(self): + return self.index[0].dim() + + @property + def graph_degree(self): + return self.index[0].graph_degree() + + def __dealloc__(self): + if self.index is not NULL: + del self.index + + +@auto_sync_handle +@auto_convert_output +def build(IndexParams index_params, dataset, handle=None): + """ + Build the CAGRA index from the dataset for efficient search. + + The build performs two different steps- first an intermediate knn-graph is + constructed, then it's optimized it to create the final graph. The + index_params object controls the node degree of these graphs. + + It is required that both the dataset and the optimized graph fit the + GPU memory. + + The following distance metrics are supported: + - L2 + + Parameters + ---------- + index_params : IndexParams object + dataset : CUDA array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int8, uint8] + {handle_docstring} + + Returns + ------- + index: cagra.Index + + Examples + -------- + + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> k = 10 + + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> handle = DeviceResources() + >>> build_params = cagra.IndexParams(metric="sqeuclidean") + + >>> index = cagra.build(build_params, dataset, handle=handle) + + >>> distances, neighbors = cagra.search(cagra.SearchParams(), + ... index, dataset, + ... k, handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + >>> distances = cp.asarray(distances) + >>> neighbors = cp.asarray(neighbors) + """ + dataset_ai = wrap_array(dataset) + dataset_dt = dataset_ai.dtype + _check_input_array(dataset_ai, [np.dtype('float32'), np.dtype('byte'), + np.dtype('ubyte')]) + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + if dataset_ai.from_cai: + if dataset_dt == np.float32: + idx_float = IndexFloat(handle) + idx_float.active_index_type = "float32" + with cuda_interruptible(): + c_cagra.build_device( + deref(handle_), + index_params.params, + get_dmv_float(dataset_ai, check_shape=True), + deref(idx_float.index)) + idx_float.trained = True + return idx_float + elif dataset_dt == np.byte: + idx_int8 = IndexInt8(handle) + idx_int8.active_index_type = "byte" + with cuda_interruptible(): + c_cagra.build_device( + deref(handle_), + index_params.params, + get_dmv_int8(dataset_ai, check_shape=True), + deref(idx_int8.index)) + idx_int8.trained = True + return idx_int8 + elif dataset_dt == np.ubyte: + idx_uint8 = IndexUint8(handle) + idx_uint8.active_index_type = "ubyte" + with cuda_interruptible(): + c_cagra.build_device( + deref(handle_), + index_params.params, + get_dmv_uint8(dataset_ai, check_shape=True), + deref(idx_uint8.index)) + idx_uint8.trained = True + return idx_uint8 + else: + raise TypeError("dtype %s not supported" % dataset_dt) + else: + if dataset_dt == np.float32: + idx_float = IndexFloat(handle) + idx_float.active_index_type = "float32" + with cuda_interruptible(): + c_cagra.build_host( + deref(handle_), + index_params.params, + get_hmv_float(dataset_ai, check_shape=True), + deref(idx_float.index)) + idx_float.trained = True + return idx_float + elif dataset_dt == np.byte: + idx_int8 = IndexInt8(handle) + idx_int8.active_index_type = "byte" + with cuda_interruptible(): + c_cagra.build_host( + deref(handle_), + index_params.params, + get_hmv_int8(dataset_ai, check_shape=True), + deref(idx_int8.index)) + idx_int8.trained = True + return idx_int8 + elif dataset_dt == np.ubyte: + idx_uint8 = IndexUint8(handle) + idx_uint8.active_index_type = "ubyte" + with cuda_interruptible(): + c_cagra.build_host( + deref(handle_), + index_params.params, + get_hmv_uint8(dataset_ai, check_shape=True), + deref(idx_uint8.index)) + idx_uint8.trained = True + return idx_uint8 + else: + raise TypeError("dtype %s not supported" % dataset_dt) + + +cdef class SearchParams: + cdef c_cagra.search_params params + + def __init__(self, *, + max_queries=0, + itopk_size=64, + max_iterations=0, + algo="auto", + team_size=0, + search_width=1, + min_iterations=0, + thread_block_size=0, + hashmap_mode="auto", + hashmap_min_bitlen=0, + hashmap_max_fill_rate=0.5, + num_random_samplings=1, + rand_xor_mask=0x128394): + """ + CAGRA search parameters + + Parameters + ---------- + max_queries: int, default = 0 + Maximum number of queries to search at the same time (batch size). + Auto select when 0. + itopk_size: int, default = 64 + Number of intermediate search results retained during the search. + This is the main knob to adjust trade off between accuracy and + search speed. Higher values improve the search accuracy. + max_iterations: int, default = 0 + Upper limit of search iterations. Auto select when 0. + algo: string denoting the search algorithm to use, default = "auto" + Valid values for algo: ["auto", "single_cta", "multi_cta"], where + - auto will automatically select the best value based on query size + - single_cta is better when query contains larger number of + vectors (e.g >10) + - multi_cta is better when query contains only a few vectors + team_size: int, default = 0 + Number of threads used to calculate a single distance. 4, 8, 16, + or 32. + search_width: int, default = 1 + Number of graph nodes to select as the starting point for the + search in each iteration. + min_iterations: int, default = 0 + Lower limit of search iterations. + thread_block_size: int, default = 0 + Thread block size. 0, 64, 128, 256, 512, 1024. + Auto selection when 0. + hashmap_mode: string denoting the type of hash map to use. It's + usually better to allow the algorithm to select this value., + default = "auto" + Valid values for hashmap_mode: ["auto", "small", "hash"], where + - auto will automatically select the best value based on algo + - small will use the small shared memory hash table with resetting. + - hash will use a single hash table in global memory. + hashmap_min_bitlen: int, default = 0 + Upper limit of hashmap fill rate. More than 0.1, less than 0.9. + hashmap_max_fill_rate: float, default = 0.5 + Upper limit of hashmap fill rate. More than 0.1, less than 0.9. + num_random_samplings: int, default = 1 + Number of iterations of initial random seed node selection. 1 or + more. + rand_xor_mask: int, default = 0x128394 + Bit mask used for initial random seed node selection. + + + """ + self.params.max_queries = max_queries + self.params.itopk_size = itopk_size + self.params.max_iterations = max_iterations + if algo == "single_cta": + self.params.algo = c_cagra.search_algo.SINGLE_CTA + elif algo == "multi_cta": + self.params.algo = c_cagra.search_algo.MULTI_CTA + elif algo == "multi_kernel": + self.params.algo = c_cagra.search_algo.MULTI_KERNEL + elif algo == "auto": + self.params.algo = c_cagra.search_algo.AUTO + else: + raise ValueError("`algo` value not supported.") + + self.params.team_size = team_size + self.params.search_width = search_width + self.params.min_iterations = min_iterations + self.params.thread_block_size = thread_block_size + if hashmap_mode == "hash": + self.params.hashmap_mode = c_cagra.hash_mode.HASH + elif hashmap_mode == "small": + self.params.hashmap_mode = c_cagra.hash_mode.SMALL + elif hashmap_mode == "auto": + self.params.hashmap_mode = c_cagra.hash_mode.AUTO + else: + raise ValueError("`hashmap_mode` value not supported.") + + self.params.hashmap_min_bitlen = hashmap_min_bitlen + self.params.hashmap_max_fill_rate = hashmap_max_fill_rate + self.params.num_random_samplings = num_random_samplings + self.params.rand_xor_mask = rand_xor_mask + + def __repr__(self): + # todo(dantegd): add all relevant attrs + attr_str = [attr + "=" + str(getattr(self, attr)) + for attr in ["max_queries"]] + return "SearchParams(type=CAGRA, " + (", ".join(attr_str)) + ")" + + @property + def max_queries(self): + return self.params.max_queries + + @property + def itopk_size(self): + return self.params.itopk_size + + @property + def max_iterations(self): + return self.params.max_iterations + + @property + def algo(self): + return self.params.algo + + @property + def team_size(self): + return self.params.team_size + + @property + def search_width(self): + return self.params.search_width + + @property + def min_iterations(self): + return self.params.min_iterations + + @property + def thread_block_size(self): + return self.params.thread_block_size + + @property + def hashmap_mode(self): + return self.params.hashmap_mode + + @property + def hashmap_min_bitlen(self): + return self.params.hashmap_min_bitlen + + @property + def hashmap_max_fill_rate(self): + return self.params.hashmap_max_fill_rate + + @property + def num_random_samplings(self): + return self.params.num_random_samplings + + @property + def rand_xor_mask(self): + return self.params.rand_xor_mask + + +@auto_sync_handle +@auto_convert_output +def search(SearchParams search_params, + Index index, + queries, + k, + neighbors=None, + distances=None, + handle=None): + """ + Find the k nearest neighbors for each query. + + Parameters + ---------- + search_params : SearchParams + index : Index + Trained CAGRA index. + queries : CUDA array interface compliant matrix shape (n_samples, dim) + Supported dtype [float, int8, uint8] + k : int + The number of neighbors. + neighbors : Optional CUDA array interface compliant matrix shape + (n_queries, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + distances : Optional CUDA array interface compliant matrix shape + (n_queries, k) If supplied, the distances to the + neighbors will be written here in-place. (default None) + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 10 + >>> search_params = cagra.SearchParams( + ... max_queries=100, + ... itopk_size=64 + ... ) + + >>> # Using a pooling allocator reduces overhead of temporary array + >>> # creation during search. This is useful if multiple searches + >>> # are performad with same query size. + >>> distances, neighbors = cagra.search(search_params, index, queries, + ... k, handle=handle) + + >>> # pylibraft functions are often asynchronous so the + >>> # handle needs to be explicitly synchronized + >>> handle.sync() + + >>> neighbors = cp.asarray(neighbors) + >>> distances = cp.asarray(distances) + """ + + if not index.trained: + raise ValueError("Index need to be built before calling search.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + queries_cai = cai_wrapper(queries) + queries_dt = queries_cai.dtype + cdef uint32_t n_queries = queries_cai.shape[0] + + _check_input_array(queries_cai, [np.dtype('float32'), np.dtype('byte'), + np.dtype('ubyte')], + exp_cols=index.dim) + + if neighbors is None: + neighbors = device_ndarray.empty((n_queries, k), dtype='uint32') + + neighbors_cai = cai_wrapper(neighbors) + _check_input_array(neighbors_cai, [np.dtype('uint32')], + exp_rows=n_queries, exp_cols=k) + + if distances is None: + distances = device_ndarray.empty((n_queries, k), dtype='float32') + + distances_cai = cai_wrapper(distances) + _check_input_array(distances_cai, [np.dtype('float32')], + exp_rows=n_queries, exp_cols=k) + + cdef c_cagra.search_params params = search_params.params + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + if queries_dt == np.float32: + idx_float = index + with cuda_interruptible(): + c_cagra.search(deref(handle_), + params, + deref(idx_float.index), + get_dmv_float(queries_cai, check_shape=True), + get_dmv_uint32(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) + elif queries_dt == np.byte: + idx_int8 = index + with cuda_interruptible(): + c_cagra.search(deref(handle_), + params, + deref(idx_int8.index), + get_dmv_int8(queries_cai, check_shape=True), + get_dmv_uint32(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) + elif queries_dt == np.ubyte: + idx_uint8 = index + with cuda_interruptible(): + c_cagra.search(deref(handle_), + params, + deref(idx_uint8.index), + get_dmv_uint8(queries_cai, check_shape=True), + get_dmv_uint32(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) + else: + raise ValueError("query dtype %s not supported" % queries_dt) + + return (distances, neighbors) + + +@auto_sync_handle +def save(filename, Index index, handle=None): + """ + Saves the index to file. + + Saving / loading the index is. The serialization format is + subject to change. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained CAGRA index. + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> # Build index + >>> handle = DeviceResources() + >>> index = cagra.build(cagra.IndexParams(), dataset, handle=handle) + >>> cagra.save("my_index.bin", index, handle=handle) + """ + if not index.trained: + raise ValueError("Index need to be built before saving it.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + if index.active_index_type == "float32": + idx_float = index + c_cagra.serialize_file( + deref(handle_), c_filename, deref(idx_float.index)) + elif index.active_index_type == "byte": + idx_int8 = index + c_cagra.serialize_file( + deref(handle_), c_filename, deref(idx_int8.index)) + elif index.active_index_type == "ubyte": + idx_uint8 = index + c_cagra.serialize_file( + deref(handle_), c_filename, deref(idx_uint8.index)) + else: + raise ValueError( + "Index dtype %s not supported" % index.active_index_type) + + +@auto_sync_handle +def load(filename, handle=None): + """ + Loads index from file. + + Saving / loading the index is. The serialization format is + subject to change, therefore loading an index saved with a previous + version of raft is not guaranteed to work. + + Parameters + ---------- + filename : string + Name of the file. + {handle_docstring} + + Returns + ------- + index : Index + + Examples + -------- + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import cagra + + """ + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + # we extract the dtype from the arrai interfaces in the file + with open(filename, 'rb') as f: + type_str = f.read(700).decode("utf-8", errors='ignore') + + dataset_dt = np.dtype(type_str[673:676]) + + if dataset_dt == np.float32: + idx_float = IndexFloat(handle) + c_cagra.deserialize_file( + deref(handle_), c_filename, idx_float.index) + idx_float.trained = True + idx_float.active_index_type = 'float32' + return idx_float + elif dataset_dt == np.byte: + idx_int8 = IndexInt8(handle) + c_cagra.deserialize_file( + deref(handle_), c_filename, idx_int8.index) + idx_int8.trained = True + idx_int8.active_index_type = 'byte' + return idx_int8 + elif dataset_dt == np.ubyte: + idx_uint8 = IndexUint8(handle) + c_cagra.deserialize_file( + deref(handle_), c_filename, idx_uint8.index) + idx_uint8.trained = True + idx_uint8.active_index_type = 'ubyte' + return idx_uint8 + else: + raise ValueError("Dataset dtype %s not supported" % dataset_dt) diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/__init__.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/__init__.py b/python/pylibraft/pylibraft/neighbors/cagra/cpp/__init__.py new file mode 100644 index 0000000000..8f2cc34855 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd new file mode 100644 index 0000000000..284c75b771 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -0,0 +1,202 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +import pylibraft.common.handle + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uint64_t +from libcpp cimport bool, nullptr +from libcpp.string cimport string + +from rmm._lib.memory_resource cimport device_memory_resource + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + device_vector_view, + host_matrix_view, + row_major, +) +from pylibraft.common.handle cimport device_resources +from pylibraft.common.optional cimport optional +from pylibraft.distance.distance_type cimport DistanceType +from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( + ann_index, + ann_index_params, + ann_search_params, + index_params as ivfpq_ip, + search_params as ivfpq_sp, +) + + +cdef extern from "raft/neighbors/cagra_types.hpp" \ + namespace "raft::neighbors::cagra" nogil: + + cpdef cppclass index_params(ann_index_params): + size_t intermediate_graph_degree + size_t graph_degree + + ctypedef enum search_algo: + SINGLE_CTA "raft::neighbors::cagra::search_algo::SINGLE_CTA", + MULTI_CTA "raft::neighbors::cagra::search_algo::MULTI_CTA", + MULTI_KERNEL "raft::neighbors::cagra::search_algo::MULTI_KERNEL", + AUTO "raft::neighbors::cagra::search_algo::AUTO" + + ctypedef enum hash_mode: + HASH "raft::neighbors::cagra::hash_mode::HASH", + SMALL "raft::neighbors::cagra::hash_mode::SMALL", + AUTO "raft::neighbors::cagra::hash_mode::AUTO" + + cpdef cppclass search_params(ann_search_params): + size_t max_queries + size_t itopk_size + size_t max_iterations + search_algo algo + size_t team_size + size_t search_width + size_t min_iterations + size_t thread_block_size + hash_mode hashmap_mode + size_t hashmap_min_bitlen + float hashmap_max_fill_rate + uint32_t num_random_samplings + uint64_t rand_xor_mask + + cdef cppclass index[T, IdxT](ann_index): + index(const device_resources&) + + DistanceType metric() + IdxT size() + uint32_t dim() + uint32_t graph_degree() + device_matrix_view[T, IdxT, row_major] dataset() + device_matrix_view[T, IdxT, row_major] graph() + +cdef extern from "raft_runtime/neighbors/cagra.hpp" \ + namespace "raft::runtime::neighbors::cagra" nogil: + + cdef void build_device( + const device_resources& handle, + const index_params& params, + device_matrix_view[float, int64_t, row_major] dataset, + index[float, uint32_t]& index) except + + + cdef void build_device( + const device_resources& handle, + const index_params& params, + device_matrix_view[int8_t, int64_t, row_major] dataset, + index[int8_t, uint32_t]& index) except + + + cdef void build_device( + const device_resources& handle, + const index_params& params, + device_matrix_view[uint8_t, int64_t, row_major] dataset, + index[uint8_t, uint32_t]& index) except + + + cdef void build_host( + const device_resources& handle, + const index_params& params, + host_matrix_view[float, int64_t, row_major] dataset, + index[float, uint32_t]& index) except + + + cdef void build_host( + const device_resources& handle, + const index_params& params, + host_matrix_view[int8_t, int64_t, row_major] dataset, + index[int8_t, uint32_t]& index) except + + + cdef void build_host( + const device_resources& handle, + const index_params& params, + host_matrix_view[uint8_t, int64_t, row_major] dataset, + index[uint8_t, uint32_t]& index) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[float, uint32_t]& index, + device_matrix_view[float, int64_t, row_major] queries, + device_matrix_view[uint32_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[int8_t, uint32_t]& index, + device_matrix_view[int8_t, int64_t, row_major] queries, + device_matrix_view[uint32_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[uint8_t, uint32_t]& index, + device_matrix_view[uint8_t, int64_t, row_major] queries, + device_matrix_view[uint32_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + + + cdef void serialize(const device_resources& handle, + string& str, + const index[float, uint32_t]& index) except + + + cdef void deserialize(const device_resources& handle, + const string& str, + index[float, uint32_t]* index) except + + + cdef void serialize(const device_resources& handle, + string& str, + const index[uint8_t, uint32_t]& index) except + + + cdef void deserialize(const device_resources& handle, + const string& str, + index[uint8_t, uint32_t]* index) except + + + cdef void serialize(const device_resources& handle, + string& str, + const index[int8_t, uint32_t]& index) except + + + cdef void deserialize(const device_resources& handle, + const string& str, + index[int8_t, uint32_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[float, uint32_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[float, uint32_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[uint8_t, uint32_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[uint8_t, uint32_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[int8_t, uint32_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[int8_t, uint32_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx index 0e550547d3..e265bee23b 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx @@ -614,26 +614,10 @@ def search(SearchParams search_params, ... dtype=cp.float32) >>> k = 10 >>> search_params = ivf_flat.SearchParams( - ... n_probes=20, - ... lut_dtype=cp.float16, - ... internal_distance_dtype=cp.float32 - ... ) - - # TODO update example to set default pool allocator - # (instead of passing an mr) - - >>> # Using a pooling allocator reduces overhead of temporary array - >>> # creation during search. This is useful if multiple searches - >>> # are performad with same query size. - >>> import rmm - >>> mr = rmm.mr.PoolMemoryResource( - ... rmm.mr.CudaMemoryResource(), - ... initial_pool_size=2**29, - ... maximum_pool_size=2**31 + ... n_probes=20 ... ) >>> distances, neighbors = ivf_flat.search(search_params, index, queries, - ... k, memory_resource=mr, - ... handle=handle) + ... k, handle=handle) >>> # pylibraft functions are often asynchronous so the >>> # handle needs to be explicitly synchronized @@ -817,7 +801,7 @@ def load(filename, handle=None): >>> handle = DeviceResources() >>> index = ivf_flat.load("my_index.bin", handle=handle) - >>> distances, neighbors = ivf_flat.search(ivf_pq.SearchParams(), index, + >>> distances, neighbors = ivf_flat.search(ivf_flat.SearchParams(), index, ... queries, k=10, handle=handle) """ if handle is None: diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pxd new file mode 100644 index 0000000000..1b99da1fd7 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pxd @@ -0,0 +1,25 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# distutils: language = c++ + +cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq + + +cdef class IndexParams: + cdef c_ivf_pq.index_params params + +cdef class SearchParams: + cdef c_ivf_pq.search_params params diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index b89e5dd44d..413a9a1d4b 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -95,7 +95,6 @@ cdef _get_dtype_string(dtype): cdef class IndexParams: - cdef c_ivf_pq.index_params params def __init__(self, *, n_lists=1024, @@ -521,7 +520,6 @@ def extend(Index index, new_vectors, new_indices, handle=None): cdef class SearchParams: - cdef c_ivf_pq.search_params params def __init__(self, *, n_probes=20, lut_dtype=np.float32, diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py new file mode 100644 index 0000000000..435b2878a2 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -0,0 +1,296 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# h ttp://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import normalize + +from pylibraft.common import device_ndarray +from pylibraft.neighbors import cagra + + +# todo (dantegd): consolidate helper utils of ann methods +def generate_data(shape, dtype): + if dtype == np.byte: + x = np.random.randint(-127, 128, size=shape, dtype=np.byte) + elif dtype == np.ubyte: + x = np.random.randint(0, 255, size=shape, dtype=np.ubyte) + else: + x = np.random.random_sample(shape).astype(dtype) + + return x + + +def calc_recall(ann_idx, true_nn_idx): + assert ann_idx.shape == true_nn_idx.shape + n = 0 + for i in range(ann_idx.shape[0]): + n += np.intersect1d(ann_idx[i, :], true_nn_idx[i, :]).size + recall = n / ann_idx.size + return recall + + +def run_cagra_build_search_test( + n_rows=10000, + n_cols=10, + n_queries=100, + k=10, + dtype=np.float32, + metric="euclidean", + intermediate_graph_degree=128, + graph_degree=64, + array_type="device", + compare=True, + inplace=True, + add_data_on_build=True, + search_params={}, +): + dataset = generate_data((n_rows, n_cols), dtype) + if metric == "inner_product": + dataset = normalize(dataset, norm="l2", axis=1) + dataset_device = device_ndarray(dataset) + + build_params = cagra.IndexParams( + metric=metric, + intermediate_graph_degree=intermediate_graph_degree, + graph_degree=graph_degree, + ) + + if array_type == "device": + index = cagra.build(build_params, dataset_device) + else: + index = cagra.build(build_params, dataset) + + assert index.trained + + if not add_data_on_build: + dataset_1 = dataset[: n_rows // 2, :] + dataset_2 = dataset[n_rows // 2 :, :] + indices_1 = np.arange(n_rows // 2, dtype=np.uint32) + indices_2 = np.arange(n_rows // 2, n_rows, dtype=np.uint32) + if array_type == "device": + dataset_1_device = device_ndarray(dataset_1) + dataset_2_device = device_ndarray(dataset_2) + indices_1_device = device_ndarray(indices_1) + indices_2_device = device_ndarray(indices_2) + index = cagra.extend(index, dataset_1_device, indices_1_device) + index = cagra.extend(index, dataset_2_device, indices_2_device) + else: + index = cagra.extend(index, dataset_1, indices_1) + index = cagra.extend(index, dataset_2, indices_2) + + queries = generate_data((n_queries, n_cols), dtype) + out_idx = np.zeros((n_queries, k), dtype=np.uint32) + out_dist = np.zeros((n_queries, k), dtype=np.float32) + + queries_device = device_ndarray(queries) + out_idx_device = device_ndarray(out_idx) if inplace else None + out_dist_device = device_ndarray(out_dist) if inplace else None + + search_params = cagra.SearchParams(**search_params) + + ret_output = cagra.search( + search_params, + index, + queries_device, + k, + neighbors=out_idx_device, + distances=out_dist_device, + ) + + if not inplace: + out_dist_device, out_idx_device = ret_output + + if not compare: + return + + out_idx = out_idx_device.copy_to_host() + out_dist = out_dist_device.copy_to_host() + + # Calculate reference values with sklearn + skl_metric = { + "sqeuclidean": "sqeuclidean", + "inner_product": "cosine", + "euclidean": "euclidean", + }[metric] + nn_skl = NearestNeighbors( + n_neighbors=k, algorithm="brute", metric=skl_metric + ) + nn_skl.fit(dataset) + skl_idx = nn_skl.kneighbors(queries, return_distance=False) + + recall = calc_recall(out_idx, skl_idx) + assert recall > 0.7 + + +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("array_type", ["device", "host"]) +def test_cagra_dataset_dtype_host_device(dtype, array_type, inplace): + # Note that inner_product tests use normalized input which we cannot + # represent in int8, therefore we test only sqeuclidean metric here. + run_cagra_build_search_test( + dtype=dtype, + inplace=inplace, + array_type=array_type, + ) + + +@pytest.mark.parametrize( + "params", + [ + { + "intermediate_graph_degree": 64, + "graph_degree": 32, + "add_data_on_build": True, + "k": 1, + "metric": "euclidean", + }, + { + "intermediate_graph_degree": 32, + "graph_degree": 16, + "add_data_on_build": False, + "k": 5, + "metric": "sqeuclidean", + }, + { + "intermediate_graph_degree": 128, + "graph_degree": 32, + "add_data_on_build": True, + "k": 10, + "metric": "inner_product", + }, + ], +) +def test_cagra_index_params(params): + # Note that inner_product tests use normalized input which we cannot + # represent in int8, therefore we test only sqeuclidean metric here. + run_cagra_build_search_test( + k=params["k"], + metric=params["metric"], + graph_degree=params["graph_degree"], + intermediate_graph_degree=params["intermediate_graph_degree"], + compare=False, + ) + + +@pytest.mark.parametrize( + "params", + [ + { + "max_queries": 100, + "itopk_size": 32, + "max_iterations": 100, + "algo": "single_cta", + "team_size": 0, + "search_width": 1, + "min_iterations": 1, + "thread_block_size": 64, + "hashmap_mode": "hash", + "hashmap_min_bitlen": 0.2, + "hashmap_max_fill_rate": 0.5, + "num_random_samplings": 1, + }, + { + "max_queries": 10, + "itopk_size": 128, + "max_iterations": 0, + "algo": "multi_cta", + "team_size": 8, + "search_width": 2, + "min_iterations": 10, + "thread_block_size": 0, + "hashmap_mode": "auto", + "hashmap_min_bitlen": 0.9, + "hashmap_max_fill_rate": 0.5, + "num_random_samplings": 10, + }, + { + "max_queries": 0, + "itopk_size": 64, + "max_iterations": 0, + "algo": "multi_kernel", + "team_size": 16, + "search_width": 1, + "min_iterations": 0, + "thread_block_size": 0, + "hashmap_mode": "auto", + "hashmap_min_bitlen": 0, + "hashmap_max_fill_rate": 0.5, + "num_random_samplings": 1, + }, + { + "max_queries": 0, + "itopk_size": 64, + "max_iterations": 0, + "algo": "auto", + "team_size": 32, + "search_width": 4, + "min_iterations": 0, + "thread_block_size": 0, + "hashmap_mode": "small", + "hashmap_min_bitlen": 0, + "hashmap_max_fill_rate": 0.5, + "num_random_samplings": 1, + }, + ], +) +def test_cagra_search_params(params): + # Note that inner_product tests use normalized input which we cannot + # represent in int8, therefore we test only sqeuclidean metric here. + run_cagra_build_search_test(search_params=params) + + +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.ubyte]) +def test_save_load(dtype): + n_rows = 10000 + n_cols = 50 + n_queries = 1000 + + dataset = generate_data((n_rows, n_cols), dtype) + dataset_device = device_ndarray(dataset) + + build_params = cagra.IndexParams() + index = cagra.build(build_params, dataset_device) + + assert index.trained + filename = "my_index.bin" + cagra.save(filename, index) + loaded_index = cagra.load(filename) + + queries = generate_data((n_queries, n_cols), dtype) + + queries_device = device_ndarray(queries) + search_params = cagra.SearchParams() + k = 10 + + distance_dev, neighbors_dev = cagra.search( + search_params, index, queries_device, k + ) + + neighbors = neighbors_dev.copy_to_host() + dist = distance_dev.copy_to_host() + del index + + distance_dev, neighbors_dev = cagra.search( + search_params, loaded_index, queries_device, k + ) + + neighbors2 = neighbors_dev.copy_to_host() + dist2 = distance_dev.copy_to_host() + + assert np.all(neighbors == neighbors2) + assert np.allclose(dist, dist2, rtol=1e-6) diff --git a/python/pylibraft/pylibraft/test/test_doctests.py b/python/pylibraft/pylibraft/test/test_doctests.py index 19e5c5c22f..c75f565236 100644 --- a/python/pylibraft/pylibraft/test/test_doctests.py +++ b/python/pylibraft/pylibraft/test/test_doctests.py @@ -97,8 +97,11 @@ def _find_doctests_in_obj(obj, finder=None, criteria=None): DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.distance)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.matrix.select_k)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors)) -DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.ivf_pq)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.brute_force)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.cagra)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.ivf_flat)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.ivf_pq)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.refine)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.random))