diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp index d09bcf4f94..afb3eb6cd6 100644 --- a/cpp/include/raft/neighbors/ivf_pq_types.hpp +++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp @@ -23,6 +23,8 @@ #include #include +#include + #include namespace raft::neighbors::ivf_pq { @@ -183,6 +185,19 @@ struct index : ann::index { "IdxT must be able to represent all values of uint32_t"); public: + /** + * Default value filled in the `indices()` array. + * One may encounter it trying to access a record within a cluster that is outside of the + * `list_sizes()` bound (due to the record alignment `kIndexGroupSize`). + */ + constexpr static IdxT kInvalidRecord = std::numeric_limits::max() - 1; + /** + * Default value returned by `search` when the `n_probes` is too small and top-k is too large. + * One may encounter it if the combined size of probed clusters is smaller than the requested + * number of results per query. + */ + constexpr static IdxT kOutOfBoundsRecord = std::numeric_limits::max(); + /** Total length of the index. */ [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return indices_.extent(0); } /** Dimensionality of the input data. */ @@ -298,6 +313,10 @@ struct index : ann::index { { pq_dataset_ = make_device_mdarray(handle, make_pq_dataset_extents(index_size)); indices_ = make_device_mdarray(handle, make_extents(index_size)); + if (index_size > 0) { + thrust::fill_n( + handle.get_thrust_policy(), indices_.data_handle(), index_size, kInvalidRecord); + } check_consistency(); } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index f4db1d67b0..0ff659ae5d 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -327,7 +327,7 @@ __device__ auto find_db_row(IdxT& x, // NOLINT uint32_t ix_max = n_probes; do { uint32_t i = (ix_min + ix_max) / 2; - if (IdxT(chunk_indices[i]) < x) { + if (IdxT(chunk_indices[i]) <= x) { ix_min = i + 1; } else { ix_max = i; @@ -365,7 +365,7 @@ __launch_bounds__(BlockDim) __global__ clusters_to_probe + n_probes * query_ix, chunk_indices + n_probes * query_ix); } - neighbors[k] = valid ? db_indices[data_ix] : std::numeric_limits::max(); + neighbors[k] = valid ? db_indices[data_ix] : index::kOutOfBoundsRecord; } /** diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 36fbef4bd5..bb6fb30ad3 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -459,6 +459,30 @@ inline auto special_cases() -> test_cases_t x.search_params.n_probes = 50; }); + ADD_CASE({ + x.num_db_vecs = 10000; + x.dim = 16; + x.num_queries = 500; + x.k = 128; + x.index_params.metric = distance::DistanceType::L2Expanded; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 100; + x.search_params.n_probes = 100; + }); + + ADD_CASE({ + x.num_db_vecs = 10000; + x.dim = 16; + x.num_queries = 500; + x.k = 129; + x.index_params.metric = distance::DistanceType::L2Expanded; + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.index_params.pq_bits = 8; + x.index_params.n_lists = 100; + x.search_params.n_probes = 100; + }); + return xs; }