From 1c3891e5505762637fc6580d5c5bd6a8280575d3 Mon Sep 17 00:00:00 2001 From: achirkin Date: Sat, 20 May 2023 07:25:31 +0200 Subject: [PATCH] ivf-pq::search: Fix the indexing type of the query mdspans --- cpp/bench/prims/neighbors/knn.cuh | 7 +-- .../neighbors/detail/cagra/cagra_build.cuh | 6 +-- cpp/include/raft/neighbors/ivf_pq-ext.cuh | 48 +++++++++---------- cpp/include/raft/neighbors/ivf_pq-inl.cuh | 14 +++--- .../raft/spatial/knn/detail/ann_quantized.cuh | 6 +-- .../neighbors/ivfpq_search_float_int64_t.cu | 36 +++++++------- .../neighbors/ivfpq_search_int8_t_int64_t.cu | 36 +++++++------- .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 36 +++++++------- cpp/test/neighbors/ann_ivf_pq.cuh | 10 ++-- 9 files changed, 100 insertions(+), 99 deletions(-) diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 8cdb816dab..e580b20fdc 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -181,9 +181,10 @@ struct ivf_pq_knn { { search_params.n_probes = 20; auto queries_view = - raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); - auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); - auto dists_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); + raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); + auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); + auto dists_view = + raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); raft::neighbors::ivf_pq::search( handle, search_params, *index, queries_view, idxs_view, dists_view); } diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index d88aaa245a..693ab9029d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -140,11 +140,11 @@ void build_knn_graph(raft::resources const& res, device_memory); for (const auto& batch : vec_batches) { - auto queries_view = raft::make_device_matrix_view( + auto queries_view = raft::make_device_matrix_view( batch.data(), batch.size(), batch.row_width()); - auto neighbors_view = make_device_matrix_view( + auto neighbors_view = make_device_matrix_view( neighbors.data_handle(), batch.size(), neighbors.extent(1)); - auto distances_view = make_device_matrix_view( + auto distances_view = make_device_matrix_view( distances.data_handle(), batch.size(), distances.extent(1)); ivf_pq::search(res, *search_params, index, queries_view, neighbors_view, distances_view); diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh index f203709b1b..5b7391569b 100644 --- a/cpp/include/raft/neighbors/ivf_pq-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -49,18 +49,18 @@ template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, SampleFilterT sample_filter) RAFT_EXPLICIT; template void search(raft::resources const& handle, const search_params& params, const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; template auto build(raft::resources const& handle, @@ -164,24 +164,24 @@ instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t); #undef instantiate_raft_neighbors_ivf_pq_extend -#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ - extern template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - \ - extern template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ rmm::mr::device_memory_resource* mr) instantiate_raft_neighbors_ivf_pq_search(float, int64_t); diff --git a/cpp/include/raft/neighbors/ivf_pq-inl.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh index e2e60f0cd3..fbe2fcb30d 100644 --- a/cpp/include/raft/neighbors/ivf_pq-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_pq-inl.cuh @@ -162,9 +162,9 @@ template void search_with_filtering(raft::resources const& handle, const search_params& params, const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, SampleFilterT sample_filter = SampleFilterT()) { RAFT_EXPECTS( @@ -182,7 +182,7 @@ void search_with_filtering(raft::resources const& handle, params, idx, queries.data_handle(), - static_cast(queries.extent(0)), + queries.extent(0), k, neighbors.data_handle(), distances.data_handle(), @@ -219,9 +219,9 @@ template void search(raft::resources const& handle, const search_params& params, const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) { search_with_filtering( handle, params, idx, queries, neighbors, distances, detail::NoneSampleFilter()); diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 9f0af8c29e..964292f6cb 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -117,9 +117,9 @@ void approx_knn_search(raft::resources const& handle, params.n_probes = index->nprobe; auto query_view = - raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); - auto indices_view = raft::make_device_matrix_view(indices, n, k); - auto distances_view = raft::make_device_matrix_view(distances, n, k); + raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); + auto indices_view = raft::make_device_matrix_view(indices, n, k); + auto distances_view = raft::make_device_matrix_view(distances, n, k); neighbors::ivf_pq::search( handle, params, *index->ivf_pq, query_view, indices_view, distances_view); } else { diff --git a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu index 2bcbe22501..e56c107735 100644 --- a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu @@ -17,24 +17,24 @@ #include #include // raft::neighbors::ivf_pq::index -#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ - template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - \ - template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ rmm::mr::device_memory_resource* mr) instantiate_raft_neighbors_ivf_pq_search(float, int64_t); diff --git a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu index 74432c1963..1efe4f7fb2 100644 --- a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -17,24 +17,24 @@ #include #include // raft::neighbors::ivf_pq::index -#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ - template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - \ - template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ rmm::mr::device_memory_resource* mr) instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t); diff --git a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu index 8a05263ca0..e746391443 100644 --- a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -17,24 +17,24 @@ #include #include // raft::neighbors::ivf_pq::index -#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ - template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); \ - \ - template void raft::neighbors::ivf_pq::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_pq::search_params& params, \ - const raft::neighbors::ivf_pq::index& idx, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ rmm::mr::device_memory_resource* mr) instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t); diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 9a6e310303..de4453a034 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -405,11 +405,11 @@ class ivf_pq_test : public ::testing::TestWithParam { rmm::device_uvector indices_ivf_pq_dev(queries_size, stream_); auto query_view = - raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); - auto inds_view = - raft::make_device_matrix_view(indices_ivf_pq_dev.data(), ps.num_queries, ps.k); - auto dists_view = - raft::make_device_matrix_view(distances_ivf_pq_dev.data(), ps.num_queries, ps.k); + raft::make_device_matrix_view(search_queries.data(), ps.num_queries, ps.dim); + auto inds_view = raft::make_device_matrix_view( + indices_ivf_pq_dev.data(), ps.num_queries, ps.k); + auto dists_view = raft::make_device_matrix_view( + distances_ivf_pq_dev.data(), ps.num_queries, ps.k); ivf_pq::search( handle_, ps.search_params, index, query_view, inds_view, dists_view);