Skip to content

Commit

Permalink
mdspan view for IVF-PQ API (#1236)
Browse files Browse the repository at this point in the history
Answers #1209

Authors:
  - Victor Lafargue (https://github.com/viclafargue)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1236
  • Loading branch information
viclafargue authored Mar 9, 2023
1 parent cf083e4 commit 8a22373
Show file tree
Hide file tree
Showing 24 changed files with 549 additions and 406 deletions.
12 changes: 9 additions & 3 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ struct ivf_pq_knn {
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::neighbors::ivf_pq::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));

auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
}

void search(const raft::device_resources& handle,
Expand All @@ -189,8 +190,13 @@ struct ivf_pq_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;

auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
raft::neighbors::ivf_pq::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
handle, search_params, *index, queries_view, ps.k, idxs_view, dists_view);
}
};

Expand Down
137 changes: 135 additions & 2 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,141 @@ namespace raft::neighbors::ivf_pq {
* @{
*/

/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param params configure the index building
* @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim]
*
* @return the constructed ivf-pq index
*/
template <typename T, typename IdxT = uint32_t>
inline auto build(raft::device_resources const& handle,
const index_params& params,
raft::device_matrix_view<const T, IdxT, row_major> dataset) -> index<IdxT>
{
IdxT n_rows = dataset.extent(0);
IdxT dim = dataset.extent(1);
return detail::build(handle, params, dataset.data_handle(), n_rows, dim);
}

/**
* @brief Build a new index containing the data of the original plus new extra vectors.
*
* Implementation note:
* The new data is clustered according to existing kmeans clusters, then the cluster
* centers are unchanged.
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
*
* @return the constructed extended ivf-pq index
*/
template <typename T, typename IdxT>
inline auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices) -> index<IdxT>
{
IdxT n_rows = new_vectors.extent(0);
ASSERT(n_rows == new_indices.extent(0),
"new_vectors and new_indices have different number of rows");
ASSERT(new_vectors.extent(1) == orig_index.dim(),
"new_vectors should have the same dimension as the index");
return detail::extend(
handle, orig_index, new_vectors.data_handle(), new_indices.data_handle(), n_rows);
}

/**
* @brief Extend the index with the new data.
* *
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
*/
template <typename T, typename IdxT>
inline void extend(raft::device_resources const& handle,
index<IdxT>* index,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices)
{
*index = extend(handle, *index, new_vectors, new_indices);
}

/**
* @brief Search ANN using the constructed index.
*
* See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
* pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
* eliminate entirely allocations happening within `search`.
* The exact size of the temporary buffer depends on multiple factors and is an implementation
* detail. However, you can safely specify a small initial size for the memory pool, so that only a
* few allocations happen to grow it during the first invocations of the `search`.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param k the number of neighbors to find for each query.
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
inline void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
raft::device_matrix_view<const T, IdxT, row_major> queries,
uint32_t k,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
{
IdxT n_queries = queries.extent(0);
bool check_n_rows = (n_queries == neighbors.extent(0)) && (n_queries == distances.extent(0));
ASSERT(check_n_rows,
"queries, neighbors and distances parameters have inconsistent number of rows");
return detail::search(handle,
params,
index,
queries.data_handle(),
n_queries,
k,
neighbors.data_handle(),
distances.data_handle(),
handle.get_workspace_resource());
}

/** @} */ // end group ivf_pq

/**
* @brief Build the index from the dataset for efficient search.
*
Expand Down Expand Up @@ -197,6 +332,4 @@ void search(raft::device_resources const& handle,
return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr);
}

/** @} */ // end group ivf_pq

} // namespace raft::neighbors::ivf_pq
13 changes: 10 additions & 3 deletions cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ void approx_knn_build_index(raft::device_resources const& handle,
params.pq_bits = ivf_pq_pams->n_bits;
params.pq_dim = ivf_pq_pams->M;
// TODO: handle ivf_pq_pams.usePrecomputedTables ?
index->ivf_pq = std::make_unique<const neighbors::ivf_pq::index<int64_t>>(
neighbors::ivf_pq::build(handle, params, index_array, int64_t(n), D));

auto index_view = raft::make_device_matrix_view<const T, IntType>(index_array, n, D);
index->ivf_pq = std::make_unique<const neighbors::ivf_pq::index<int64_t>>(
neighbors::ivf_pq::build(handle, params, index_view));
} else {
RAFT_FAIL("Unrecognized index type.");
}
Expand Down Expand Up @@ -110,8 +112,13 @@ void approx_knn_search(raft::device_resources const& handle,
} else if (index->ivf_pq) {
neighbors::ivf_pq::search_params params;
params.n_probes = index->nprobe;

auto query_view =
raft::make_device_matrix_view<const T, IntType>(query_array, n, index->ivf_pq->dim());
auto indices_view = raft::make_device_matrix_view<IntType, IntType>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, IntType>(distances, n, k);
neighbors::ivf_pq::search(
handle, params, *index->ivf_pq, query_array, n, k, indices, distances);
handle, params, *index->ivf_pq, query_view, k, indices_view, distances_view);
} else {
RAFT_FAIL("The model is not trained");
}
Expand Down
66 changes: 29 additions & 37 deletions cpp/include/raft_runtime/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@

namespace raft::runtime::neighbors::ivf_pq {

#define RAFT_INST_SEARCH(T, IdxT) \
void search(raft::device_resources const&, \
const raft::neighbors::ivf_pq::search_params&, \
const raft::neighbors::ivf_pq::index<IdxT>&, \
const T*, \
uint32_t, \
uint32_t, \
IdxT*, \
float*, \
rmm::mr::device_memory_resource*);
#define RAFT_INST_SEARCH(T, IdxT) \
void search(raft::device_resources const&, \
const raft::neighbors::ivf_pq::search_params&, \
const raft::neighbors::ivf_pq::index<IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
uint32_t, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_INST_SEARCH(float, uint64_t);
RAFT_INST_SEARCH(int8_t, uint64_t);
Expand All @@ -40,33 +38,27 @@ RAFT_INST_SEARCH(uint8_t, uint64_t);
// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
auto extend(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index<IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim, \
raft::neighbors::ivf_pq::index<IdxT>* idx); \
\
void extend(raft::device_resources const& handle, \
raft::neighbors::ivf_pq::index<IdxT>* idx, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows);
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
auto extend(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index<IdxT>& orig_index, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_pq::index<IdxT>* idx); \
\
void extend(raft::device_resources const& handle, \
raft::neighbors::ivf_pq::index<IdxT>* idx, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices);

RAFT_INST_BUILD_EXTEND(float, uint64_t)
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t)
Expand Down
Loading

0 comments on commit 8a22373

Please sign in to comment.