diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index e011aeb706..ed3c6db909 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -179,8 +179,9 @@ struct ivf_pq_knn { { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; - index.emplace(raft::neighbors::ivf_pq::build( - handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); + + auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); + index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); } void search(const raft::device_resources& handle, @@ -189,8 +190,13 @@ struct ivf_pq_knn { IdxT* out_idxs) { search_params.n_probes = 20; + + auto queries_view = + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto dists_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); raft::neighbors::ivf_pq::search( - handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists); + handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view); } }; diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 053fe634da..4bb617b526 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -33,6 +33,141 @@ namespace raft::neighbors::ivf_pq { * @{ */ +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param params configure the index building + * @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-pq index + */ +template +inline auto build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) -> index +{ + IdxT n_rows = dataset.extent(0); + IdxT dim = dataset.extent(1); + return detail::build(handle, params, dataset.data_handle(), n_rows, dim); +} + +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are unchanged. + * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param orig_index original index + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * + * @return the constructed extended ivf-pq index + */ +template +inline auto extend(raft::device_resources const& handle, + const index& orig_index, + raft::device_matrix_view new_vectors, + raft::device_matrix_view new_indices) -> index +{ + IdxT n_rows = new_vectors.extent(0); + ASSERT(n_rows == new_indices.extent(0), + "new_vectors and new_indices have different number of rows"); + ASSERT(new_vectors.extent(1) == orig_index.dim(), + "new_vectors should have the same dimension as the index"); + return detail::extend( + handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows); +} + +/** + * @brief Extend the index with the new data. + * * + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset + * + * @param handle + * @param[inout] index + * @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device matrix view to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + */ +template +inline void extend(raft::device_resources const& handle, + index* index, + raft::device_matrix_view new_vectors, + raft::device_matrix_view new_indices) +{ + *index = extend(handle, *index, new_vectors, new_indices); +} + +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`. + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param handle + * @param params configure the search + * @param index ivf-pq constructed index + * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] + * @param k the number of neighbors to find for each query. + * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, + * k] + */ +template +inline void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + uint32_t k, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) +{ + IdxT n_queries = queries.extent(0); + bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0)); + ASSERT(check_n_rows, + "queries, neighbors and distances parameters have inconsistent number of rows"); + return detail::search(handle, + params, + index, + queries.data_handle(), + n_queries, + k, + neighbors.data_handle(), + distances.data_handle(), + handle.get_workspace_resource()); +} + +/** @} */ // end group ivf_pq + /** * @brief Build the index from the dataset for efficient search. * @@ -197,6 +332,4 @@ void search(raft::device_resources const& handle, return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr); } -/** @} */ // end group ivf_pq - } // namespace raft::neighbors::ivf_pq diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 427e812cda..066dcaaa6b 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -78,8 +78,10 @@ void approx_knn_build_index(raft::device_resources const& handle, params.pq_bits = ivf_pq_pams->n_bits; params.pq_dim = ivf_pq_pams->M; // TODO: handle ivf_pq_pams.usePrecomputedTables ? - index->ivf_pq = std::make_unique>( - neighbors::ivf_pq::build(handle, params, index_array, int64_t(n), D)); + + auto index_view = raft::make_device_matrix_view(index_array, n, D); + index->ivf_pq = std::make_unique>( + neighbors::ivf_pq::build(handle, params, index_view)); } else { RAFT_FAIL("Unrecognized index type."); } @@ -110,8 +112,13 @@ void approx_knn_search(raft::device_resources const& handle, } else if (index->ivf_pq) { neighbors::ivf_pq::search_params params; params.n_probes = index->nprobe; + + auto query_view = + raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); + auto indices_view = raft::make_device_matrix_view(indices, n, k); + auto distances_view = raft::make_device_matrix_view(distances, n, k); neighbors::ivf_pq::search( - handle, params, *index->ivf_pq, query_array, n, k, indices, distances); + handle, params, *index->ivf_pq, query_view, k, indices_view, distances_view); } else { RAFT_FAIL("The model is not trained"); } diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index 59d0b59128..e4c228effe 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -20,16 +20,14 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_SEARCH(T, IdxT) \ - void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - const T*, \ - uint32_t, \ - uint32_t, \ - IdxT*, \ - float*, \ - rmm::mr::device_memory_resource*); +#define RAFT_INST_SEARCH(T, IdxT) \ + void search(raft::device_resources const&, \ + const raft::neighbors::ivf_pq::search_params&, \ + const raft::neighbors::ivf_pq::index&, \ + raft::device_matrix_view, \ + uint32_t, \ + raft::device_matrix_view, \ + raft::device_matrix_view); RAFT_INST_SEARCH(float, uint64_t); RAFT_INST_SEARCH(int8_t, uint64_t); @@ -40,33 +38,27 @@ RAFT_INST_SEARCH(uint8_t, uint64_t); // We define overloads for build and extend with void return type. This is used in the Cython // wrappers, where exception handling is not compatible with return type that has nontrivial // constructor. -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_pq::index; \ - \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->raft::neighbors::ivf_pq::index; \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim, \ - raft::neighbors::ivf_pq::index* idx); \ - \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_pq::index; \ + \ + auto extend(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& orig_index, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices) \ + ->raft::neighbors::ivf_pq::index; \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + void extend(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices); RAFT_INST_BUILD_EXTEND(float, uint64_t) RAFT_INST_BUILD_EXTEND(int8_t, uint64_t) diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index 2a595854bb..96ba349d1d 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -20,43 +20,36 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ - auto build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_pq::index \ - { \ - return raft::neighbors::ivf_pq::build(handle, params, dataset, n_rows, dim); \ - } \ - auto extend(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->raft::neighbors::ivf_pq::index \ - { \ - return raft::neighbors::ivf_pq::extend( \ - handle, orig_index, new_vectors, new_indices, n_rows); \ - } \ - \ - void build(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim, \ - raft::neighbors::ivf_pq::index* idx) \ - { \ - *idx = raft::neighbors::ivf_pq::build(handle, params, dataset, n_rows, dim); \ - } \ - void extend(raft::device_resources const& handle, \ - raft::neighbors::ivf_pq::index* idx, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - { \ - raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices, n_rows); \ +#define RAFT_INST_BUILD_EXTEND(T, IdxT) \ + auto build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_pq::index \ + { \ + return raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + auto extend(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& orig_index, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices) \ + ->raft::neighbors::ivf_pq::index \ + { \ + return raft::neighbors::ivf_pq::extend(handle, orig_index, new_vectors, new_indices); \ + } \ + \ + void build(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_pq::index* idx) \ + { \ + *idx = raft::neighbors::ivf_pq::build(handle, params, dataset); \ + } \ + void extend(raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices) \ + { \ + raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices); \ } RAFT_INST_BUILD_EXTEND(float, uint64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu index 3379f685c5..9bd750a2e2 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu @@ -20,19 +20,17 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + uint32_t k, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, k, neighbors, distances); \ } RAFT_SEARCH_INST(float, uint64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu index d507b0b6e3..303c7009cf 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu @@ -20,19 +20,17 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + uint32_t k, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, k, neighbors, distances); \ } RAFT_SEARCH_INST(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu index cc5331c015..c057abd22e 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu @@ -20,19 +20,17 @@ namespace raft::runtime::neighbors::ivf_pq { -#define RAFT_SEARCH_INST(T, IdxT) \ - void search(raft::device_resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource* mr) \ - { \ - raft::neighbors::ivf_pq::search( \ - handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \ +#define RAFT_SEARCH_INST(T, IdxT) \ + void search(raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + uint32_t k, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + raft::neighbors::ivf_pq::search( \ + handle, params, idx, queries, k, neighbors, distances); \ } RAFT_SEARCH_INST(uint8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu index 841c417eb0..9563ea8a88 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu @@ -18,12 +18,10 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + raft::device_matrix_view dataset) \ ->index; RAFT_MAKE_INSTANCE(float, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu index ee681954aa..40c84d2a73 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu @@ -18,12 +18,10 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + raft::device_matrix_view dataset) \ ->index; RAFT_MAKE_INSTANCE(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu index a22917fade..8d406542e8 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu @@ -18,12 +18,10 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto build(raft::device_resources const& handle, \ + const index_params& params, \ + raft::device_matrix_view dataset) \ ->index; RAFT_MAKE_INSTANCE(uint8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu index a43c221acb..3a0690a2f1 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu @@ -18,18 +18,17 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->index; \ - template void extend(raft::device_resources const& handle, \ - index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend(raft::device_resources const& handle, \ + const index& orig_index, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices); RAFT_MAKE_INSTANCE(float, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu index f8399c67be..83cb2d14e9 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu @@ -18,18 +18,17 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->index; \ - template void extend(raft::device_resources const& handle, \ - index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend(raft::device_resources const& handle, \ + const index& orig_index, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices); RAFT_MAKE_INSTANCE(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu index 2307400811..0b218dbc6f 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu @@ -18,18 +18,17 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - const index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->index; \ - template void extend(raft::device_resources const& handle, \ - index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template auto extend(raft::device_resources const& handle, \ + const index& orig_index, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices) \ + ->index; \ + template void extend( \ + raft::device_resources const& handle, \ + index* index, \ + raft::device_matrix_view new_vectors, \ + raft::device_matrix_view new_indices); RAFT_MAKE_INSTANCE(uint8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu index 90a9377452..f28e854554 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu @@ -18,16 +18,14 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource*) +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void search(raft::device_resources const& handle, \ + const search_params& params, \ + const index& index, \ + raft::device_matrix_view queries, \ + uint32_t k, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(float, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu index 17e7a09f16..230001df75 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu @@ -18,16 +18,14 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource*) +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void search(raft::device_resources const& handle, \ + const search_params& params, \ + const index& index, \ + raft::device_matrix_view queries, \ + uint32_t k, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(int8_t, uint64_t); diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu index 08731067bc..c6ff5097dc 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu @@ -18,16 +18,14 @@ namespace raft::neighbors::ivf_pq { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::mr::device_memory_resource*) +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void search(raft::device_resources const& handle, \ + const search_params& params, \ + const index& index, \ + raft::device_matrix_view queries, \ + uint32_t k, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); RAFT_MAKE_INSTANCE(uint8_t, uint64_t); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 91294a859a..e2a938aef8 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -183,7 +183,9 @@ class ivf_pq_test : public ::testing::TestWithParam { auto ipams = ps.index_params; ipams.add_data_on_build = true; - return ivf_pq::build(handle_, ipams, database.data(), ps.num_db_vecs, ps.dim); + auto index_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + return ivf_pq::build(handle_, ipams, index_view); } auto build_2_extends() @@ -203,11 +205,17 @@ class ivf_pq_test : public ::testing::TestWithParam { auto ipams = ps.index_params; ipams.add_data_on_build = false; - auto index = - ivf_pq::build(handle_, ipams, database.data(), ps.num_db_vecs, ps.dim); + auto database_view = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_pq::build(handle_, ipams, database_view); - ivf_pq::extend(handle_, &index, vecs_2, inds_2, size_2); - return ivf_pq::extend(handle_, index, vecs_1, inds_1, size_1); + auto vecs_2_view = raft::make_device_matrix_view(vecs_2, size_2, ps.dim); + auto inds_2_view = raft::make_device_matrix_view(inds_2, size_2, 1); + ivf_pq::extend(handle_, &index, vecs_2_view, inds_2_view); + + auto vecs_1_view = raft::make_device_matrix_view(vecs_1, size_1, ps.dim); + auto inds_1_view = raft::make_device_matrix_view(inds_1, size_1, 1); + return ivf_pq::extend(handle_, index, vecs_1_view, inds_1_view); } auto build_serialize() @@ -228,14 +236,15 @@ class ivf_pq_test : public ::testing::TestWithParam { rmm::device_uvector distances_ivf_pq_dev(queries_size, stream_); rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); - ivf_pq::search(handle_, - ps.search_params, - index, - search_queries.data(), - ps.num_queries, - ps.k, - indices_ivf_pq_dev.data(), - distances_ivf_pq_dev.data()); + auto query_view = + raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); + auto inds_view = + raft::make_device_matrix_view(indices_ivf_pq_dev.data(), ps.num_queries, ps.k); + auto dists_view = + raft::make_device_matrix_view(distances_ivf_pq_dev.data(), ps.num_queries, ps.k); + + ivf_pq::search( + handle_, ps.search_params, index, query_view, ps.k, inds_view, dists_view); update_host(distances_ivf_pq.data(), distances_ivf_pq_dev.data(), queries_size, stream_); update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); diff --git a/python/pylibraft/pylibraft/common/cpp/mdspan.pxd b/python/pylibraft/pylibraft/common/cpp/mdspan.pxd index c3e5abb47e..a8c636f0b7 100644 --- a/python/pylibraft/pylibraft/common/cpp/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/cpp/mdspan.pxd @@ -19,6 +19,7 @@ # cython: embedsignature = True # cython: language_level = 3 +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t from libcpp.string cimport string from pylibraft.common.handle cimport device_resources diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd new file mode 100644 index 0000000000..2a0bdaca62 --- /dev/null +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -0,0 +1,39 @@ +# +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libc.stdint cimport int8_t, uint8_t, uint64_t +from libcpp.string cimport string + +from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major +from pylibraft.common.handle cimport device_resources + + +cdef device_matrix_view[float, uint64_t, row_major] get_dmv_float( + array, check_shape) except * + +cdef device_matrix_view[uint8_t, uint64_t, row_major] get_dmv_uint8( + array, check_shape) except * + +cdef device_matrix_view[int8_t, uint64_t, row_major] get_dmv_int8( + array, check_shape) except * + +cdef device_matrix_view[uint64_t, uint64_t, row_major] get_dmv_uint64( + array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index ec825495f4..22afda043d 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -25,11 +25,21 @@ import numpy as np from cpython.object cimport PyObject from cython.operator cimport dereference as deref from libc.stddef cimport size_t -from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t, uintptr_t +from libc.stdint cimport ( + int8_t, + int32_t, + int64_t, + uint8_t, + uint32_t, + uint64_t, + uintptr_t, +) from pylibraft.common.cpp.mdspan cimport ( col_major, + device_matrix_view, host_mdspan, + make_device_matrix_view, make_host_matrix_view, matrix_extent, ostream, @@ -144,3 +154,47 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): X2 = np.load(f) assert np.all(X.shape == X2.shape) assert np.all(X == X2) + + +cdef device_matrix_view[float, uint64_t, row_major] \ + get_dmv_float(cai, check_shape) except *: + if cai.dtype != np.float32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[float, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[uint8_t, uint64_t, row_major] \ + get_dmv_uint8(cai, check_shape) except *: + if cai.dtype != np.uint8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[uint8_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[int8_t, uint64_t, row_major] \ + get_dmv_int8(cai, check_shape) except *: + if cai.dtype != np.int8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[int8_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[uint64_t, uint64_t, row_major] \ + get_dmv_uint64(cai, check_shape) except *: + if cai.dtype != np.uint64: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[uint64_t, uint64_t, row_major]( + cai.data, shape[0], shape[1]) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index de929847e9..ca35f5b8ca 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -36,6 +36,7 @@ from libcpp.string cimport string from rmm._lib.memory_resource cimport device_memory_resource +from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major from pylibraft.common.handle cimport device_resources from pylibraft.distance.distance_type cimport DistanceType @@ -110,74 +111,68 @@ cdef extern from "raft/neighbors/ivf_pq_types.hpp" \ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ namespace "raft::runtime::neighbors::ivf_pq" nogil: - cdef void build(const device_resources& handle, - const index_params& params, - const float* dataset, - uint64_t n_rows, - uint32_t dim, - index[uint64_t]* index) except + - - cdef void build(const device_resources& handle, - const index_params& params, - const int8_t* dataset, - uint64_t n_rows, - uint32_t dim, - index[uint64_t]* index) except + - - cdef void build(const device_resources& handle, - const index_params& params, - const uint8_t* dataset, - uint64_t n_rows, - uint32_t dim, - index[uint64_t]* index) except + - - cdef void extend(const device_resources& handle, - index[uint64_t]* index, - const float* new_vectors, - const uint64_t* new_indices, - uint64_t n_rows) except + - - cdef void extend(const device_resources& handle, - index[uint64_t]* index, - const int8_t* new_vectors, - const uint64_t* new_indices, - uint64_t n_rows) except + - - cdef void extend(const device_resources& handle, - index[uint64_t]* index, - const uint8_t* new_vectors, - const uint64_t* new_indices, - uint64_t n_rows) except + - - cdef void search(const device_resources& handle, - const search_params& params, - const index[uint64_t]& index, - const float* queries, - uint32_t n_queries, - uint32_t k, - uint64_t* neighbors, - float* distances, - device_memory_resource* mr) except + - - cdef void search(const device_resources& handle, - const search_params& params, - const index[uint64_t]& index, - const int8_t* queries, - uint32_t n_queries, - uint32_t k, - uint64_t* neighbors, - float* distances, - device_memory_resource* mr) except + - - cdef void search(const device_resources& handle, - const search_params& params, - const index[uint64_t]& index, - const uint8_t* queries, - uint32_t n_queries, - uint32_t k, - uint64_t* neighbors, - float* distances, - device_memory_resource* mr) except + + cdef void build( + const device_resources& handle, + const index_params& params, + device_matrix_view[float, uint64_t, row_major] dataset, + index[uint64_t]* index) except + + + cdef void build( + const device_resources& handle, + const index_params& params, + device_matrix_view[int8_t, uint64_t, row_major] dataset, + index[uint64_t]* index) except + + + cdef void build( + const device_resources& handle, + const index_params& params, + device_matrix_view[uint8_t, uint64_t, row_major] dataset, + index[uint64_t]* index) except + + + cdef void extend( + const device_resources& handle, + index[uint64_t]* index, + device_matrix_view[float, uint64_t, row_major] new_vectors, + device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + + cdef void extend( + const device_resources& handle, + index[uint64_t]* index, + device_matrix_view[int8_t, uint64_t, row_major] new_vectors, + device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + + cdef void extend( + const device_resources& handle, + index[uint64_t]* index, + device_matrix_view[uint8_t, uint64_t, row_major] new_vectors, + device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[uint64_t]& index, + device_matrix_view[float, uint64_t, row_major] queries, + uint32_t k, + device_matrix_view[uint64_t, uint64_t, row_major] neighbors, + device_matrix_view[float, uint64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[uint64_t]& index, + device_matrix_view[int8_t, uint64_t, row_major] queries, + uint32_t k, + device_matrix_view[uint64_t, uint64_t, row_major] neighbors, + device_matrix_view[float, uint64_t, row_major] distances) except + + + cdef void search( + const device_resources& handle, + const search_params& params, + const index[uint64_t]& index, + device_matrix_view[uint8_t, uint64_t, row_major] queries, + uint32_t k, + device_matrix_view[uint64_t, uint64_t, row_major] neighbors, + device_matrix_view[float, uint64_t, row_major] distances) except + cdef void serialize(const device_resources& handle, const string& filename, diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 703409500e..47d8e94e5f 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -23,14 +23,7 @@ import warnings import numpy as np from cython.operator cimport dereference as deref -from libc.stdint cimport ( - int8_t, - int64_t, - uint8_t, - uint32_t, - uint64_t, - uintptr_t, -) +from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t, uintptr_t from libcpp cimport bool, nullptr from libcpp.string cimport string @@ -57,6 +50,13 @@ from rmm._lib.memory_resource cimport ( ) cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq +from pylibraft.common.cpp.mdspan cimport device_matrix_view +from pylibraft.common.mdspan cimport ( + get_dmv_float, + get_dmv_int8, + get_dmv_uint8, + get_dmv_uint64, +) from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, search_params, @@ -395,7 +395,6 @@ def build(IndexParams index_params, dataset, handle=None): dataset_dt = dataset_cai.dtype _check_input_array(dataset_cai, [np.dtype('float32'), np.dtype('byte'), np.dtype('ubyte')]) - cdef uintptr_t dataset_ptr = dataset_cai.data cdef uint64_t n_rows = dataset_cai.shape[0] cdef uint32_t dim = dataset_cai.shape[1] @@ -411,27 +410,21 @@ def build(IndexParams index_params, dataset, handle=None): with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - dataset_ptr, - n_rows, - dim, + get_dmv_float(dataset_cai, check_shape=True), idx.index) idx.trained = True elif dataset_dt == np.byte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - dataset_ptr, - n_rows, - dim, + get_dmv_int8(dataset_cai, check_shape=True), idx.index) idx.trained = True elif dataset_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.build(deref(handle_), index_params.params, - dataset_ptr, - n_rows, - dim, + get_dmv_uint8(dataset_cai, check_shape=True), idx.index) idx.trained = True else: @@ -523,30 +516,24 @@ def extend(Index index, new_vectors, new_indices, handle=None): if len(idx_cai.shape)!=1: raise ValueError("Indices array is expected to be 1D") - cdef uintptr_t vecs_ptr = vecs_cai.data - cdef uintptr_t idx_ptr = idx_cai.data - if vecs_dt == np.float32: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - vecs_ptr, - idx_ptr, - n_rows) + get_dmv_float(vecs_cai, check_shape=True), + get_dmv_uint64(idx_cai, check_shape=False)) elif vecs_dt == np.int8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - vecs_ptr, - idx_ptr, - n_rows) + get_dmv_int8(vecs_cai, check_shape=True), + get_dmv_uint64(idx_cai, check_shape=False)) elif vecs_dt == np.uint8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, - vecs_ptr, - idx_ptr, - n_rows) + get_dmv_uint8(vecs_cai, check_shape=True), + get_dmv_uint64(idx_cai, check_shape=False)) else: raise TypeError("query dtype %s not supported" % vecs_dt) @@ -723,7 +710,6 @@ def search(SearchParams search_params, cdef c_ivf_pq.search_params params = search_params.params - cdef uintptr_t queries_ptr = queries_cai.data cdef uintptr_t neighbors_ptr = neighbors_cai.data cdef uintptr_t distances_ptr = distances_cai.data # TODO(tfeher) pass mr_ptr arg @@ -736,34 +722,28 @@ def search(SearchParams search_params, c_ivf_pq.search(deref(handle_), params, deref(index.index), - queries_ptr, - n_queries, + get_dmv_float(queries_cai, check_shape=True), k, - neighbors_ptr, - distances_ptr, - mr_ptr) + get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.byte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), params, deref(index.index), - queries_ptr, - n_queries, + get_dmv_int8(queries_cai, check_shape=True), k, - neighbors_ptr, - distances_ptr, - mr_ptr) + get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.ubyte: with cuda_interruptible(): c_ivf_pq.search(deref(handle_), params, deref(index.index), - queries_ptr, - n_queries, + get_dmv_uint8(queries_cai, check_shape=True), k, - neighbors_ptr, - distances_ptr, - mr_ptr) + get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True)) else: raise ValueError("query dtype %s not supported" % queries_dt) diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index 5c652f7c73..ddc6f115a3 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -21,14 +21,7 @@ import numpy as np from cython.operator cimport dereference as deref -from libc.stdint cimport ( - int8_t, - int64_t, - uint8_t, - uint32_t, - uint64_t, - uintptr_t, -) +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t from libcpp cimport bool, nullptr from pylibraft.distance.distance_type cimport DistanceType @@ -55,10 +48,15 @@ cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq from pylibraft.common.cpp.mdspan cimport ( device_matrix_view, host_matrix_view, - make_device_matrix_view, make_host_matrix_view, row_major, ) +from pylibraft.common.mdspan cimport ( + get_dmv_float, + get_dmv_int8, + get_dmv_uint8, + get_dmv_uint64, +) from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, search_params, @@ -125,50 +123,6 @@ cdef extern from "raft_runtime/neighbors/refine.hpp" \ DistanceType metric) except + -cdef device_matrix_view[float, uint64_t, row_major] \ - get_device_matrix_view_float(array) except *: - cai = cai_wrapper(array) - if cai.dtype != np.float32: - raise TypeError("dtype %s not supported" % cai.dtype) - if len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - return make_device_matrix_view[float, uint64_t, row_major]( - cai.data, cai.shape[0], cai.shape[1]) - - -cdef device_matrix_view[uint64_t, uint64_t, row_major] \ - get_device_matrix_view_uint64(array) except *: - cai = cai_wrapper(array) - if cai.dtype != np.uint64: - raise TypeError("dtype %s not supported" % cai.dtype) - if len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - return make_device_matrix_view[uint64_t, uint64_t, row_major]( - cai.data, cai.shape[0], cai.shape[1]) - - -cdef device_matrix_view[uint8_t, uint64_t, row_major] \ - get_device_matrix_view_uint8(array) except *: - cai = cai_wrapper(array) - if cai.dtype != np.uint8: - raise TypeError("dtype %s not supported" % cai.dtype) - if len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - return make_device_matrix_view[uint8_t, uint64_t, row_major]( - cai.data, cai.shape[0], cai.shape[1]) - - -cdef device_matrix_view[int8_t, uint64_t, row_major] \ - get_device_matrix_view_int8(array) except *: - cai = cai_wrapper(array) - if cai.dtype != np.int8: - raise TypeError("dtype %s not supported" % cai.dtype) - if len(cai.shape) != 2: - raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) - return make_device_matrix_view[int8_t, uint64_t, row_major]( - cai.data, cai.shape[0], cai.shape[1]) - - def _get_array_params(array_interface, check_dtype=None): dtype = np.dtype(array_interface["typestr"]) if check_dtype is None and dtype != check_dtype: @@ -309,9 +263,6 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, cdef device_resources* handle_ = \ handle.getHandle() - cdef device_matrix_view[uint64_t, uint64_t, row_major] candidates_view = \ - get_device_matrix_view_uint64(candidates) - if k is None: if indices is not None: k = cai_wrapper(indices).shape[1] @@ -321,6 +272,9 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, raise ValueError("Argument k must be specified if both indices " "and distances arg is None") + queries_cai = cai_wrapper(queries) + dataset_cai = cai_wrapper(dataset) + candidates_cai = cai_wrapper(candidates) n_queries = cai_wrapper(queries).shape[0] if indices is None: @@ -329,36 +283,37 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, if distances is None: distances = device_ndarray.empty((n_queries, k), dtype='float32') - cdef DistanceType c_metric = _get_metric(metric) + indices_cai = cai_wrapper(indices) + distances_cai = cai_wrapper(distances) - dataset_cai = cai_wrapper(dataset) + cdef DistanceType c_metric = _get_metric(metric) if dataset_cai.dtype == np.float32: with cuda_interruptible(): c_refine(deref(handle_), - get_device_matrix_view_float(dataset), - get_device_matrix_view_float(queries), - get_device_matrix_view_uint64(candidates), - get_device_matrix_view_uint64(indices), - get_device_matrix_view_float(distances), + get_dmv_float(dataset_cai, check_shape=True), + get_dmv_float(queries_cai, check_shape=True), + get_dmv_uint64(candidates_cai, check_shape=True), + get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), c_metric) elif dataset_cai.dtype == np.int8: with cuda_interruptible(): c_refine(deref(handle_), - get_device_matrix_view_int8(dataset), - get_device_matrix_view_int8(queries), - get_device_matrix_view_uint64(candidates), - get_device_matrix_view_uint64(indices), - get_device_matrix_view_float(distances), + get_dmv_int8(dataset_cai, check_shape=True), + get_dmv_int8(queries_cai, check_shape=True), + get_dmv_uint64(candidates_cai, check_shape=True), + get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), c_metric) elif dataset_cai.dtype == np.uint8: with cuda_interruptible(): c_refine(deref(handle_), - get_device_matrix_view_uint8(dataset), - get_device_matrix_view_uint8(queries), - get_device_matrix_view_uint64(candidates), - get_device_matrix_view_uint64(indices), - get_device_matrix_view_float(distances), + get_dmv_uint8(dataset_cai, check_shape=True), + get_dmv_uint8(queries_cai, check_shape=True), + get_dmv_uint64(candidates_cai, check_shape=True), + get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), c_metric) else: raise TypeError("dtype %s not supported" % dataset_cai.dtype)