Skip to content

Commit

Permalink
addressing reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Mar 14, 2023
1 parent 78ea301 commit 6e34d8a
Show file tree
Hide file tree
Showing 23 changed files with 444 additions and 499 deletions.
10 changes: 7 additions & 3 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ struct ivf_pq_knn {
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
index.emplace(raft::neighbors::ivf_pq::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
auto data_view = raft::make_device_matrix_view<const ValT, IdxT>(data, ps.n_samples, ps.n_dims);
index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view));
}

void search(const raft::device_resources& handle,
Expand All @@ -188,8 +188,12 @@ struct ivf_pq_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
raft::neighbors::ivf_pq::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
handle, search_params, *index, queries_view, idxs_view, dists_view);
}
};

Expand Down
410 changes: 196 additions & 214 deletions cpp/include/raft/neighbors/ivf_flat.cuh

Large diffs are not rendered by default.

124 changes: 44 additions & 80 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,77 +46,41 @@ namespace raft::neighbors::ivf_pq {
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim]
* @param params configure the index building
* @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim]
*
* @return the constructed ivf-pq index
*/
template <typename T, typename IdxT = uint32_t>
auto build(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> dataset,
const index_params& params) -> index<IdxT>
const index_params& params,
raft::device_matrix_view<const T, IdxT, row_major> dataset) -> index<IdxT>
{
IdxT n_rows = dataset.extent(0);
IdxT dim = dataset.extent(1);
return detail::build(handle, params, dataset.data_handle(), n_rows, dim);
}

/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[in] dataset a device matrix view to a row-major matrix [n_rows, dim]
* @param params configure the index building
*
* @return the constructed ivf-pq index
*/
template <typename T, typename IdxT = uint32_t>
void build(raft::device_resources const& handle,
index<IdxT>* index,
raft::device_matrix_view<const T, IdxT, row_major> dataset,
const index_params& params)
{
IdxT n_rows = dataset.extent(0);
IdxT dim = dataset.extent(1);
*index = detail::build(handle, params, dataset.data_handle(), n_rows, dim);
}

