From 6e34d8ad297aee4359552d685d0d2ae5d10e942b Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 14 Mar 2023 12:04:30 +0100 Subject: [PATCH] addressing reviews --- cpp/bench/neighbors/knn.cuh | 10 +- cpp/include/raft/neighbors/ivf_flat.cuh | 410 +++++++++--------- cpp/include/raft/neighbors/ivf_pq.cuh | 124 ++---- .../raft/neighbors/specializations/ivf_pq.cuh | 55 ++- .../raft/spatial/knn/detail/ann_quantized.cuh | 9 +- cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 52 +-- cpp/src/distance/neighbors/ivfpq_build.cu | 56 +-- .../neighbors/ivfpq_search_float_int64_t.cu | 6 +- .../neighbors/ivfpq_search_int8_t_int64_t.cu | 6 +- .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 6 +- .../ivfpq_build_float_int64_t.cu | 10 +- .../ivfpq_build_int8_t_int64_t.cu | 10 +- .../ivfpq_build_uint8_t_int64_t.cu | 10 +- .../ivfpq_extend_float_int64_t.cu | 8 +- .../ivfpq_extend_int8_t_int64_t.cu | 8 +- .../ivfpq_extend_uint8_t_int64_t.cu | 8 +- .../ivfpq_search_float_int64_t.cu | 6 +- .../ivfpq_search_int8_t_int64_t.cu | 6 +- .../ivfpq_search_uint8_t_int64_t.cu | 6 +- cpp/test/neighbors/ann_ivf_flat.cuh | 38 +- cpp/test/neighbors/ann_ivf_pq.cuh | 27 +- .../neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 36 +- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 36 +- 23 files changed, 444 insertions(+), 499 deletions(-) diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index 37d4471852..fe8c2c10d8 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -178,8 +178,8 @@ struct ivf_pq_knn { { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; - index.emplace(raft::neighbors::ivf_pq::build( - handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); + auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); + index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); } void search(const raft::device_resources& handle, @@ -188,8 +188,12 @@ struct ivf_pq_knn { IdxT* out_idxs) { search_params.n_probes = 20; + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto dists_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); raft::neighbors::ivf_pq::search( - handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists); + handle, search_params, *index, queries_view, idxs_view, dists_view); } }; diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index 34080038f5..dd16813737 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -28,6 +28,13 @@ namespace raft::neighbors::ivf_flat { +namespace detail = raft::spatial::knn::ivf_flat::detail; + +/** + * @defgroup ivf_flat IVF Flat Algorithm + * @{ + */ + /** * @brief Build the index from the dataset for efficient search. * @@ -42,11 +49,11 @@ namespace raft::neighbors::ivf_flat { * // use default index parameters * ivf_flat::index_params index_params; * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * auto index = ivf_flat::build(handle, dataset, index_params); * // use default search parameters * ivf_flat::search_params search_params; * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k); * @endcode * * @tparam T data element type @@ -55,78 +62,68 @@ namespace raft::neighbors::ivf_flat { * @param[in] handle * @param[in] params configure the index building * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * @param[in] n_rows the number of samples - * @param[in] dim the dimensionality of the data * * @return the constructed ivf-flat index */ template auto build(raft::device_resources const& handle, const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index + raft::device_matrix_view dataset) -> index { - return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); + IdxT n_rows = dataset.extent(0); + IdxT dim = dataset.extent(1); + return detail::build(handle, params, dataset.data_handle(), n_rows, dim); } /** - * @defgroup ivf_flat IVF Flat Algorithm - * @{ - */ - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct + * @brief Extend the index in-place with the new data. * * Usage example: * @code{.cpp} * using namespace raft::neighbors; - * // use default index parameters * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, dataset, index_params); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k); + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * // fill the index with the data + * ivf_flat::extend(handle, index_empty, dataset); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * - * @return the constructed ivf-flat index + * @param[inout] idx + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. */ -template -auto build(raft::device_resources const& handle, - raft::device_matrix_view dataset, - const index_params& params) -> index +template +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index& idx) { - return raft::spatial::knn::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} + ASSERT(new_vectors.extent(1) == idx.dim(), + "new_vectors should have the same dimension as the index"); -/** @} */ + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + return detail::extend(handle, + idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); +} /** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. + * @brief Extend the index in-place with the new data. * * Usage example: * @code{.cpp} @@ -135,92 +132,158 @@ auto build(raft::device_resources const& handle, * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * ivf_flat::extend(handle, index_empty, dataset); * @endcode * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * * @param[in] handle - * @param[in] orig_index original index + * @param[inout] idx * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows number of rows in `new_vectors` - * - * @return the constructed extended ivf-flat index */ template -auto extend(raft::device_resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) { - return raft::spatial::knn::ivf_flat::detail::extend( - handle, orig_index, new_vectors, new_indices, n_rows); + ASSERT(new_vectors.extent(1) == idx->dim(), + "new_vectors should have the same dimension as the index"); + + IdxT n_rows = new_vectors.extent(0); + if (new_indices.has_value()) { + ASSERT(n_rows == new_indices.value().extent(0), + "new_vectors and new_indices have different number of rows"); + } + + *idx = detail::extend(handle, + *idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } /** - * @ingroup ivf_flat - * @{ + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, index, queries1, out_inds1, out_dists1, search_params); + * ivf_flat::search(handle, index, queries2, out_inds2, out_dists2, search_params); + * ivf_flat::search(handle, index, queries3, out_inds3, out_dists3, search_params); + * ... + * @endcode + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] params configure the search + * @param[in] idx ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] */ +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + RAFT_EXPECTS( + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == idx.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + IdxT n_queries = queries.extent(0); + uint32_t k = neighbors.extent(1); + return detail::search(handle, + params, + idx, + queries.data_handle(), + n_queries, + k, + neighbors.data_handle(), + distances.data_handle(), + handle.get_workspace_resource()); +} + +/** @} */ /** - * @brief Build a new index containing the data of the original plus new extra vectors. + * @brief Build the index from the dataset for efficient search. * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct * * Usage example: * @code{.cpp} * using namespace raft::neighbors; + * // use default index parameters * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); - * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset); + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * * @param[in] handle - * @param[in] orig_index original index - * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. + * @param[in] params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data * - * @return the constructed extended ivf-flat index + * @return the constructed ivf-flat index */ -template -auto extend(raft::device_resources const& handle, - const index& orig_index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) - -> index +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index { - return extend( - handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); + return detail::build(handle, params, dataset, n_rows, dim); } -/** @} */ +/** @} */ // end group ivf_flat /** - * @brief Extend the index in-place with the new data. + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. * * Usage example: * @code{.cpp} @@ -231,35 +294,32 @@ auto extend(raft::device_resources const& handle, * // train the index from a [N, D] dataset * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); * @endcode * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param[inout] index + * @param[in] handle + * @param[in] idx original index * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples + * @param[in] n_rows number of rows in `new_vectors` + * + * @return the constructed extended ivf-flat index */ template -void extend(raft::device_resources const& handle, - index* index, +auto extend(raft::device_resources const& handle, + const index& idx, const T* new_vectors, const IdxT* new_indices, - IdxT n_rows) + IdxT n_rows) -> index { - *index = extend(handle, *index, new_vectors, new_indices, n_rows); + return detail::extend(handle, idx, new_vectors, new_indices, n_rows); } -/** - * @ingroup ivf_flat - * @{ - */ - /** * @brief Extend the index in-place with the new data. * @@ -270,38 +330,32 @@ void extend(raft::device_resources const& handle, * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset); + * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * - * @param[in] handle - * @param[inout] index + * @param handle + * @param[inout] idx * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows the number of samples */ -template +template void extend(raft::device_resources const& handle, - index* index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) + index& idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) { - *index = extend(handle, - *index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); + idx = detail::extend(handle, idx, new_vectors, new_indices, n_rows); } -/** @} */ - /** * @brief Search ANN using the constructed index. * @@ -332,22 +386,22 @@ void extend(raft::device_resources const& handle, * @tparam T data element type * @tparam IdxT type of the indices * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index + * @param handle + * @param params configure the search + * @param idx ivf-pq constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. + * @param n_queries the batch size + * @param k the number of neighbors to find for each query. * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). + * @param mr an optional memory resource to use across the searches (you can provide a large enough + * memory pool here to avoid memory allocations within search). */ template void search(raft::device_resources const& handle, const search_params& params, - const index& index, + const index& idx, const T* queries, uint32_t n_queries, uint32_t k, @@ -355,79 +409,7 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return raft::spatial::knn::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr); + return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); } -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, index, queries1, out_inds1, out_dists1, search_params); - * ivf_flat::search(handle, index, queries2, out_inds2, out_dists2, search_params); - * ivf_flat::search(handle, index, queries3, out_inds3, out_dists3, search_params); - * ... - * @endcode - * - * @tparam value_t data element type - * @tparam idx_t type of the indices - * @tparam int_t precision / type of integral arguments - * @tparam matrix_idx_t matrix indexing type - * - * @param[in] handle - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] params configure the search - */ -template -void search(raft::device_resources const& handle, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - const search_params& params) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - std::uint32_t k = neighbors.extent(1); - return search(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - k, - neighbors.data_handle(), - distances.data_handle(), - handle.get_workspace_resource()); -} - -/** @} */ - } // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 549bf606a4..fd293672de 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -46,15 +46,15 @@ namespace raft::neighbors::ivf_pq { * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] * @param params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] * * @return the constructed ivf-pq index */ template auto build(raft::device_resources const& handle, - raft::device_matrix_view dataset, - const index_params& params) -> index + const index_params& params, + raft::device_matrix_view dataset) -> index { IdxT n_rows = dataset.extent(0); IdxT dim = dataset.extent(1); @@ -62,61 +62,25 @@ auto build(raft::device_resources const& handle, } /** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param[inout] index - * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] - * @param params configure the index building - * - * @return the constructed ivf-pq index - */ -template -void build(raft::device_resources const& handle, - index* index, - raft::device_matrix_view dataset, - const index_params& params) -{ - IdxT n_rows = dataset.extent(0); - IdxT dim = dataset.extent(1); - *index = detail::build(handle, params, dataset.data_handle(), n_rows, dim); -} - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are unchanged. - * + * @brief Extend the index with the new data. + * * * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param orig_index original index + * @param[inout] idx * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * - * @return the constructed extended ivf-pq index */ template -auto extend(raft::device_resources const& handle, - const index& orig_index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = - std::nullopt) -> index +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& idx) { - ASSERT(new_vectors.extent(1) == orig_index.dim(), + ASSERT(new_vectors.extent(1) == idx.dim(), "new_vectors should have the same dimension as the index"); IdxT n_rows = new_vectors.extent(0); @@ -126,7 +90,7 @@ auto extend(raft::device_resources const& handle, } return detail::extend(handle, - orig_index, + idx, new_vectors.data_handle(), new_indices.has_value() ? new_indices.value().data_handle() : nullptr, n_rows); @@ -139,20 +103,19 @@ auto extend(raft::device_resources const& handle, * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param[inout] index + * @param[inout] idx * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. */ template -void extend( - raft::device_resources const& handle, - index* index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) { - ASSERT(new_vectors.extent(1) == index->dim(), + ASSERT(new_vectors.extent(1) == idx->dim(), "new_vectors should have the same dimension as the index"); IdxT n_rows = new_vectors.extent(0); @@ -161,11 +124,11 @@ void extend( "new_vectors and new_indices have different number of rows"); } - *index = extend(handle, - *index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - n_rows); + *idx = detail::extend(handle, + *idx, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + n_rows); } /** @@ -185,21 +148,21 @@ void extend( * @tparam IdxT type of the indices * * @param handle - * @param index ivf-pq constructed index + * @param params configure the search + * @param idx ivf-pq constructed index * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, * k] - * @param params configure the search */ template void search(raft::device_resources const& handle, - const index& index, + const search_params& params, + const index& idx, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - const search_params& params) + raft::device_matrix_view distances) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), @@ -208,13 +171,13 @@ void search(raft::device_resources const& handle, RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == index.dim(), + RAFT_EXPECTS(queries.extent(1) == idx.dim(), "Number of query dimensions should equal number of dimensions in the index."); std::uint32_t k = neighbors.extent(1); return detail::search(handle, params, - index, + idx, queries.data_handle(), static_cast(queries.extent(0)), k, @@ -290,10 +253,10 @@ auto build(raft::device_resources const& handle, * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param orig_index original index + * @param idx original index * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. * @param n_rows the number of samples * @@ -301,12 +264,12 @@ auto build(raft::device_resources const& handle, */ template auto extend(raft::device_resources const& handle, - const index& orig_index, + const index& idx, const T* new_vectors, const IdxT* new_indices, IdxT n_rows) -> index { - return detail::extend(handle, orig_index, new_vectors, new_indices, n_rows); + return detail::extend(handle, idx, new_vectors, new_indices, n_rows); } /** @@ -316,21 +279,22 @@ auto extend(raft::device_resources const& handle, * @tparam IdxT type of the indices in the source dataset * * @param handle - * @param[inout] index + * @param[inout] idx * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * If the original index is empty (`index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. + * @param[inout] idx * @param n_rows the number of samples */ template void extend(raft::device_resources const& handle, - index* index, + index* idx, const T* new_vectors, const IdxT* new_indices, IdxT n_rows) { - detail::extend(handle, index, new_vectors, new_indices, n_rows); + detail::extend(handle, idx, new_vectors, new_indices, n_rows); } /** @@ -365,7 +329,7 @@ void extend(raft::device_resources const& handle, * * @param handle * @param params configure the search - * @param index ivf-pq constructed index + * @param idx ivf-pq constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] * @param n_queries the batch size * @param k the number of neighbors to find for each query. @@ -378,7 +342,7 @@ void extend(raft::device_resources const& handle, template void search(raft::device_resources const& handle, const search_params& params, - const index& index, + const index& idx, const T* queries, uint32_t n_queries, uint32_t k, @@ -386,7 +350,7 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr) { - return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr); + return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); } } // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh index 352c75bc89..55a7cd5858 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh @@ -35,29 +35,24 @@ namespace raft::neighbors::ivf_pq { // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ - extern template auto build(raft::device_resources const&, \ - raft::device_matrix_view, \ - const raft::neighbors::ivf_pq::index_params&) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template void build(raft::device_resources const&, \ - index*, \ - raft::device_matrix_view, \ - const raft::neighbors::ivf_pq::index_params&); \ - \ - extern template auto extend( \ - raft::device_resources const&, \ - const index&, \ - raft::device_matrix_view, \ - std::optional>) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template void extend( \ - raft::device_resources const&, \ - index*, \ - raft::device_matrix_view, \ - std::optional>); +#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ + extern template auto build(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::index_params&, \ + raft::device_matrix_view) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template auto extend( \ + raft::device_resources const&, \ + raft::device_matrix_view, \ + std::optional>, \ + const raft::neighbors::ivf_pq::index&) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template void extend( \ + raft::device_resources const&, \ + raft::device_matrix_view, \ + std::optional>, \ + raft::neighbors::ivf_pq::index*); RAFT_DECL_BUILD_EXTEND(float, int64_t) RAFT_DECL_BUILD_EXTEND(int8_t, int64_t) @@ -65,13 +60,13 @@ RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t) #undef RAFT_DECL_BUILD_EXTEND -#define RAFT_DECL_SEARCH(T, IdxT) \ - extern template void search(raft::device_resources const&, \ - const index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - const search_params&); +#define RAFT_DECL_SEARCH(T, IdxT) \ + extern template void search(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::search_params&, \ + const raft::neighbors::ivf_pq::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ + raft::device_matrix_view); RAFT_DECL_SEARCH(float, int64_t); RAFT_DECL_SEARCH(int8_t, int64_t); diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 9c511c4acf..cc95b32cee 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -83,7 +83,7 @@ void approx_knn_build_index(raft::device_resources const& handle, auto index_view = raft::make_device_matrix_view(index_array, n, D); index->ivf_pq = std::make_unique>( - neighbors::ivf_pq::build(handle, index_view, params)); + neighbors::ivf_pq::build(handle, params, index_view)); } else { RAFT_FAIL("Unrecognized index type."); } @@ -114,8 +114,13 @@ void approx_knn_search(raft::device_resources const& handle, } else if (index->ivf_pq) { neighbors::ivf_pq::search_params params; params.n_probes = index->nprobe; + + auto query_view = + raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); + auto indices_view = raft::make_device_matrix_view(indices, n, k); + auto distances_view = raft::make_device_matrix_view(distances, n, k); neighbors::ivf_pq::search( - handle, params, *index->ivf_pq, query_array, n, k, indices, distances); + handle, params, *index->ivf_pq, query_view, indices_view, distances_view); } else { RAFT_FAIL("The model is not trained"); } diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index 04664716f0..fb22d7657e 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -23,39 +23,41 @@ namespace raft::runtime::neighbors::ivf_pq { // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const raft::neighbors::ivf_pq::index_params& params); \ - \ - void build(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* index, \ - raft::device_matrix_view dataset, \ - const raft::neighbors::ivf_pq::index_params& params); \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* index, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices); +#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ + [[nodiscard]] raft::neighbors::ivf_pq::index build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + [[nodiscard]] raft::neighbors::ivf_pq::index extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); -RAFT_DECL_BUILD_EXTEND(float, int64_t) -RAFT_DECL_BUILD_EXTEND(int8_t, int64_t) -RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t) +RAFT_DECL_BUILD_EXTEND(float, int64_t); +RAFT_DECL_BUILD_EXTEND(int8_t, int64_t); +RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t); #undef RAFT_DECL_BUILD_EXTEND #define RAFT_DECL_SEARCH(T, IdxT) \ void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& index, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const raft::neighbors::ivf_pq::search_params& params); + raft::device_matrix_view distances); RAFT_DECL_SEARCH(float, int64_t); RAFT_DECL_SEARCH(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index dbd877401e..8759ca2587 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -20,33 +20,35 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const raft::neighbors::ivf_pq::index_params& params) \ - { \ - return raft::neighbors::ivf_pq::build(handle, dataset, params); \ - } \ - void build(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - raft::device_matrix_view dataset, \ - const raft::neighbors::ivf_pq::index_params& params) \ - { \ - raft::neighbors::ivf_pq::build(handle, idx, dataset, params); \ - } \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - return raft::neighbors::ivf_pq::extend(handle, orig_index, new_vectors, new_indices); \ - } \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ - { \ - raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices); \ +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + raft::neighbors::ivf_pq::index build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset) \ + { \ + return raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_pq::index* idx) \ + { \ + *idx = raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + raft::neighbors::ivf_pq::index extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx) \ + { \ + return raft::neighbors::ivf_pq::extend(handle, new_vectors, new_indices, idx); \ + } \ + void extend(raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx) \ + { \ + raft::neighbors::ivf_pq::extend(handle, new_vectors, new_indices, idx); \ } RAFT_INST_BUILD_EXTEND(float, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu index 00392be8a7..91093d3a39 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu @@ -23,13 +23,13 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_SEARCH_INST(T, IdxT) \ void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ const raft::neighbors::ivf_pq::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const raft::neighbors::ivf_pq::search_params& params) \ + raft::device_matrix_view distances) \ { \ - raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(float, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu index 01a26b78b3..e1552c0b27 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -23,13 +23,13 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_SEARCH_INST(T, IdxT) \ void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ const raft::neighbors::ivf_pq::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const raft::neighbors::ivf_pq::search_params& params) \ + raft::device_matrix_view distances) \ { \ - raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu index 5b99b0df9f..85195a7551 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -23,13 +23,13 @@ namespace raft::runtime::neighbors::ivf_pq { #define RAFT_SEARCH_INST(T, IdxT) \ void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ const raft::neighbors::ivf_pq::index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const raft::neighbors::ivf_pq::search_params& params) \ + raft::device_matrix_view distances) \ { \ - raft::neighbors::ivf_pq::search(handle, idx, queries, neighbors, distances, params); \ + raft::neighbors::ivf_pq::search(handle, params, idx, queries, neighbors, distances); \ } RAFT_SEARCH_INST(uint8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu index 6818fa665d..d559291b93 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu @@ -21,13 +21,9 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const index_params& params) \ - ->index; \ - template void build(raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view dataset, \ - const index_params& params); + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu index feee5eaba2..c8b31e1fff 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu @@ -21,13 +21,9 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const index_params& params) \ - ->index; \ - template void build(raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view dataset, \ - const index_params& params); + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu index 963cc23f57..5fc62969f0 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu @@ -21,13 +21,9 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto build(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - const index_params& params) \ - ->index; \ - template void build(raft::device_resources const& handle, \ - index* index, \ - raft::device_matrix_view dataset, \ - const index_params& params); + const index_params& params, \ + raft::device_matrix_view dataset) \ + ->index; RAFT_MAKE_INSTANCE(uint8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu index 70ef1a3acf..4cc616f32d 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu @@ -22,15 +22,15 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto extend( \ raft::device_resources const& handle, \ - const index& orig_index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ + std::optional> new_indices, \ + const index& idx) \ ->index; \ template void extend( \ raft::device_resources const& handle, \ - index* index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices); + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu index a284bec9f3..a3117aae0f 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu @@ -22,15 +22,15 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto extend( \ raft::device_resources const& handle, \ - const index& orig_index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ + std::optional> new_indices, \ + const index& idx) \ ->index; \ template void extend( \ raft::device_resources const& handle, \ - index* index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices); + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu index 2ef568885f..a5e3d68569 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu @@ -22,15 +22,15 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template auto extend( \ raft::device_resources const& handle, \ - const index& orig_index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices) \ + std::optional> new_indices, \ + const index& idx) \ ->index; \ template void extend( \ raft::device_resources const& handle, \ - index* index, \ raft::device_matrix_view new_vectors, \ - std::optional> new_indices); + std::optional> new_indices, \ + index* idx); RAFT_MAKE_INSTANCE(uint8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu index 43f2d3898e..92a4d89e6b 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu @@ -21,11 +21,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const index& index, \ + const search_params& params, \ + const index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const search_params& params); + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu index fd8c727853..62a8b48ad5 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu @@ -21,11 +21,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const index& index, \ + const search_params& params, \ + const index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const search_params& params); + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu index 0717a7462d..3bcf134a22 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu @@ -21,11 +21,11 @@ namespace raft::neighbors::ivf_pq { #define RAFT_MAKE_INSTANCE(T, IdxT) \ template void search(raft::device_resources const& handle, \ - const index& index, \ + const search_params& params, \ + const index& idx, \ raft::device_matrix_view queries, \ raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances, \ - const search_params& params); + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(uint8_t, int64_t); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index cdd6570562..b78bd872f7 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -156,7 +156,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_flat::build(handle_, database_view, index_params); + index idx = ivf_flat::build(handle_, index_params, database_view); rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), @@ -169,7 +169,8 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto half_of_data_view = raft::make_device_matrix_view( (const DataT*)database.data(), half_of_data, ps.dim); - auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); + const std::optional> no_opt = std::nullopt; + index idx_2 = ivf_flat::extend(handle_, half_of_data_view, no_opt, idx); auto new_half_of_data_view = raft::make_device_matrix_view( database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); @@ -178,10 +179,10 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); ivf_flat::extend(handle_, - &index_2, new_half_of_data_view, std::make_optional>( - new_half_of_data_indices_view)); + new_half_of_data_indices_view), + &idx_2); auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); @@ -189,47 +190,46 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { indices_ivfflat_dev.data(), ps.num_queries, ps.k); auto dists_out_view = raft::make_device_matrix_view( distances_ivfflat_dev.data(), ps.num_queries, ps.k); - raft::spatial::knn::ivf_flat::detail::serialize(handle_, "ivf_flat_index", index_2); + raft::spatial::knn::ivf_flat::detail::serialize(handle_, "ivf_flat_index", idx_2); auto index_loaded = raft::spatial::knn::ivf_flat::detail::deserialize(handle_, "ivf_flat_index"); ivf_flat::search(handle_, + search_params, index_loaded, search_queries_view, indices_out_view, - dists_out_view, - search_params); + dists_out_view); update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); handle_.sync_stream(stream_); // Test the centroid invariants - if (index_2.adaptive_centers()) { + if (idx_2.adaptive_centers()) { // The centers must be up-to-date with the corresponding data - std::vector list_sizes(index_2.n_lists()); - std::vector list_offsets(index_2.n_lists()); + std::vector list_sizes(idx_2.n_lists()); + std::vector list_offsets(idx_2.n_lists()); rmm::device_uvector centroid(ps.dim, stream_); + raft::copy(list_sizes.data(), idx_2.list_sizes().data_handle(), idx_2.n_lists(), stream_); raft::copy( - list_sizes.data(), index_2.list_sizes().data_handle(), index_2.n_lists(), stream_); - raft::copy( - list_offsets.data(), index_2.list_offsets().data_handle(), index_2.n_lists(), stream_); + list_offsets.data(), idx_2.list_offsets().data_handle(), idx_2.n_lists(), stream_); handle_.sync_stream(stream_); - for (uint32_t l = 0; l < index_2.n_lists(); l++) { + for (uint32_t l = 0; l < idx_2.n_lists(); l++) { rmm::device_uvector cluster_data(list_sizes[l] * ps.dim, stream_); raft::spatial::knn::detail::utils::copy_selected( (IdxT)list_sizes[l], (IdxT)ps.dim, database.data(), - index_2.indices().data_handle() + list_offsets[l], + idx_2.indices().data_handle() + list_offsets[l], (IdxT)ps.dim, cluster_data.data(), (IdxT)ps.dim, stream_); raft::stats::mean( centroid.data(), cluster_data.data(), ps.dim, list_sizes[l], false, true, stream_); - ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle() + ps.dim * l, + ASSERT_TRUE(raft::devArrMatch(idx_2.centers().data_handle() + ps.dim * l, centroid.data(), ps.dim, raft::CompareApprox(0.001), @@ -237,9 +237,9 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { } } else { // The centers must be immutable - ASSERT_TRUE(raft::devArrMatch(index_2.centers().data_handle(), - index.centers().data_handle(), - index_2.centers().size(), + ASSERT_TRUE(raft::devArrMatch(idx_2.centers().data_handle(), + idx.centers().data_handle(), + idx_2.centers().size(), raft::Compare(), stream_)); } diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index f07a241b95..c368192b03 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -178,17 +178,17 @@ class ivf_pq_test : public ::testing::TestWithParam { handle_.sync_stream(stream_); } - auto build_only() + index build_only() { auto ipams = ps.index_params; ipams.add_data_on_build = true; auto index_view = raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - return ivf_pq::build(handle_, index_view, ipams); + return ivf_pq::build(handle_, ipams, index_view); } - auto build_2_extends() + index build_2_extends() { rmm::device_uvector db_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), @@ -207,18 +207,21 @@ class ivf_pq_test : public ::testing::TestWithParam { auto database_view = raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_pq::build(handle_, database_view, ipams); + auto idx = ivf_pq::build(handle_, ipams, database_view); auto vecs_2_view = raft::make_device_matrix_view(vecs_2, size_2, ps.dim); auto inds_2_view = raft::make_device_matrix_view(inds_2, size_2, 1); - ivf_pq::extend(handle_, &index, vecs_2_view, inds_2_view); - - auto vecs_1_view = raft::make_device_matrix_view(vecs_1, size_1, ps.dim); - auto inds_1_view = raft::make_device_matrix_view(inds_1, size_1, 1); - return ivf_pq::extend(handle_, index, vecs_1_view, inds_1_view); + ivf_pq::extend(handle_, vecs_2_view, inds_2_view, &idx); + + auto vecs_1_view = + raft::make_device_matrix_view(vecs_1, size_1, ps.dim); + auto inds_1_view = + raft::make_device_matrix_view(inds_1, size_1, 1); + ivf_pq::extend(handle_, vecs_1_view, inds_1_view, &idx); + return idx; } - auto build_serialize() + index build_serialize() { ivf_pq::serialize(handle_, "ivf_pq_index", build_only()); return ivf_pq::deserialize(handle_, "ivf_pq_index"); @@ -227,7 +230,7 @@ class ivf_pq_test : public ::testing::TestWithParam { template void run(BuildIndex build_index) { - auto index = build_index(); + index index = build_index(); size_t queries_size = ps.num_queries * ps.k; std::vector indices_ivf_pq(queries_size); @@ -244,7 +247,7 @@ class ivf_pq_test : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_ivf_pq_dev.data(), ps.num_queries, ps.k); ivf_pq::search( - handle_, index, query_view, inds_view, dists_view, ps.search_params); + handle_, ps.search_params, index, query_view, inds_view, dists_view); update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index 7d951ae56a..d04d833f3b 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -107,63 +107,63 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ cdef void build( const device_resources& handle, - index[int64_t]* index, + const index_params& params, device_matrix_view[float, int64_t, row_major] dataset, - const index_params& params) except + + index[int64_t]* index) except + cdef void build( const device_resources& handle, - index[int64_t]* index, + const index_params& params, device_matrix_view[int8_t, int64_t, row_major] dataset, - const index_params& params) except + + index[int64_t]* index) except + cdef void build( const device_resources& handle, - index[int64_t]* index, + const index_params& params, device_matrix_view[uint8_t, int64_t, row_major] dataset, - const index_params& params) except + + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[float, int64_t, row_major] new_vectors, - optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices) except + # noqa: E501 + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[int8_t, int64_t, row_major] new_vectors, - optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices) except + # noqa: E501 + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[int64_t]* index, device_matrix_view[uint8_t, int64_t, row_major] new_vectors, - optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices) except + # noqa: E501 + optional[device_matrix_view[int64_t, int64_t, row_major]] new_indices, + index[int64_t]* index) except + cdef void search( const device_resources& handle, + const search_params& params, const index[int64_t]& index, device_matrix_view[float, int64_t, row_major] queries, device_matrix_view[int64_t, int64_t, row_major] neighbors, - device_matrix_view[float, int64_t, row_major] distances, - const search_params& params) except + + device_matrix_view[float, int64_t, row_major] distances) except + cdef void search( const device_resources& handle, + const search_params& params, const index[int64_t]& index, device_matrix_view[int8_t, int64_t, row_major] queries, device_matrix_view[int64_t, int64_t, row_major] neighbors, - device_matrix_view[float, int64_t, row_major] distances, - const search_params& params) except + + device_matrix_view[float, int64_t, row_major] distances) except + cdef void search( const device_resources& handle, + const search_params& params, const index[int64_t]& index, device_matrix_view[uint8_t, int64_t, row_major] queries, device_matrix_view[int64_t, int64_t, row_major] neighbors, - device_matrix_view[float, int64_t, row_major] distances, - const search_params& params) except + + device_matrix_view[float, int64_t, row_major] distances) except + cdef void serialize(const device_resources& handle, const string& filename, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 75de0aba82..4f4d2c75a4 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -410,23 +410,23 @@ def build(IndexParams index_params, dataset, handle=None): if dataset_dt == np.float32: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - idx.index, + index_params.params, get_dmv_float(dataset_cai, check_shape=True), - index_params.params) + idx.index) idx.trained = True elif dataset_dt == np.byte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - idx.index, + index_params.params, get_dmv_int8(dataset_cai, check_shape=True), - index_params.params) + idx.index) idx.trained = True elif dataset_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), - idx.index, + index_params.params, get_dmv_uint8(dataset_cai, check_shape=True), - index_params.params) + idx.index) idx.trained = True else: raise TypeError("dtype %s not supported" % dataset_dt) @@ -520,21 +520,21 @@ def extend(Index index, new_vectors, new_indices, handle=None): if vecs_dt == np.float32: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_float(vecs_cai, check_shape=True), - create_optional(get_dmv_int64(idx_cai, check_shape=False))) # noqa: E501 + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) elif vecs_dt == np.int8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_int8(vecs_cai, check_shape=True), - create_optional(get_dmv_int64(idx_cai, check_shape=False))) # noqa: E501 + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) elif vecs_dt == np.uint8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), - index.index, get_dmv_uint8(vecs_cai, check_shape=True), - create_optional(get_dmv_int64(idx_cai, check_shape=False))) # noqa: E501 + create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501 + index.index) else: raise TypeError("query dtype %s not supported" % vecs_dt) @@ -721,27 +721,27 @@ def search(SearchParams search_params, if queries_dt == np.float32: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), + params, deref(index.index), get_dmv_float(queries_cai, check_shape=True), get_dmv_int64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True), - params) + get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.byte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), + params, deref(index.index), get_dmv_int8(queries_cai, check_shape=True), get_dmv_int64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True), - params) + get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), + params, deref(index.index), get_dmv_uint8(queries_cai, check_shape=True), get_dmv_int64(neighbors_cai, check_shape=True), - get_dmv_float(distances_cai, check_shape=True), - params) + get_dmv_float(distances_cai, check_shape=True)) else: raise ValueError("query dtype %s not supported" % queries_dt)