Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update and standardize IVF indexes API #1328

Merged
merged 9 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ struct ivf_pq_knn {
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;

auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
}
Expand All @@ -189,13 +188,12 @@ struct ivf_pq_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;

auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
raft::neighbors::ivf_pq::search(
handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view);
handle, search_params, *index, queries_view, idxs_view, dists_view);
}
};

Expand Down
416 changes: 197 additions & 219 deletions cpp/include/raft/neighbors/ivf_flat.cuh

Large diffs are not rendered by default.

127 changes: 73 additions & 54 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,47 +52,48 @@ namespace raft::neighbors::ivf_pq {
* @return the constructed ivf-pq index
*/
template <typename T, typename IdxT = uint32_t>
inline auto build(raft::device_resources const& handle,
index<IdxT> build(raft::device_resources const& handle,
const index_params& params,
raft::device_matrix_view<const T, IdxT, row_major> dataset) -> index<IdxT>
raft::device_matrix_view<const T, IdxT, row_major> dataset)
{
IdxT n_rows = dataset.extent(0);
IdxT dim = dataset.extent(1);
return detail::build(handle, params, dataset.data_handle(), n_rows, dim);
}

/**
* @brief Build a new index containing the data of the original plus new extra vectors.
*
* Implementation note:
* The new data is clustered according to existing kmeans clusters, then the cluster
* centers are unchanged.
*
* @brief Extend the index with the new data.
* *
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param[inout] idx
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
*
* @return the constructed extended ivf-pq index
*/
template <typename T, typename IdxT>
inline auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
index<IdxT> extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices) -> index<IdxT>
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
const index<IdxT>& idx)
{
IdxT n_rows = new_vectors.extent(0);
ASSERT(n_rows == new_indices.extent(0),
"new_vectors and new_indices have different number of rows");
ASSERT(new_vectors.extent(1) == orig_index.dim(),
ASSERT(new_vectors.extent(1) == idx.dim(),
"new_vectors should have the same dimension as the index");
return detail::extend(
handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows);

IdxT n_rows = new_vectors.extent(0);
if (new_indices.has_value()) {
ASSERT(n_rows == new_indices.value().extent(0),
"new_vectors and new_indices have different number of rows");
}

return detail::extend(handle,
idx,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
}

/**
Expand All @@ -102,19 +103,32 @@ inline auto extend(raft::device_resources const& handle,
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[inout] idx
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
*/
template <typename T, typename IdxT>
inline void extend(raft::device_resources const& handle,
index<IdxT>* index,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices)
void extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
index<IdxT>* idx)
{
*index = extend(handle, *index, new_vectors, new_indices);
ASSERT(new_vectors.extent(1) == idx->dim(),
"new_vectors should have the same dimension as the index");

IdxT n_rows = new_vectors.extent(0);
if (new_indices.has_value()) {
ASSERT(n_rows == new_indices.value().extent(0),
"new_vectors and new_indices have different number of rows");
}

*idx = detail::extend(handle,
*idx,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
}

/**
Expand All @@ -135,32 +149,37 @@ inline void extend(raft::device_resources const& handle,
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param idx ivf-pq constructed index
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param k the number of neighbors to find for each query.
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
inline void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
raft::device_matrix_view<const T, IdxT, row_major> queries,
uint32_t k,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
{
IdxT n_queries = queries.extent(0);
bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0));
ASSERT(check_n_rows,
"queries, neighbors and distances parameters have inconsistent number of rows");
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");

RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

std::uint32_t k = neighbors.extent(1);
return detail::search(handle,
params,
index,
idx,
queries.data_handle(),
n_queries,
static_cast<std::uint32_t>(queries.extent(0)),
k,
neighbors.data_handle(),
distances.data_handle(),
Expand Down Expand Up @@ -234,23 +253,23 @@ auto build(raft::device_resources const& handle,
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param idx original index
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param n_rows the number of samples
*
* @return the constructed extended ivf-pq index
*/
template <typename T, typename IdxT>
auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
const index<IdxT>& idx,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows) -> index<IdxT>
{
return detail::extend(handle, orig_index, new_vectors, new_indices, n_rows);
return detail::extend(handle, idx, new_vectors, new_indices, n_rows);
}

/**
Expand All @@ -260,21 +279,21 @@ auto extend(raft::device_resources const& handle,
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[inout] idx
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param n_rows the number of samples
*/
template <typename T, typename IdxT>
void extend(raft::device_resources const& handle,
index<IdxT>* index,
index<IdxT>* idx,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows)
{
detail::extend(handle, index, new_vectors, new_indices, n_rows);
detail::extend(handle, idx, new_vectors, new_indices, n_rows);
}

/**
Expand Down Expand Up @@ -309,7 +328,7 @@ void extend(raft::device_resources const& handle,
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param idx ivf-pq constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param n_queries the batch size
* @param k the number of neighbors to find for each query.
Expand All @@ -322,15 +341,15 @@ void extend(raft::device_resources const& handle,
template <typename T, typename IdxT>
void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
const index<IdxT>& idx,
const T* queries,
uint32_t n_queries,
uint32_t k,
IdxT* neighbors,
float* distances,
rmm::mr::device_memory_resource* mr = nullptr)
{
return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr);
return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr);
}

} // namespace raft::neighbors::ivf_pq
2 changes: 0 additions & 2 deletions cpp/include/raft/neighbors/ivf_pq_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ namespace raft::neighbors::ivf_pq {
* @param[in] os output stream
* @param[in] index IVF-PQ index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<IdxT>& index)
Expand Down Expand Up @@ -77,7 +76,6 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind
* @param[in] filename the file name for saving the index
* @param[in] index IVF-PQ index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle,
Expand Down
81 changes: 49 additions & 32 deletions cpp/include/raft/neighbors/specializations/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,54 @@

namespace raft::neighbors::ivf_pq {

#define RAFT_INST(T, IdxT) \
extern template auto build<T, IdxT>(raft::device_resources const& handle, \
const index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim) \
->index<IdxT>; \
extern template auto extend<T, IdxT>(raft::device_resources const& handle, \
const index<IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->index<IdxT>; \
extern template void extend<T, IdxT>(raft::device_resources const& handle, \
index<IdxT>* index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows); \
extern template void search<T, IdxT>(raft::device_resources const&, \
const search_params&, \
const index<IdxT>&, \
const T*, \
uint32_t, \
uint32_t, \
IdxT*, \
float*, \
rmm::mr::device_memory_resource*);
RAFT_INST(float, int64_t);
RAFT_INST(int8_t, int64_t);
RAFT_INST(uint8_t, int64_t);

#undef RAFT_INST
#ifdef RAFT_DECL_BUILD_EXTEND
#undef RAFT_DECL_BUILD_EXTEND
#endif

#ifdef RAFT_DECL_SEARCH
#undef RAFT_DECL_SEARCH
#endif

// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \
extern template auto build(raft::device_resources const&, \
const raft::neighbors::ivf_pq::index_params&, \
raft::device_matrix_view<const T, IdxT, row_major>) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template auto extend( \
raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>>, \
const raft::neighbors::ivf_pq::index<IdxT>&) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template void extend( \
raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>>, \
raft::neighbors::ivf_pq::index<IdxT>*);

RAFT_DECL_BUILD_EXTEND(float, int64_t)
RAFT_DECL_BUILD_EXTEND(int8_t, int64_t)
RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t)

#undef RAFT_DECL_BUILD_EXTEND

#define RAFT_DECL_SEARCH(T, IdxT) \
extern template void search(raft::device_resources const&, \
const raft::neighbors::ivf_pq::search_params&, \
const raft::neighbors::ivf_pq::index<IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_DECL_SEARCH(float, int64_t);
RAFT_DECL_SEARCH(int8_t, int64_t);
RAFT_DECL_SEARCH(uint8_t, int64_t);

#undef RAFT_DECL_SEARCH

} // namespace raft::neighbors::ivf_pq
2 changes: 1 addition & 1 deletion cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ void approx_knn_search(raft::device_resources const& handle,
auto indices_view = raft::make_device_matrix_view<int64_t, int64_t>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, n, k);
neighbors::ivf_pq::search(
handle, params, *index->ivf_pq, query_view, k, indices_view, distances_view);
handle, params, *index->ivf_pq, query_view, indices_view, distances_view);
} else {
RAFT_FAIL("The model is not trained");
}
Expand Down
Loading