/**
* @brief Build a new index containing the data of the original plus new extra vectors.
*
* Implementation note:
* The new data is clustered according to existing kmeans clusters, then the cluster
* centers are unchanged.
*
* @brief Extend the index with the new data.
* *
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param[inout] idx
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
*
* @return the constructed extended ivf-pq index
*/
template <typename T, typename IdxT>
auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices =
std::nullopt) -> index<IdxT>
index<IdxT> extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
const index<IdxT>& idx)
{
ASSERT(new_vectors.extent(1) == orig_index.dim(),
ASSERT(new_vectors.extent(1) == idx.dim(),
"new_vectors should have the same dimension as the index");

IdxT n_rows = new_vectors.extent(0);
Expand All @@ -126,7 +90,7 @@ auto extend(raft::device_resources const& handle,
}

return detail::extend(handle,
orig_index,
idx,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
Expand All @@ -139,20 +103,19 @@ auto extend(raft::device_resources const& handle,
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[inout] idx
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
*/
template <typename T, typename IdxT>
void extend(
raft::device_resources const& handle,
index<IdxT>* index,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices = std::nullopt)
void extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
index<IdxT>* idx)
{
ASSERT(new_vectors.extent(1) == index->dim(),
ASSERT(new_vectors.extent(1) == idx->dim(),
"new_vectors should have the same dimension as the index");

IdxT n_rows = new_vectors.extent(0);
Expand All @@ -161,11 +124,11 @@ void extend(
"new_vectors and new_indices have different number of rows");
}

*index = extend(handle,
*index,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
*idx = detail::extend(handle,
*idx,
new_vectors.data_handle(),
new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
n_rows);
}

/**
Expand All @@ -185,21 +148,21 @@ void extend(
* @tparam IdxT type of the indices
*
* @param handle
* @param index ivf-pq constructed index
* @param params configure the search
* @param idx ivf-pq constructed index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param params configure the search
*/
template <typename T, typename IdxT>
void search(raft::device_resources const& handle,
const index<IdxT>& index,
const search_params& params,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances,
const search_params& params)
raft::device_matrix_view<float, IdxT, row_major> distances)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
Expand All @@ -208,13 +171,13 @@ void search(raft::device_resources const& handle,
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");

RAFT_EXPECTS(queries.extent(1) == index.dim(),
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

std::uint32_t k = neighbors.extent(1);
return detail::search(handle,
params,
index,
idx,
queries.data_handle(),
static_cast<std::uint32_t>(queries.extent(0)),
k,
Expand Down Expand Up @@ -290,23 +253,23 @@ auto build(raft::device_resources const& handle,
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param orig_index original index
* @param idx original index
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param n_rows the number of samples
*
* @return the constructed extended ivf-pq index
*/
template <typename T, typename IdxT>
auto extend(raft::device_resources const& handle,
const index<IdxT>& orig_index,
const index<IdxT>& idx,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows) -> index<IdxT>
{
return detail::extend(handle, orig_index, new_vectors, new_indices, n_rows);
return detail::extend(handle, idx, new_vectors, new_indices, n_rows);
}

/**
Expand All @@ -316,21 +279,22 @@ auto extend(raft::device_resources const& handle,
* @tparam IdxT type of the indices in the source dataset
*
* @param handle
* @param[inout] index
* @param[inout] idx
* @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device/host pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* If the original index is empty (`index.size() == 0`), you can pass `nullptr`
* here to imply a continuous range `[0...n_rows)`.
* @param[inout] idx
* @param n_rows the number of samples
*/
template <typename T, typename IdxT>
void extend(raft::device_resources const& handle,
index<IdxT>* index,
index<IdxT>* idx,
const T* new_vectors,
const IdxT* new_indices,
IdxT n_rows)
{
detail::extend(handle, index, new_vectors, new_indices, n_rows);
detail::extend(handle, idx, new_vectors, new_indices, n_rows);
}

/**
Expand Down Expand Up @@ -365,7 +329,7 @@ void extend(raft::device_resources const& handle,
*
* @param handle
* @param params configure the search
* @param index ivf-pq constructed index
* @param idx ivf-pq constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param n_queries the batch size
* @param k the number of neighbors to find for each query.
Expand All @@ -378,15 +342,15 @@ void extend(raft::device_resources const& handle,
template <typename T, typename IdxT>
void search(raft::device_resources const& handle,
const search_params& params,
const index<IdxT>& index,
const index<IdxT>& idx,
const T* queries,
uint32_t n_queries,
uint32_t k,
IdxT* neighbors,
float* distances,
rmm::mr::device_memory_resource* mr = nullptr)
{
return detail::search(handle, params, index, queries, n_queries, k, neighbors, distances, mr);
return detail::search(handle, params, idx, queries, n_queries, k, neighbors, distances, mr);
}

} // namespace raft::neighbors::ivf_pq
55 changes: 25 additions & 30 deletions cpp/include/raft/neighbors/specializations/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,38 @@ namespace raft::neighbors::ivf_pq {
// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \
extern template auto build<T, IdxT>(raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
const raft::neighbors::ivf_pq::index_params&) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template void build<T, IdxT>(raft::device_resources const&, \
index<IdxT>*, \
raft::device_matrix_view<const T, IdxT, row_major>, \
const raft::neighbors::ivf_pq::index_params&); \
\
extern template auto extend<T, IdxT>( \
raft::device_resources const&, \
const index<IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>>) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template void extend<T, IdxT>( \
raft::device_resources const&, \
index<IdxT>*, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>>);
#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \
extern template auto build(raft::device_resources const&, \
const raft::neighbors::ivf_pq::index_params&, \
raft::device_matrix_view<const T, IdxT, row_major>) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template auto extend( \
raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>>, \
const raft::neighbors::ivf_pq::index<IdxT>&) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template void extend( \
raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>>, \
raft::neighbors::ivf_pq::index<IdxT>*);

RAFT_DECL_BUILD_EXTEND(float, int64_t)
RAFT_DECL_BUILD_EXTEND(int8_t, int64_t)
RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t)

#undef RAFT_DECL_BUILD_EXTEND

#define RAFT_DECL_SEARCH(T, IdxT) \
extern template void search<T, IdxT>(raft::device_resources const&, \
const index<IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>, \
const search_params&);
#define RAFT_DECL_SEARCH(T, IdxT) \
extern template void search(raft::device_resources const&, \
const raft::neighbors::ivf_pq::search_params&, \
const raft::neighbors::ivf_pq::index<IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_DECL_SEARCH(float, int64_t);
RAFT_DECL_SEARCH(int8_t, int64_t);
Expand Down
9 changes: 7 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void approx_knn_build_index(raft::device_resources const& handle,

auto index_view = raft::make_device_matrix_view<const T, int64_t>(index_array, n, D);
index->ivf_pq = std::make_unique<const neighbors::ivf_pq::index<int64_t>>(
neighbors::ivf_pq::build(handle, index_view, params));
neighbors::ivf_pq::build(handle, params, index_view));
} else {
RAFT_FAIL("Unrecognized index type.");
}
Expand Down Expand Up @@ -114,8 +114,13 @@ void approx_knn_search(raft::device_resources const& handle,
} else if (index->ivf_pq) {
neighbors::ivf_pq::search_params params;
params.n_probes = index->nprobe;

auto query_view =
raft::make_device_matrix_view<const T, int64_t>(query_array, n, index->ivf_pq->dim());
auto indices_view = raft::make_device_matrix_view<int64_t, int64_t>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, n, k);
neighbors::ivf_pq::search(
handle, params, *index->ivf_pq, query_array, n, k, indices, distances);
handle, params, *index->ivf_pq, query_view, indices_view, distances_view);
} else {
RAFT_FAIL("The model is not trained");
}
Expand Down
Loading

0 comments on commit 6e34d8a

Please sign in to comment.