Skip to content

Commit

Permalink
ivf-pq::search: fix the indexing type of the query-related mdspan arg…
Browse files Browse the repository at this point in the history
…uments (#1539)

closes #1357

breaking change: the type of argument mdspans has slightly changed (second template parameter fixed to `uint32_t`)

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1539
  • Loading branch information
achirkin authored May 20, 2023
1 parent cdf107b commit a196645
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 99 deletions.
7 changes: 4 additions & 3 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,10 @@ struct ivf_pq_knn {
{
search_params.n_probes = 20;
auto queries_view =
raft::make_device_matrix_view<const ValT, IdxT>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, IdxT>(out_idxs, ps.n_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<dist_t, IdxT>(out_dists, ps.n_queries, ps.k);
raft::make_device_matrix_view<const ValT, uint32_t>(search_items, ps.n_queries, ps.n_dims);
auto idxs_view = raft::make_device_matrix_view<IdxT, uint32_t>(out_idxs, ps.n_queries, ps.k);
auto dists_view =
raft::make_device_matrix_view<dist_t, uint32_t>(out_dists, ps.n_queries, ps.k);
raft::neighbors::ivf_pq::search(
handle, search_params, *index, queries_view, idxs_view, dists_view);
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ void build_knn_graph(raft::resources const& res,
device_memory);

for (const auto& batch : vec_batches) {
auto queries_view = raft::make_device_matrix_view<const DataT, int64_t>(
auto queries_view = raft::make_device_matrix_view<const DataT, uint32_t>(
batch.data(), batch.size(), batch.row_width());
auto neighbors_view = make_device_matrix_view<int64_t, int64_t>(
auto neighbors_view = make_device_matrix_view<int64_t, uint32_t>(
neighbors.data_handle(), batch.size(), neighbors.extent(1));
auto distances_view = make_device_matrix_view<float, int64_t>(
auto distances_view = make_device_matrix_view<float, uint32_t>(
distances.data_handle(), batch.size(), distances.extent(1));

ivf_pq::search(res, *search_params, index, queries_view, neighbors_view, distances_view);
Expand Down
48 changes: 24 additions & 24 deletions cpp/include/raft/neighbors/ivf_pq-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ template <typename T, typename IdxT, typename SampleFilterT>
void search_with_filtering(raft::resources const& handle,
const search_params& params,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances,
raft::device_matrix_view<const T, uint32_t, row_major> queries,
raft::device_matrix_view<IdxT, uint32_t, row_major> neighbors,
raft::device_matrix_view<float, uint32_t, row_major> distances,
SampleFilterT sample_filter) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void search(raft::resources const& handle,
const search_params& params,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances) RAFT_EXPLICIT;
raft::device_matrix_view<const T, uint32_t, row_major> queries,
raft::device_matrix_view<IdxT, uint32_t, row_major> neighbors,
raft::device_matrix_view<float, uint32_t, row_major> distances) RAFT_EXPLICIT;

template <typename T, typename IdxT = uint32_t>
auto build(raft::resources const& handle,
Expand Down Expand Up @@ -164,24 +164,24 @@ instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t);

#undef instantiate_raft_neighbors_ivf_pq_extend

#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \
extern template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
raft::device_matrix_view<const T, IdxT, row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors, \
raft::device_matrix_view<float, IdxT, row_major> distances); \
\
extern template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \
extern template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
raft::device_matrix_view<const T, uint32_t, row_major> queries, \
raft::device_matrix_view<IdxT, uint32_t, row_major> neighbors, \
raft::device_matrix_view<float, uint32_t, row_major> distances); \
\
extern template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_neighbors_ivf_pq_search(float, int64_t);
Expand Down
14 changes: 7 additions & 7 deletions cpp/include/raft/neighbors/ivf_pq-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ template <typename T, typename IdxT, typename SampleFilterT>
void search_with_filtering(raft::resources const& handle,
const search_params& params,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances,
raft::device_matrix_view<const T, uint32_t, row_major> queries,
raft::device_matrix_view<IdxT, uint32_t, row_major> neighbors,
raft::device_matrix_view<float, uint32_t, row_major> distances,
SampleFilterT sample_filter = SampleFilterT())
{
RAFT_EXPECTS(
Expand All @@ -182,7 +182,7 @@ void search_with_filtering(raft::resources const& handle,
params,
idx,
queries.data_handle(),
static_cast<std::uint32_t>(queries.extent(0)),
queries.extent(0),
k,
neighbors.data_handle(),
distances.data_handle(),
Expand Down Expand Up @@ -219,9 +219,9 @@ template <typename T, typename IdxT>
void search(raft::resources const& handle,
const search_params& params,
const index<IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
raft::device_matrix_view<const T, uint32_t, row_major> queries,
raft::device_matrix_view<IdxT, uint32_t, row_major> neighbors,
raft::device_matrix_view<float, uint32_t, row_major> distances)
{
search_with_filtering(
handle, params, idx, queries, neighbors, distances, detail::NoneSampleFilter());
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ void approx_knn_search(raft::resources const& handle,
params.n_probes = index->nprobe;

auto query_view =
raft::make_device_matrix_view<const T, int64_t>(query_array, n, index->ivf_pq->dim());
auto indices_view = raft::make_device_matrix_view<int64_t, int64_t>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, n, k);
raft::make_device_matrix_view<const T, uint32_t>(query_array, n, index->ivf_pq->dim());
auto indices_view = raft::make_device_matrix_view<int64_t, uint32_t>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, uint32_t>(distances, n, k);
neighbors::ivf_pq::search(
handle, params, *index->ivf_pq, query_view, indices_view, distances_view);
} else {
Expand Down
36 changes: 18 additions & 18 deletions cpp/src/neighbors/ivfpq_search_float_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
#include <raft/neighbors/ivf_pq-inl.cuh>
#include <raft/neighbors/ivf_pq_types.hpp> // raft::neighbors::ivf_pq::index

#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
raft::device_matrix_view<const T, IdxT, row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors, \
raft::device_matrix_view<float, IdxT, row_major> distances); \
\
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
raft::device_matrix_view<const T, uint32_t, row_major> queries, \
raft::device_matrix_view<IdxT, uint32_t, row_major> neighbors, \
raft::device_matrix_view<float, uint32_t, row_major> distances); \
\
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_neighbors_ivf_pq_search(float, int64_t);
Expand Down
36 changes: 18 additions & 18 deletions cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
#include <raft/neighbors/ivf_pq-inl.cuh>
#include <raft/neighbors/ivf_pq_types.hpp> // raft::neighbors::ivf_pq::index

#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
raft::device_matrix_view<const T, IdxT, row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors, \
raft::device_matrix_view<float, IdxT, row_major> distances); \
\
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
raft::device_matrix_view<const T, uint32_t, row_major> queries, \
raft::device_matrix_view<IdxT, uint32_t, row_major> neighbors, \
raft::device_matrix_view<float, uint32_t, row_major> distances); \
\
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t);
Expand Down
36 changes: 18 additions & 18 deletions cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
#include <raft/neighbors/ivf_pq-inl.cuh>
#include <raft/neighbors/ivf_pq_types.hpp> // raft::neighbors::ivf_pq::index

#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
raft::device_matrix_view<const T, IdxT, row_major> queries, \
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors, \
raft::device_matrix_view<float, IdxT, row_major> distances); \
\
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
raft::device_matrix_view<const T, uint32_t, row_major> queries, \
raft::device_matrix_view<IdxT, uint32_t, row_major> neighbors, \
raft::device_matrix_view<float, uint32_t, row_major> distances); \
\
template void raft::neighbors::ivf_pq::search<T, IdxT>( \
raft::resources const& handle, \
const raft::neighbors::ivf_pq::search_params& params, \
const raft::neighbors::ivf_pq::index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t);
Expand Down
10 changes: 5 additions & 5 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,11 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
rmm::device_uvector<IdxT> indices_ivf_pq_dev(queries_size, stream_);

auto query_view =
raft::make_device_matrix_view<DataT, IdxT>(search_queries.data(), ps.num_queries, ps.dim);
auto inds_view =
raft::make_device_matrix_view<IdxT, IdxT>(indices_ivf_pq_dev.data(), ps.num_queries, ps.k);
auto dists_view =
raft::make_device_matrix_view<EvalT, IdxT>(distances_ivf_pq_dev.data(), ps.num_queries, ps.k);
raft::make_device_matrix_view<DataT, uint32_t>(search_queries.data(), ps.num_queries, ps.dim);
auto inds_view = raft::make_device_matrix_view<IdxT, uint32_t>(
indices_ivf_pq_dev.data(), ps.num_queries, ps.k);
auto dists_view = raft::make_device_matrix_view<EvalT, uint32_t>(
distances_ivf_pq_dev.data(), ps.num_queries, ps.k);

ivf_pq::search<DataT, IdxT>(
handle_, ps.search_params, index, query_view, inds_view, dists_view);
Expand Down

0 comments on commit a196645

Please sign in to comment.