diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index bb6fb30ad3..a702357511 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -35,6 +35,8 @@ #include +#include +#include #include #include @@ -106,6 +108,18 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream return os; } +template +auto min_output_size(const handle_t& handle, const ivf_pq::index& index, uint32_t n_probes) + -> IdxT +{ + uint32_t skip = index.n_nonempty_lists() > n_probes ? index.n_nonempty_lists() - n_probes : 0; + auto map_type = [] __device__(uint32_t x) { return IdxT(x); }; + using iter = cub::TransformInputIterator; + iter start(index.list_sizes().data_handle() + skip, map_type); + iter end(index.list_sizes().data_handle() + index.n_nonempty_lists(), map_type); + return thrust::reduce(handle.get_thrust_policy(), start, end); +} + template class ivf_pq_test : public ::testing::TestWithParam { public: @@ -189,7 +203,7 @@ class ivf_pq_test : public ::testing::TestWithParam { } template - auto run(BuildIndex build_index) + void run(BuildIndex build_index) { auto index = build_index(); @@ -228,6 +242,29 @@ class ivf_pq_test : public ::testing::TestWithParam { ps.k, 0.001 / low_precision_factor, min_recall)); + + // Test a few extra invariants + IdxT min_results = min_output_size(handle_, index, ps.search_params.n_probes); + IdxT max_oob = ps.k <= min_results ? 0 : ps.k - min_results; + IdxT found_oob = 0; + for (uint32_t query_ix = 0; query_ix < ps.num_queries; query_ix++) { + for (uint32_t k = 0; k < ps.k; k++) { + auto flat_i = query_ix * ps.k + k; + auto found_ix = indices_ivf_pq[flat_i]; + if (found_ix == ivf_pq::index::kOutOfBoundsRecord) { + found_oob++; + continue; + } + ASSERT_NE(found_ix, ivf_pq::index::kInvalidRecord) + << "got invalid record at query_ix = " << query_ix << ", k = " << k + << " (distance = " << distances_ivf_pq[flat_i] << ")"; + ASSERT_LT(found_ix, ps.num_db_vecs) + << "got an impossible index = " << found_ix << " at query_ix = " << query_ix + << ", k = " << k << " (distance = " << distances_ivf_pq[flat_i] << ")"; + } + } + ASSERT_LE(found_oob, max_oob) + << "got too many records out-of-bounds (see ivf_pq::index::kOutOfBoundsRecord)."; } void SetUp() override // NOLINT