Skip to content

Commit

Permalink
Making ivf flat gtest invoke mdspanified APIs (#955)
Browse files Browse the repository at this point in the history
I recalled having made this change initially and I'm wondering if it accidentally got reverted since there's been so many hands in the IVF flat code recently. 

For proper end-to-end testing, we need ivf flat testing code to invoke the mdspan APIs (which in turn invoke the non-mdspan APIs).

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #955
  • Loading branch information
cjnolet authored Oct 27, 2022
1 parent af05bcc commit 72a38c6
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 28 deletions.
30 changes: 15 additions & 15 deletions cpp/include/raft/neighbors/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ auto build(const handle_t& handle,
*
* Usage example:
* @code{.cpp}
* using namespace raft::spatial::knn;
* using namespace raft::neighbors;
* ivf_flat::index_params index_params;
* index_params.add_data_on_build = false; // don't populate index on build
* index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
Expand Down Expand Up @@ -164,7 +164,7 @@ auto extend(const handle_t& handle,
*
* Usage example:
* @code{.cpp}
* using namespace raft::spatial::knn;
* using namespace raft::neighbors;
* ivf_flat::index_params index_params;
* index_params.add_data_on_build = false; // don't populate index on build
* index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
Expand Down Expand Up @@ -195,7 +195,7 @@ auto extend(const handle_t& handle,
std::optional<raft::device_vector_view<const idx_t, idx_t>> new_indices = std::nullopt)
-> index<value_t, idx_t>
{
return raft::spatial::knn::ivf_flat::detail::extend<value_t, idx_t>(
return extend<value_t, idx_t>(
handle,
orig_index,
new_vectors.data_handle(),
Expand All @@ -208,7 +208,7 @@ auto extend(const handle_t& handle,
*
* Usage example:
* @code{.cpp}
* using namespace raft::spatial::knn;
* using namespace raft::neighbors;
* ivf_flat::index_params index_params;
* index_params.add_data_on_build = false; // don't populate index on build
* index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
Expand Down Expand Up @@ -244,7 +244,7 @@ void extend(const handle_t& handle,
*
* Usage example:
* @code{.cpp}
* using namespace raft::spatial::knn;
* using namespace raft::neighbors;
* ivf_flat::index_params index_params;
* index_params.add_data_on_build = false; // don't populate index on build
* index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
Expand Down Expand Up @@ -376,7 +376,7 @@ void search(const handle_t& handle,
const index<value_t, idx_t>& index,
raft::device_matrix_view<const value_t, idx_t, row_major> queries,
raft::device_matrix_view<idx_t, idx_t, row_major> neighbors,
raft::device_matrix_view<idx_t, idx_t, float> distances,
raft::device_matrix_view<float, idx_t, row_major> distances,
const search_params& params,
int_t k)
{
Expand All @@ -391,15 +391,15 @@ void search(const handle_t& handle,
RAFT_EXPECTS(queries.extent(1) == index.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

return raft::spatial::knn::ivf_flat::detail::search(handle,
params,
index,
queries.data_handle(),
queries.extent(0),
k,
neighbors.data_handle(),
distances.data_handle(),
nullptr);
return search(handle,
params,
index,
queries.data_handle(),
static_cast<std::uint32_t>(queries.extent(0)),
static_cast<std::uint32_t>(k),
neighbors.data_handle(),
distances.data_handle(),
nullptr);
}

} // namespace raft::neighbors::ivf_flat
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ inline auto build(
*
* Usage example:
* @code{.cpp}
* using namespace raft::spatial::knn;
* using namespace raft::neighbors;
* ivf_pq::index_params index_params;
* index_params.add_data_on_build = false; // don't populate index on build
* index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/spatial/knn/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ using raft::neighbors::ivf_flat::build;
using raft::neighbors::ivf_flat::extend;
using raft::neighbors::ivf_flat::search;

} // namespace raft::spatial::knn::ivf_flat
}; // namespace raft::spatial::knn::ivf_flat
2 changes: 1 addition & 1 deletion cpp/include/raft/spatial/knn/ivf_flat_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ using raft::neighbors::ivf_flat::index_params;
using raft::neighbors::ivf_flat::kIndexGroupSize;
using raft::neighbors::ivf_flat::search_params;

} // namespace raft::spatial::knn::ivf_flat
}; // namespace raft::spatial::knn::ivf_flat
31 changes: 21 additions & 10 deletions cpp/test/neighbors/ann_ivf_flat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,31 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs<IdxT>> {

auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view);

auto new_half_of_data_view = raft::make_device_matrix_view<const DataT, IdxT>(
database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim);

auto new_half_of_data_indices_view = raft::make_device_vector_view<const IdxT, IdxT>(
vector_indices.data() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data);

ivf_flat::extend(handle_,
&index_2,
database.data() + half_of_data * ps.dim,
vector_indices.data() + half_of_data,
IdxT(ps.num_db_vecs) - half_of_data);

new_half_of_data_view,
std::make_optional<raft::device_vector_view<const IdxT, IdxT>>(
new_half_of_data_indices_view));

auto search_queries_view = raft::make_device_matrix_view<const DataT, IdxT>(
search_queries.data(), ps.num_queries, ps.dim);
auto indices_out_view = raft::make_device_matrix_view<IdxT, IdxT>(
indices_ivfflat_dev.data(), ps.num_queries, ps.k);
auto dists_out_view = raft::make_device_matrix_view<T, IdxT>(
distances_ivfflat_dev.data(), ps.num_queries, ps.k);
ivf_flat::search(handle_,
search_params,
index_2,
search_queries.data(),
ps.num_queries,
ps.k,
indices_ivfflat_dev.data(),
distances_ivfflat_dev.data());
search_queries_view,
indices_out_view,
dists_out_view,
search_params,
ps.k);

update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_);
update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_);
Expand Down

0 comments on commit 72a38c6

Please sign in to comment.