Skip to content

Commit

Permalink
IVF-PQ: tighten the test criteria (#1135)
Browse files Browse the repository at this point in the history
Make the recall reporting a bit more verbose and try to tighten the `min_recall` for various test cases. This should help spot any regressions in future and improve our understanding of ivf-pq performance for various inputs.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1135
  • Loading branch information
achirkin authored Jan 14, 2023
1 parent dde7c53 commit 2af2749
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 29 deletions.
122 changes: 95 additions & 27 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,18 @@
#include <algorithm>
#include <cstddef>
#include <iostream>
#include <optional>
#include <vector>

namespace raft::neighbors::ivf_pq {

struct ivf_pq_inputs {
uint32_t num_db_vecs = 4096;
uint32_t num_queries = 1024;
uint32_t dim = 64;
uint32_t k = 32;
uint32_t num_db_vecs = 4096;
uint32_t num_queries = 1024;
uint32_t dim = 64;
uint32_t k = 32;
std::optional<double> min_recall = std::nullopt;

ivf_pq::index_params index_params;
ivf_pq::search_params search_params;

Expand Down Expand Up @@ -91,6 +94,7 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream
PRINT_DIFF(.num_queries);
PRINT_DIFF(.dim);
PRINT_DIFF(.k);
PRINT_DIFF_V(.min_recall, p.min_recall.value_or(0));
PRINT_DIFF_V(.index_params.metric, print_metric{p.index_params.metric});
PRINT_DIFF(.index_params.metric_arg);
PRINT_DIFF(.index_params.add_data_on_build);
Expand All @@ -100,6 +104,7 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream
PRINT_DIFF(.index_params.pq_bits);
PRINT_DIFF(.index_params.pq_dim);
PRINT_DIFF(.index_params.codebook_kind);
PRINT_DIFF(.index_params.force_random_rotation);
PRINT_DIFF(.search_params.n_probes);
PRINT_DIFF_V(.search_params.lut_dtype, print_dtype{p.search_params.lut_dtype});
PRINT_DIFF_V(.search_params.internal_distance_dtype,
Expand Down Expand Up @@ -231,21 +236,26 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_);
handle_.sync_stream(stream_);

// Using very dense, small codebooks results in large errors in the distance calculation
double low_precision_factor =
static_cast<double>(index.pq_dim() * index.pq_bits()) / static_cast<double>(ps.dim * 8);
// A very conservative lower bound on recall
double min_recall = low_precision_factor * static_cast<double>(ps.search_params.n_probes) /
static_cast<double>(ps.index_params.n_lists);
double min_recall =
static_cast<double>(ps.search_params.n_probes) / static_cast<double>(ps.index_params.n_lists);
double low_precision_factor =
static_cast<double>(ps.dim * 8) / static_cast<double>(index.pq_dim() * index.pq_bits());
// Using a heuristic to lower the required recall due to code-packing errors
min_recall =
std::min(std::erfc(0.05 * low_precision_factor / std::max(min_recall, 0.5)), min_recall);
// Use explicit per-test min recall value if provided.
min_recall = ps.min_recall.value_or(min_recall);

ASSERT_TRUE(eval_neighbours(indices_ref,
indices_ivf_pq,
distances_ref,
distances_ivf_pq,
ps.num_queries,
ps.k,
0.001 / low_precision_factor,
min_recall));
0.0001 * low_precision_factor,
min_recall))
<< ps;

// Test a few extra invariants
IdxT min_results = min_output_size(handle_, index, ps.search_params.n_probes);
Expand Down Expand Up @@ -350,18 +360,28 @@ inline auto small_dims_per_cluster() -> test_cases_t

inline auto big_dims() -> test_cases_t
{
return with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144});
// return with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144, 8192, 12288,
// 16384});
// with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144, 8192, 12288, 16384});
auto xs = with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144});
return map<ivf_pq_inputs>(xs, [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
uint32_t pq_len = 2;
y.index_params.pq_dim = div_rounding_up_safe(x.dim, pq_len);
// This comes from pure experimentation, also the recall depens a lot on pq_len.
y.min_recall = 0.48 + 0.028 * std::log2(x.dim);
return y;
});
}

/** These will surely trigger no-smem-lut kernel. */
inline auto big_dims_moderate_lut() -> test_cases_t
{
return map<ivf_pq_inputs>(big_dims(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
uint32_t pq_len = 2;
y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u);
y.index_params.pq_bits = 6;
y.search_params.lut_dtype = CUDA_R_16F;
y.min_recall = 0.69;
return y;
});
}
Expand All @@ -371,9 +391,11 @@ inline auto big_dims_small_lut() -> test_cases_t
{
return map<ivf_pq_inputs>(big_dims(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
y.index_params.pq_dim = raft::round_up_safe(y.dim / 8u, 64u);
uint32_t pq_len = 8;
y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u);
y.index_params.pq_bits = 6;
y.search_params.lut_dtype = CUDA_R_8U;
y.min_recall = 0.21;
return y;
});
}
Expand All @@ -390,30 +412,68 @@ inline auto enum_variety() -> test_cases_t
([](ivf_pq_inputs & x) f)(xs[xs.size() - 1]); \
} while (0);

ADD_CASE({ x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; });
ADD_CASE({ x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; });
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER;
x.min_recall = 0.86;
});
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE;
x.min_recall = 0.86;
});
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER;
x.index_params.pq_bits = 4;
x.min_recall = 0.79;
});
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER;
x.index_params.pq_bits = 5;
x.min_recall = 0.83;
});

ADD_CASE({ x.index_params.pq_bits = 6; });
ADD_CASE({ x.index_params.pq_bits = 7; });
ADD_CASE({ x.index_params.pq_bits = 8; });
ADD_CASE({
x.index_params.pq_bits = 6;
x.min_recall = 0.84;
});
ADD_CASE({
x.index_params.pq_bits = 7;
x.min_recall = 0.85;
});
ADD_CASE({
x.index_params.pq_bits = 8;
x.min_recall = 0.86;
});

ADD_CASE({ x.index_params.force_random_rotation = true; });
ADD_CASE({ x.index_params.force_random_rotation = false; });
ADD_CASE({
x.index_params.force_random_rotation = true;
x.min_recall = 0.86;
});
ADD_CASE({
x.index_params.force_random_rotation = false;
x.min_recall = 0.86;
});

ADD_CASE({ x.search_params.lut_dtype = CUDA_R_32F; });
ADD_CASE({ x.search_params.lut_dtype = CUDA_R_16F; });
ADD_CASE({ x.search_params.lut_dtype = CUDA_R_8U; });
ADD_CASE({
x.search_params.lut_dtype = CUDA_R_32F;
x.min_recall = 0.86;
});
ADD_CASE({
x.search_params.lut_dtype = CUDA_R_16F;
x.min_recall = 0.86;
});
ADD_CASE({
x.search_params.lut_dtype = CUDA_R_8U;
x.min_recall = 0.85;
});

ADD_CASE({ x.search_params.internal_distance_dtype = CUDA_R_32F; });
ADD_CASE({ x.search_params.internal_distance_dtype = CUDA_R_16F; });
ADD_CASE({
x.search_params.internal_distance_dtype = CUDA_R_32F;
x.min_recall = 0.86;
});
ADD_CASE({
x.search_params.internal_distance_dtype = CUDA_R_16F;
x.min_recall = 0.86;
});

return xs;
}
Expand All @@ -431,6 +491,14 @@ inline auto enum_variety_ip() -> test_cases_t
{
return map<ivf_pq_inputs>(enum_variety(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
if (y.min_recall.has_value()) {
if (y.search_params.lut_dtype == CUDA_R_8U) {
// InnerProduct score is signed,
// thus we're forced to used signed 8-bit representation,
// thus we have one bit less precision
y.min_recall = y.min_recall.value() * 0.95;
}
}
y.index_params.metric = distance::DistanceType::InnerProduct;
return y;
});
Expand Down
11 changes: 9 additions & 2 deletions cpp/test/neighbors/ann_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -232,7 +232,14 @@ auto eval_neighbours(const std::vector<T>& expected_idx,
}
}
double actual_recall = static_cast<double>(match_count) / static_cast<double>(total_count);
RAFT_LOG_INFO("Recall = %f (%zu/%zu)", actual_recall, match_count, total_count);
double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps);
RAFT_LOG_INFO("Recall = %f (%zu/%zu), the error is %2.1f%% %s the threshold (eps = %f).",
actual_recall,
match_count,
total_count,
std::abs(error_margin * 100.0),
error_margin < 0 ? "above" : "below",
eps);
if (actual_recall < min_recall - eps) {
if (actual_recall < min_recall * min_recall - eps) {
RAFT_LOG_ERROR("Recall is much lower than the minimum (%f < %f)", actual_recall, min_recall);
Expand Down

0 comments on commit 2af2749

Please sign in to comment.