Skip to content

Commit

Permalink
Add extra checks for index invariants
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Nov 17, 2022
1 parent b28364d commit e80a345
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

#include <gtest/gtest.h>

#include <cub/cub.cuh>
#include <thrust/reduce.h>
#include <thrust/sequence.h>

#include <algorithm>
Expand Down Expand Up @@ -106,6 +108,18 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream
return os;
}

template <typename IdxT>
auto min_output_size(const handle_t& handle, const ivf_pq::index<IdxT>& index, uint32_t n_probes)
-> IdxT
{
uint32_t skip = index.n_nonempty_lists() > n_probes ? index.n_nonempty_lists() - n_probes : 0;
auto map_type = [] __device__(uint32_t x) { return IdxT(x); };
using iter = cub::TransformInputIterator<IdxT, decltype(map_type), const uint32_t*>;
iter start(index.list_sizes().data_handle() + skip, map_type);
iter end(index.list_sizes().data_handle() + index.n_nonempty_lists(), map_type);
return thrust::reduce(handle.get_thrust_policy(), start, end);
}

template <typename EvalT, typename DataT, typename IdxT>
class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
public:
Expand Down Expand Up @@ -189,7 +203,7 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
}

template <typename BuildIndex>
auto run(BuildIndex build_index)
void run(BuildIndex build_index)
{
auto index = build_index();

Expand Down Expand Up @@ -228,6 +242,29 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
ps.k,
0.001 / low_precision_factor,
min_recall));

// Test a few extra invariants
IdxT min_results = min_output_size(handle_, index, ps.search_params.n_probes);
IdxT max_oob = ps.k <= min_results ? 0 : ps.k - min_results;
IdxT found_oob = 0;
for (uint32_t query_ix = 0; query_ix < ps.num_queries; query_ix++) {
for (uint32_t k = 0; k < ps.k; k++) {
auto flat_i = query_ix * ps.k + k;
auto found_ix = indices_ivf_pq[flat_i];
if (found_ix == ivf_pq::index<IdxT>::kOutOfBoundsRecord) {
found_oob++;
continue;
}
ASSERT_NE(found_ix, ivf_pq::index<IdxT>::kInvalidRecord)
<< "got invalid record at query_ix = " << query_ix << ", k = " << k
<< " (distance = " << distances_ivf_pq[flat_i] << ")";
ASSERT_LT(found_ix, ps.num_db_vecs)
<< "got an impossible index = " << found_ix << " at query_ix = " << query_ix
<< ", k = " << k << " (distance = " << distances_ivf_pq[flat_i] << ")";
}
}
ASSERT_LE(found_oob, max_oob)
<< "got too many records out-of-bounds (see ivf_pq::index<IdxT>::kOutOfBoundsRecord).";
}

void SetUp() override // NOLINT
Expand Down

0 comments on commit e80a345

Please sign in to comment.