diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh index 87400a9b93..5317f406e1 100644 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -121,7 +121,7 @@ auto build(const handle_t& handle, * * Usage example: * @code{.cpp} - * using namespace raft::spatial::knn; + * using namespace raft::neighbors; * ivf_flat::index_params index_params; * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training @@ -164,7 +164,7 @@ auto extend(const handle_t& handle, * * Usage example: * @code{.cpp} - * using namespace raft::spatial::knn; + * using namespace raft::neighbors; * ivf_flat::index_params index_params; * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training @@ -195,7 +195,7 @@ auto extend(const handle_t& handle, std::optional> new_indices = std::nullopt) -> index { - return raft::spatial::knn::ivf_flat::detail::extend( + return extend( handle, orig_index, new_vectors.data_handle(), @@ -208,7 +208,7 @@ auto extend(const handle_t& handle, * * Usage example: * @code{.cpp} - * using namespace raft::spatial::knn; + * using namespace raft::neighbors; * ivf_flat::index_params index_params; * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training @@ -244,7 +244,7 @@ void extend(const handle_t& handle, * * Usage example: * @code{.cpp} - * using namespace raft::spatial::knn; + * using namespace raft::neighbors; * ivf_flat::index_params index_params; * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training @@ -376,7 +376,7 @@ void search(const handle_t& handle, const index& index, raft::device_matrix_view queries, raft::device_matrix_view neighbors, - raft::device_matrix_view distances, + raft::device_matrix_view distances, const search_params& params, int_t k) { @@ -391,15 +391,15 @@ void search(const handle_t& handle, RAFT_EXPECTS(queries.extent(1) == index.dim(), "Number of query dimensions should equal number of dimensions in the index."); - return raft::spatial::knn::ivf_flat::detail::search(handle, - params, - index, - queries.data_handle(), - queries.extent(0), - k, - neighbors.data_handle(), - distances.data_handle(), - nullptr); + return search(handle, + params, + index, + queries.data_handle(), + static_cast(queries.extent(0)), + static_cast(k), + neighbors.data_handle(), + distances.data_handle(), + nullptr); } } // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 5d619c5bec..207e298947 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -76,7 +76,7 @@ inline auto build( * * Usage example: * @code{.cpp} - * using namespace raft::spatial::knn; + * using namespace raft::neighbors; * ivf_pq::index_params index_params; * index_params.add_data_on_build = false; // don't populate index on build * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index 65b6f5ed4b..92fe49be98 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -36,4 +36,4 @@ using raft::neighbors::ivf_flat::build; using raft::neighbors::ivf_flat::extend; using raft::neighbors::ivf_flat::search; -} // namespace raft::spatial::knn::ivf_flat +}; // namespace raft::spatial::knn::ivf_flat diff --git a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp index 2db29eeb58..75d777573f 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp +++ b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp @@ -37,4 +37,4 @@ using raft::neighbors::ivf_flat::index_params; using raft::neighbors::ivf_flat::kIndexGroupSize; using raft::neighbors::ivf_flat::search_params; -} // namespace raft::spatial::knn::ivf_flat +}; // namespace raft::spatial::knn::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 3a5daff4bb..9a430e14f2 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -169,20 +169,31 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); + auto new_half_of_data_view = raft::make_device_matrix_view( + database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); + + auto new_half_of_data_indices_view = raft::make_device_vector_view( + vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); + ivf_flat::extend(handle_, &index_2, - database.data() + half_of_data * ps.dim, - vector_indices.data() + half_of_data, - IdxT(ps.num_db_vecs) - half_of_data); - + new_half_of_data_view, + std::make_optional>( + new_half_of_data_indices_view)); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = raft::make_device_matrix_view( + indices_ivfflat_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_ivfflat_dev.data(), ps.num_queries, ps.k); ivf_flat::search(handle_, - search_params, index_2, - search_queries.data(), - ps.num_queries, - ps.k, - indices_ivfflat_dev.data(), - distances_ivfflat_dev.data()); + search_queries_view, + indices_out_view, + dists_out_view, + search_params, + ps.k); update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_);