Skip to content

Commit

Permalink
Update and standardize IVF indexes API (#1328)
Browse files Browse the repository at this point in the history
Update and standardize IVF indexes API + edits on specializations

Authors:
  - Victor Lafargue (https://github.com/viclafargue)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Artem M. Chirkin (https://github.com/achirkin)

URL: #1328
  • Loading branch information
viclafargue authored Mar 16, 2023
1 parent fb84190 commit 9f2a64f
Show file tree
Hide file tree
Showing 23 changed files with 378 additions and 289 deletions.
4 changes: 1 addition & 3 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ struct ivf_pq_knn {
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;

auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
}
Expand All @@ -189,13 +188,12 @@ struct ivf_pq_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;

auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
raft::neighbors::ivf_pq::search(
handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view);
handle, search_params, *index, queries_view, idxs_view, dists_view);
}
};

Expand Down
175 changes: 97 additions & 78 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,54 +45,55 @@ namespace raft::neighbors::ivf_pq {
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param params configure the index building
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim]
*
* @return the constructed ivf-pq index
*/
template <typename T, typename IdxT = uint32_t>
inline auto build(raft::device_resources const& handle,
index<IdxT> build(raft::device_resources const& handle,
const index_params& params,
raft::device_matrix_view<const T, IdxT, row_major> dataset) -> index<IdxT>
raft::device_matrix_view<const T, IdxT, row_major> dataset)
{
IdxT n_rows = dataset.extent(0);
IdxT dim = dataset.extent(1);
return detail::build(handle, params, dataset.data_handle(), n_rows, dim);
}

/**
* @brief Build a new index containing the data of the original plus new extra vectors.
*
* Implementation note:
* The new data is clustered according to existing kmeans clusters, then the cluster
* centers are unchanged.
*
* @brief Extend the index with the new data.
* *
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] handle
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
*
* @return the constructed extended ivf-pq index
* @param[inout] idx
*/
template <typename T, typename IdxT>
inline auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
index<IdxT> extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices) -> index<IdxT>
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
const index<IdxT>& idx)
{
IdxT n_rows = new_vectors.extent(0);
ASSERT(n_rows == new_indices.extent(0),
"new_vectors and new_indices have different number of rows");
ASSERT(new_vectors.extent(1) == orig_index.dim(),
ASSERT(new_vectors.extent(1) == idx.dim(),
"new_vectors should have the same dimension as the index");
return detail::extend(
handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows);

IdxT n_rows = new_vectors.extent(0);
if (new_indices.has_value()) {
ASSERT(n_rows == new_indices.value().extent(0),
"new_vectors and new_indices have different number of rows");
}

return detail::extend(handle,
idx,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
}

/**
Expand All @@ -101,20 +102,33 @@ inline auto extend(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] handle
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
* @param[inout] idx
*/
template <typename T, typename IdxT>
inline void extend(raft::device_resources const& handle,
index<IdxT>* index,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices)
void extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
index<IdxT>* idx)
{
*index = extend(handle, *index, new_vectors, new_indices);
ASSERT(new_vectors.extent(1) == idx->dim(),
"new_vectors should have the same dimension as the index");

IdxT n_rows = new_vectors.extent(0);
if (new_indices.has_value()) {
ASSERT(n_rows == new_indices.value().extent(0),
"new_vectors and new_indices have different number of rows");
}

*idx = detail::extend(handle,
*idx,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
}

/**
Expand All @@ -133,34 +147,39 @@ inline void extend(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-pq constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param k the number of neighbors to find for each query.
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
inline void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
raft::device_matrix_view<const T, IdxT, row_major> queries,
uint32_t k,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
{
IdxT n_queries = queries.extent(0);
bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0));
ASSERT(check_n_rows,
"queries, neighbors and distances parameters have inconsistent number of rows");
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");

RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

std::uint32_t k = neighbors.extent(1);
return detail::search(handle,
params,
index,
idx,
queries.data_handle(),
n_queries,
static_cast<std::uint32_t>(queries.extent(0)),
k,
neighbors.data_handle(),
distances.data_handle(),
Expand Down Expand Up @@ -193,11 +212,11 @@ inline void search(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param params configure the index building
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset a device/host pointer to a row-major matrix [n_rows, dim]
* @param n_rows the number of samples
* @param dim the dimensionality of the data
* @param[in] n_rows the number of samples
* @param[in] dim the dimensionality of the data
*
* @return the constructed ivf-pq index
*/
Expand Down Expand Up @@ -233,24 +252,24 @@ auto build(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] handle
* @param[inout] idx original index
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`idx.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param n_rows the number of samples
* @param[in] n_rows the number of samples
*
* @return the constructed extended ivf-pq index
*/
template <typename T, typename IdxT>
auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
const index<IdxT>& idx,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows) -> index<IdxT>
{
return detail::extend(handle, orig_index, new_vectors, new_indices, n_rows);
return detail::extend(handle, idx, new_vectors, new_indices, n_rows);
}

/**
Expand All @@ -259,22 +278,22 @@ auto extend(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] handle
* @param[inout] idx
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`idx.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param n_rows the number of samples
* @param[in] n_rows the number of samples
*/
template <typename T, typename IdxT>
void extend(raft::device_resources const& handle,
index<IdxT>* index,
index<IdxT>* idx,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows)
{
detail::extend(handle, index, new_vectors, new_indices, n_rows);
detail::extend(handle, idx, new_vectors, new_indices, n_rows);
}

/**
Expand Down Expand Up @@ -307,30 +326,30 @@ void extend(raft::device_resources const& handle,
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param[in] handle
* @param[in] params configure the search
* @param[in] idx ivf-pq constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param n_queries the batch size
* @param k the number of neighbors to find for each query.
* @param[in] n_queries the batch size
* @param[in] k the number of neighbors to find for each query.
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param mr an optional memory resource to use across the searches (you can provide a large enough
* memory pool here to avoid memory allocations within search).
* @param[in] mr an optional memory resource to use across the searches (you can provide a large
* enough memory pool here to avoid memory allocations within search).
*/
template <typename T, typename IdxT>
void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
const index<IdxT>& idx,
const T* queries,
uint32_t n_queries,
uint32_t k,
IdxT* neighbors,
float* distances,
rmm::mr::device_memory_resource* mr = nullptr)
{
return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr);
return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr);
}

} // namespace raft::neighbors::ivf_pq
2 changes: 0 additions & 2 deletions cpp/include/raft/neighbors/ivf_pq_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ namespace raft::neighbors::ivf_pq {
* @param[in] os output stream
* @param[in] index IVF-PQ index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<IdxT>& index)
Expand Down Expand Up @@ -77,7 +76,6 @@ void serialize(raft::device_resources const& handle, std::ostream& os, const ind
* @param[in] filename the file name for saving the index
* @param[in] index IVF-PQ index
*
* @return raft::neighbors::ivf_pq::index<IdxT>
*/
template <typename IdxT>
void serialize(raft::device_resources const& handle,
Expand Down
Loading

0 comments on commit 9f2a64f

Please sign in to comment.