Skip to content

Commit

Permalink
Fix incorrect lookup of the DB record when the query result index is …
Browse files Browse the repository at this point in the history
…exactly the size of a cluster
  • Loading branch information
achirkin committed Nov 17, 2022
1 parent d64799f commit b28364d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
19 changes: 19 additions & 0 deletions cpp/include/raft/neighbors/ivf_pq_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <raft/distance/distance_types.hpp>
#include <raft/util/integer_utils.hpp>

#include <thrust/fill.h>

#include <type_traits>

namespace raft::neighbors::ivf_pq {
Expand Down Expand Up @@ -183,6 +185,19 @@ struct index : ann::index {
"IdxT must be able to represent all values of uint32_t");

public:
/**
* Default value filled in the `indices()` array.
* One may encounter it trying to access a record within a cluster that is outside of the
* `list_sizes()` bound (due to the record alignment `kIndexGroupSize`).
*/
constexpr static IdxT kInvalidRecord = std::numeric_limits<IdxT>::max() - 1;
/**
* Default value returned by `search` when the `n_probes` is too small and top-k is too large.
* One may encounter it if the combined size of probed clusters is smaller than the requested
* number of results per query.
*/
constexpr static IdxT kOutOfBoundsRecord = std::numeric_limits<IdxT>::max();

/** Total length of the index. */
[[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return indices_.extent(0); }
/** Dimensionality of the input data. */
Expand Down Expand Up @@ -298,6 +313,10 @@ struct index : ann::index {
{
pq_dataset_ = make_device_mdarray<uint8_t>(handle, make_pq_dataset_extents(index_size));
indices_ = make_device_mdarray<IdxT>(handle, make_extents<IdxT>(index_size));
if (index_size > 0) {
thrust::fill_n(
handle.get_thrust_policy(), indices_.data_handle(), index_size, kInvalidRecord);
}
check_consistency();
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ __device__ auto find_db_row(IdxT& x, // NOLINT
uint32_t ix_max = n_probes;
do {
uint32_t i = (ix_min + ix_max) / 2;
if (IdxT(chunk_indices[i]) < x) {
if (IdxT(chunk_indices[i]) <= x) {
ix_min = i + 1;
} else {
ix_max = i;
Expand Down Expand Up @@ -365,7 +365,7 @@ __launch_bounds__(BlockDim) __global__
clusters_to_probe + n_probes * query_ix,
chunk_indices + n_probes * query_ix);
}
neighbors[k] = valid ? db_indices[data_ix] : std::numeric_limits<IdxT>::max();
neighbors[k] = valid ? db_indices[data_ix] : index<IdxT>::kOutOfBoundsRecord;
}

/**
Expand Down
24 changes: 24 additions & 0 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,30 @@ inline auto special_cases() -> test_cases_t
x.search_params.n_probes = 50;
});

ADD_CASE({
x.num_db_vecs = 10000;
x.dim = 16;
x.num_queries = 500;
x.k = 128;
x.index_params.metric = distance::DistanceType::L2Expanded;
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE;
x.index_params.pq_bits = 8;
x.index_params.n_lists = 100;
x.search_params.n_probes = 100;
});

ADD_CASE({
x.num_db_vecs = 10000;
x.dim = 16;
x.num_queries = 500;
x.k = 129;
x.index_params.metric = distance::DistanceType::L2Expanded;
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE;
x.index_params.pq_bits = 8;
x.index_params.n_lists = 100;
x.search_params.n_probes = 100;
});

return xs;
}

Expand Down

0 comments on commit b28364d

Please sign in to comment.