Skip to content

Commit

Permalink
Tweak the sensitivity for more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Jan 13, 2023
1 parent fe3bc3b commit 0e4ee9f
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,16 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_);
handle_.sync_stream(stream_);

// Using very dense, small codebooks results in large errors in the distance calculation
double low_precision_factor =
static_cast<double>(ps.dim * 8) / static_cast<double>(index.pq_dim() * index.pq_bits());
// A very conservative lower bound on recall
double min_recall =
static_cast<double>(ps.search_params.n_probes) / static_cast<double>(ps.index_params.n_lists);
double low_precision_factor =
static_cast<double>(ps.dim * 8) / static_cast<double>(index.pq_dim() * index.pq_bits());
// Using a heuristic to lower the required recall due to code-packing errors
min_recall =
std::min(std::erfc(0.1 * low_precision_factor / std::max(min_recall, 0.5)), min_recall);
// Use explicit per-test min recall value if provided.
// TODO: investigate what prevents 100%
min_recall = std::min(1 - 0.01 * low_precision_factor, ps.min_recall.value_or(min_recall));
min_recall = ps.min_recall.value_or(min_recall);

ASSERT_TRUE(eval_neighbours(indices_ref,
indices_ivf_pq,
Expand Down Expand Up @@ -394,7 +395,7 @@ inline auto big_dims_small_lut() -> test_cases_t
y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u);
y.index_params.pq_bits = 6;
y.search_params.lut_dtype = CUDA_R_8U;
y.min_recall = 0.21;
y.min_recall = 0.2;
return y;
});
}
Expand Down Expand Up @@ -422,7 +423,7 @@ inline auto enum_variety() -> test_cases_t
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER;
x.index_params.pq_bits = 4;
x.min_recall = 0.74;
x.min_recall = 0.73;
});
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER;
Expand Down Expand Up @@ -491,8 +492,15 @@ inline auto enum_variety_ip() -> test_cases_t
return map<ivf_pq_inputs>(enum_variety(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
if (y.min_recall.has_value()) {
// Apparently InnerProduct tends to give higher scores for these parameter values.
y.min_recall = y.min_recall.value() * 1.15;
if (y.search_params.lut_dtype == CUDA_R_8U) {
// InnerProduct score is signed,
// thus we're forced to used signed 8-bit representation,
// thus we have one bit less precision
y.min_recall = y.min_recall.value() * 0.95;
} else {
// Apparently InnerProduct tends to give higher scores for these parameter values.
y.min_recall = y.min_recall.value() * 1.12;
}
}
y.index_params.metric = distance::DistanceType::InnerProduct;
return y;
Expand Down

0 comments on commit 0e4ee9f

Please sign in to comment.