From 9f2a64f4be479b3343a42a89c98cfca3827ae7c0 Mon Sep 17 00:00:00 2001 From: Victor Lafargue Date: Thu, 16 Mar 2023 10:37:11 +0100 Subject: [PATCH] Update and standardize IVF indexes API (#1328) Update and standardize IVF indexes API + edits on specializations Authors: - Victor Lafargue (https://github.com/viclafargue) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Artem M. Chirkin (https://github.com/achirkin) URL: https://github.com/rapidsai/raft/pull/1328 --- cpp/bench/neighbors/knn.cuh | 4 +- cpp/include/raft/neighbors/ivf_pq.cuh | 175 ++++++++++-------- .../raft/neighbors/ivf_pq_serialize.cuh | 2 - .../raft/neighbors/specializations/ivf_pq.cuh | 81 ++++---- .../raft/spatial/knn/detail/ann_quantized.cuh | 2 +- cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 79 ++++---- cpp/src/distance/neighbors/ivfpq_build.cu | 59 +++--- .../neighbors/ivfpq_search_float_int64_t.cu | 20 +- .../neighbors/ivfpq_search_int8_t_int64_t.cu | 20 +- .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 20 +- .../ivfpq_extend_float_int64_t.cu | 23 +-- .../ivfpq_extend_int8_t_int64_t.cu | 23 +-- .../ivfpq_extend_uint8_t_int64_t.cu | 23 +-- .../ivfpq_search_float_int64_t.cu | 3 +- .../ivfpq_search_int8_t_int64_t.cu | 3 +- .../ivfpq_search_uint8_t_int64_t.cu | 3 +- cpp/test/neighbors/ann_ivf_flat.cuh | 2 +- cpp/test/neighbors/ann_ivf_pq.cuh | 25 +-- python/pylibraft/pylibraft/common/mdspan.pxd | 4 + python/pylibraft/pylibraft/common/mdspan.pyx | 6 + .../pylibraft/pylibraft/common/optional.pxd | 56 ++++++ .../neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 16 +- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 18 +- 23 files changed, 378 insertions(+), 289 deletions(-) create mode 100644 python/pylibraft/pylibraft/common/optional.pxd diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index 259d39d8f7..fe8c2c10d8 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -178,7 +178,6 @@ struct ivf_pq_knn { { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; - auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); } @@ -189,13 +188,12 @@ struct ivf_pq_knn { IdxT* out_idxs) { search_params.n_probes = 20; - auto queries_view = raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); auto dists_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); raft::neighbors::ivf_pq::search( - handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view); + handle, search_params, *index, queries_view, idxs_view, dists_view); } }; diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index db60af847a..4a12ca72a4 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -45,16 +45,16 @@ namespace raft::neighbors::ivf_pq { * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param params configure the index building + * @param[in] handle + * @param[in] params configure the index building * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] * * @return the constructed ivf-pq index */ template -inline auto build(raft::device_resources const& handle, +index build(raft::device_resources const& handle, const index_params& params, - raft::device_matrix_view dataset) -> index + raft::device_matrix_view dataset) { IdxT n_rows = dataset.extent(0); IdxT dim = dataset.extent(1); @@ -62,37 +62,38 @@ inline auto build(raft::device_resources const& handle, } /** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are unchanged. - * + * @brief Extend the index with the new data. + * * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param orig_index original index - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. - * - * @return the constructed extended ivf-pq index + * @param[inout] idx */ template -inline auto extend(raft::device_resources const& handle, - const index& orig_index, +index extend(raft::device_resources const& handle, raft::device_matrix_view new_vectors, - raft::device_matrix_view new_indices) -> index + std::optional> new_indices, + const index& idx) { - IdxT n_rows = new_vectors.extent(0); - ASSERT(n_rows == new_indices.extent(0), - "new_vectors and new_indices have different number of rows"); - ASSERT(new_vectors.extent(1) == orig_index.dim(), + ASSERT(new_vectors.extent(1) == idx.dim(), "new_vectors should have the same dimension as the index"); - return detail::extend( - handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + return detail::extend(handle, + idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } /** @@ -101,20 +102,33 @@ inline auto extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param[inout] index - * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] handle + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()] * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx */ template -inline void extend(raft::device_resources const& handle, - index* index, - raft::device_matrix_view new_vectors, - raft::device_matrix_view new_indices) +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) { - *index = extend(handle, *index, new_vectors, new_indices); + ASSERT(new_vectors.extent(1) == idx->dim(), + "new_vectors should have the same dimension as the index"); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + *idx = detail::extend(handle, + *idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } /** @@ -133,34 +147,39 @@ inline void extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * - * @param handle - * @param params configure the search - * @param index ivf-pq constructed index + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param k the number of neighbors to find for each query. * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] */ template -inline void search(raft::device_resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - uint32_t k, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) { - IdxT n_queries = queries.extent(0); - bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0)); - ASSERT(check_n_rows, - "queries, neighbors and distances parameters have inconsistent number of rows"); + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + std::uint32_t k = neighbors.extent(1); return detail::search(handle, params, - index, + idx, queries.data_handle(), - n_queries, + static_cast(queries.extent(0)), k, neighbors.data_handle(), distances.data_handle(), @@ -193,11 +212,11 @@ inline void search(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param params configure the index building + * @param[in] handle + * @param[in] params configure the index building * @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim] - * @param n_rows the number of samples - * @param dim the dimensionality of the data + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data * * @return the constructed ivf-pq index */ @@ -233,24 +252,24 @@ auto build(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param orig_index original index - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] handle + * @param[inout] idx original index + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples + * @param[in] n_rows the number of samples * * @return the constructed extended ivf-pq index */ template auto extend(raft::device_resources const& handle, - const index& orig_index, + const index& idx, const T* new_vectors, const IdxT* new_indices, IdxT n_rows) -> index { - return detail::extend(handle, orig_index, new_vectors, new_indices, n_rows); + return detail::extend(handle, idx, new_vectors, new_indices, n_rows); } /** @@ -259,22 +278,22 @@ auto extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param[inout] index - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] handle + * @param[inout] idx + * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()] * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`idx.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples + * @param[in] n_rows the number of samples */ template void extend(raft::device_resources const& handle, - index* index, + index* idx, const T* new_vectors, const IdxT* new_indices, IdxT n_rows) { - detail::extend(handle, index, new_vectors, new_indices, n_rows); + detail::extend(handle, idx, new_vectors, new_indices, n_rows); } /** @@ -307,22 +326,22 @@ void extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * - * @param handle - * @param params configure the search - * @param index ivf-pq constructed index + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-pq constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param n_queries the batch size - * @param k the number of neighbors to find for each query. + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param mr an optional memory resource to use across the searches (you can provide a large enough - * memory pool here to avoid memory allocations within search). + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). */ template void search(raft::device_resources const& handle, const search_params& params, - const index& index, + const index& idx, const T* queries, uint32_t n_queries, uint32_t k, @@ -330,7 +349,7 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr); + return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); } } // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/neighbors/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh index 98b59fd5e1..2dd9d39d73 100644 --- a/cpp/include/raft/neighbors/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh @@ -47,7 +47,6 @@ namespace raft::neighbors::ivf_pq { * @param[in] os output stream * @param[in] index IVF-PQ index * - * @return raft::neighbors::ivf_pq::index */ template void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) @@ -77,7 +76,6 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind * @param[in] filename the file name for saving the index * @param[in] index IVF-PQ index * - * @return raft::neighbors::ivf_pq::index */ template void serialize(raft::device_resources const& handle, diff --git a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh index a651be90db..55a7cd5858 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh @@ -24,37 +24,54 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_INST(T, IdxT) \ - extern template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->index; \ - extern template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->index; \ - extern template void extend(raft::device_resources const& handle, \ - index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); \ - extern template void search(raft::device_resources const&, \ - const search_params&, \ - const index&, \ - const T*, \ - uint32_t, \ - uint32_t, \ - IdxT*, \ - float*, \ - rmm::mr::device_memory_resource*); -RAFT_INST(float, int64_t); -RAFT_INST(int8_t, int64_t); -RAFT_INST(uint8_t, int64_t); - -#undef RAFT_INST +#ifdef RAFT_DECL_BUILD_EXTEND +#undef RAFT_DECL_BUILD_EXTEND +#endif + +#ifdef RAFT_DECL_SEARCH +#undef RAFT_DECL_SEARCH +#endif + +// We define overloads for build and extend with void return type. This is used in the Cython +// wrappers, where exception handling is not compatible with return type that has nontrivial +// constructor. +#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ + extern template auto build(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::index_params&, \ + raft::device_matrix_view) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template auto extend( \ + raft::device_resources const&, \ + raft::device_matrix_view, \ + std::optional>, \ + const raft::neighbors::ivf_pq::index&) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template void extend( \ + raft::device_resources const&, \ + raft::device_matrix_view, \ + std::optional>, \ + raft::neighbors::ivf_pq::index*); + +RAFT_DECL_BUILD_EXTEND(float, int64_t) +RAFT_DECL_BUILD_EXTEND(int8_t, int64_t) +RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t) + +#undef RAFT_DECL_BUILD_EXTEND + +#define RAFT_DECL_SEARCH(T, IdxT) \ + extern template void search(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::search_params&, \ + const raft::neighbors::ivf_pq::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view); + +RAFT_DECL_SEARCH(float, int64_t); +RAFT_DECL_SEARCH(int8_t, int64_t); +RAFT_DECL_SEARCH(uint8_t, int64_t); + +#undef RAFT_DECL_SEARCH } // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 8238c99065..cc95b32cee 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -120,7 +120,7 @@ void approx_knn_search(raft::device_resources const& handle, auto indices_view = raft::make_device_matrix_view(indices, n, k); auto distances_view = raft::make_device_matrix_view(distances, n, k); neighbors::ivf_pq::search( - handle, params, *index->ivf_pq, query_view, k, indices_view, distances_view); + handle, params, *index->ivf_pq, query_view, indices_view, distances_view); } else { RAFT_FAIL("The model is not trained"); } diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index 00a97931fb..fb22d7657e 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -20,51 +20,50 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_SEARCH(T, IdxT) \ - void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - raft::device_matrix_view, \ - uint32_t, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_INST_SEARCH(float, int64_t); -RAFT_INST_SEARCH(int8_t, int64_t); -RAFT_INST_SEARCH(uint8_t, int64_t); - -#undef RAFT_INST_SEARCH - // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::ivf_pq::index; \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->raft::neighbors::ivf_pq::index; \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::ivf_pq::index* idx); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices); +#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ + [[nodiscard]] raft::neighbors::ivf_pq::index build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + [[nodiscard]] raft::neighbors::ivf_pq::index extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); + +RAFT_DECL_BUILD_EXTEND(float, int64_t); +RAFT_DECL_BUILD_EXTEND(int8_t, int64_t); +RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t); + +#undef RAFT_DECL_BUILD_EXTEND + +#define RAFT_DECL_SEARCH(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); -RAFT_INST_BUILD_EXTEND(float, int64_t) -RAFT_INST_BUILD_EXTEND(int8_t, int64_t) -RAFT_INST_BUILD_EXTEND(uint8_t, int64_t) +RAFT_DECL_SEARCH(float, int64_t); +RAFT_DECL_SEARCH(int8_t, int64_t); +RAFT_DECL_SEARCH(uint8_t, int64_t); -#undef RAFT_INST_BUILD_EXTEND +#undef RAFT_DECL_SEARCH /** * Save the index to file. diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index 4a4d16cbac..8759ca2587 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -20,36 +20,35 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::ivf_pq::index \ - { \ - return raft::neighbors::ivf_pq::build(handle, params, dataset); \ - } \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->raft::neighbors::ivf_pq::index \ - { \ - return raft::neighbors::ivf_pq::extend(handle, orig_index, new_vectors, new_indices); \ - } \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::ivf_pq::index* idx) \ - { \ - *idx = raft::neighbors::ivf_pq::build(handle, params, dataset); \ - } \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - { \ - raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices); \ +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + raft::neighbors::ivf_pq::index build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset) \ + { \ + return raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_pq::index* idx) \ + { \ + *idx = raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + raft::neighbors::ivf_pq::index extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx) \ + { \ + return raft::neighbors::ivf_pq::extend(handle, new_vectors, new_indices, idx); \ + } \ + void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx) \ + { \ + raft::neighbors::ivf_pq::extend(handle, new_vectors, new_indices, idx); \ } RAFT_INST_BUILD_EXTEND(float, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu index 47dbc48e44..91093d3a39 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu @@ -21,17 +21,15 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - uint32_t k, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, k, neighbors, distances); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(float, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu index 45218b215c..e1552c0b27 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -21,17 +21,15 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - uint32_t k, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, k, neighbors, distances); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu index b7f028002f..85195a7551 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -21,17 +21,15 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - uint32_t k, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, k, neighbors, distances); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(uint8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu index ccd773832d..4cc616f32d 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu @@ -19,17 +19,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const index& idx) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu index ea182da5f0..a3117aae0f 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu @@ -19,17 +19,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const index& idx) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu index 7c5ef6af8f..a5e3d68569 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu @@ -19,17 +19,18 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view new_vectors, \ - raft::device_matrix_view new_indices); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const index& idx) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(uint8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu index fd8031f269..92a4d89e6b 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu @@ -22,9 +22,8 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ const search_params& params, \ - const index& index, \ + const index& idx, \ raft::device_matrix_view queries, \ - uint32_t k, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu index 43c5953c25..62a8b48ad5 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu @@ -22,9 +22,8 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ const search_params& params, \ - const index& index, \ + const index& idx, \ raft::device_matrix_view queries, \ - uint32_t k, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu index 2e365cc164..3bcf134a22 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu @@ -22,9 +22,8 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ const search_params& params, \ - const index& index, \ + const index& idx, \ raft::device_matrix_view queries, \ - uint32_t k, \ raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index eb0f84d104..002e4f07d2 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -357,4 +357,4 @@ const std::vector> inputs = { raft::distance::DistanceType::InnerProduct, false}}; -} // namespace raft::neighbors::ivf_flat +} // namespace raft::neighbors::ivf_flat \ No newline at end of file diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 37b3e91434..c368192b03 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -178,7 +178,7 @@ class ivf_pq_test : public ::testing::TestWithParam { handle_.sync_stream(stream_); } - auto build_only() + index build_only() { auto ipams = ps.index_params; ipams.add_data_on_build = true; @@ -188,7 +188,7 @@ class ivf_pq_test : public ::testing::TestWithParam { return ivf_pq::build(handle_, ipams, index_view); } - auto build_2_extends() + index build_2_extends() { rmm::device_uvector db_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), @@ -207,18 +207,21 @@ class ivf_pq_test : public ::testing::TestWithParam { auto database_view = raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_pq::build(handle_, ipams, database_view); + auto idx = ivf_pq::build(handle_, ipams, database_view); auto vecs_2_view = raft::make_device_matrix_view(vecs_2, size_2, ps.dim); auto inds_2_view = raft::make_device_matrix_view(inds_2, size_2, 1); - ivf_pq::extend(handle_, &index, vecs_2_view, inds_2_view); - - auto vecs_1_view = raft::make_device_matrix_view(vecs_1, size_1, ps.dim); - auto inds_1_view = raft::make_device_matrix_view(inds_1, size_1, 1); - return ivf_pq::extend(handle_, index, vecs_1_view, inds_1_view); + ivf_pq::extend(handle_, vecs_2_view, inds_2_view, &idx); + + auto vecs_1_view = + raft::make_device_matrix_view(vecs_1, size_1, ps.dim); + auto inds_1_view = + raft::make_device_matrix_view(inds_1, size_1, 1); + ivf_pq::extend(handle_, vecs_1_view, inds_1_view, &idx); + return idx; } - auto build_serialize() + index build_serialize() { ivf_pq::serialize(handle_, "ivf_pq_index", build_only()); return ivf_pq::deserialize(handle_, "ivf_pq_index"); @@ -227,7 +230,7 @@ class ivf_pq_test : public ::testing::TestWithParam { template void run(BuildIndex build_index) { - auto index = build_index(); + index index = build_index(); size_t queries_size = ps.num_queries * ps.k; std::vector indices_ivf_pq(queries_size); @@ -244,7 +247,7 @@ class ivf_pq_test : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_ivf_pq_dev.data(), ps.num_queries, ps.k); ivf_pq::search( - handle_, ps.search_params, index, query_view, ps.k, inds_view, dists_view); + handle_, ps.search_params, index, query_view, inds_view, dists_view); update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 98521e48fa..970b59dda3 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -24,6 +24,7 @@ from libcpp.string cimport string from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major from pylibraft.common.handle cimport device_resources +from pylibraft.common.optional cimport make_optional, optional cdef device_matrix_view[float, int64_t, row_major] get_dmv_float( @@ -37,3 +38,6 @@ cdef device_matrix_view[int8_t, int64_t, row_major] get_dmv_int8( cdef device_matrix_view[int64_t, int64_t, row_major] get_dmv_int64( array, check_shape) except * + +cdef optional[device_matrix_view[int64_t, int64_t, row_major]] create_optional( + device_matrix_view[int64_t, int64_t, row_major]& dmv) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index 9f04545a0f..74c6722cab 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -40,6 +40,7 @@ from pylibraft.common.cpp.mdspan cimport ( serialize_mdspan, ) from pylibraft.common.handle cimport device_resources +from pylibraft.common.optional cimport make_optional, optional from pylibraft.common import DeviceResources @@ -190,3 +191,8 @@ cdef device_matrix_view[int64_t, int64_t, row_major] \ shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) return make_device_matrix_view[int64_t, int64_t, row_major]( cai.data, shape[0], shape[1]) + + +cdef optional[device_matrix_view[int64_t, int64_t, row_major]] \ + create_optional(device_matrix_view[int64_t, int64_t, row_major]& dmv) except *: # noqa: E501 + return make_optional[device_matrix_view[int64_t, int64_t, row_major]](dmv) diff --git a/python/pylibraft/pylibraft/common/optional.pxd b/python/pylibraft/pylibraft/common/optional.pxd new file mode 100644 index 0000000000..f29f435693 --- /dev/null +++ b/python/pylibraft/pylibraft/common/optional.pxd @@ -0,0 +1,56 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +# Code from Cython libcpp + +from libcpp cimport bool + + +cdef extern from "" namespace "std" nogil: + cdef cppclass nullopt_t: + nullopt_t() + + cdef nullopt_t nullopt + + cdef cppclass optional[T]: + ctypedef T value_type + optional() + optional(nullopt_t) + optional(optional&) except + + optional(T&) except + + bool has_value() + T& value() + T& value_or[U](U& default_value) + void swap(optional&) + void reset() + T& emplace(...) + T& operator*() + optional& operator=(optional&) + optional& operator=[U](U&) + bool operator bool() + bool operator!() + bool operator==[U](optional&, U&) + bool operator!=[U](optional&, U&) + bool operator<[U](optional&, U&) + bool operator>[U](optional&, U&) + bool operator<=[U](optional&, U&) + bool operator>=[U](optional&, U&) + + optional[T] make_optional[T](...) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index dcc0371421..d04d833f3b 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -31,6 +31,7 @@ from rmm._lib.memory_resource cimport device_memory_resource from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major from pylibraft.common.handle cimport device_resources +from pylibraft.common.optional cimport optional from pylibraft.distance.distance_type cimport DistanceType @@ -124,28 +125,27 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[float, int64_t, row_major] new_vectors, - device_matrix_view[int64_t, int64_t, row_major] new_indices) except + + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[int8_t, int64_t, row_major] new_vectors, - device_matrix_view[int64_t, int64_t, row_major] new_indices) except + + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[uint8_t, int64_t, row_major] new_vectors, - device_matrix_view[int64_t, int64_t, row_major] new_indices) except + + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void search( const device_resources& handle, const search_params& params, const index[int64_t]& index, device_matrix_view[float, int64_t, row_major] queries, - uint32_t k, device_matrix_view[int64_t, int64_t, row_major] neighbors, device_matrix_view[float, int64_t, row_major] distances) except + @@ -154,7 +154,6 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ const search_params& params, const index[int64_t]& index, device_matrix_view[int8_t, int64_t, row_major] queries, - uint32_t k, device_matrix_view[int64_t, int64_t, row_major] neighbors, device_matrix_view[float, int64_t, row_major] distances) except + @@ -163,7 +162,6 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ const search_params& params, const index[int64_t]& index, device_matrix_view[uint8_t, int64_t, row_major] queries, - uint32_t k, device_matrix_view[int64_t, int64_t, row_major] neighbors, device_matrix_view[float, int64_t, row_major] distances) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 860a1ea27c..4f4d2c75a4 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -50,8 +50,9 @@ from rmm._lib.memory_resource cimport ( ) cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq -from pylibraft.common.cpp.mdspan cimport device_matrix_view +from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major from pylibraft.common.mdspan cimport ( + create_optional, get_dmv_float, get_dmv_int8, get_dmv_int64, @@ -519,21 +520,21 @@ def extend(Index index, new_vectors, new_indices, handle=None): if vecs_dt == np.float32: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_float(vecs_cai, check_shape=True), - get_dmv_int64(idx_cai, check_shape=False)) + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) elif vecs_dt == np.int8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_int8(vecs_cai, check_shape=True), - get_dmv_int64(idx_cai, check_shape=False)) + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) elif vecs_dt == np.uint8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_uint8(vecs_cai, check_shape=True), - get_dmv_int64(idx_cai, check_shape=False)) + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) else: raise TypeError("query dtype %s not supported" % vecs_dt) @@ -723,7 +724,6 @@ def search(SearchParams search_params, params, deref(index.index), get_dmv_float(queries_cai, check_shape=True), - k, get_dmv_int64(neighbors_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.byte: @@ -732,7 +732,6 @@ def search(SearchParams search_params, params, deref(index.index), get_dmv_int8(queries_cai, check_shape=True), - k, get_dmv_int64(neighbors_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.ubyte: @@ -741,7 +740,6 @@ def search(SearchParams search_params, params, deref(index.index), get_dmv_uint8(queries_cai, check_shape=True), - k, get_dmv_int64(neighbors_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True)) else: