Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IVF-PQ: tighten the test criteria #1135

Merged
merged 9 commits into from
Jan 14, 2023
122 changes: 95 additions & 27 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,18 @@
#include <algorithm>
#include <cstddef>
#include <iostream>
#include <optional>
#include <vector>

namespace raft::neighbors::ivf_pq {

struct ivf_pq_inputs {
uint32_t num_db_vecs = 4096;
uint32_t num_queries = 1024;
uint32_t dim = 64;
uint32_t k = 32;
uint32_t num_db_vecs = 4096;
uint32_t num_queries = 1024;
uint32_t dim = 64;
uint32_t k = 32;
std::optional<double> min_recall = std::nullopt;

ivf_pq::index_params index_params;
ivf_pq::search_params search_params;

Expand Down Expand Up @@ -91,6 +94,7 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream
PRINT_DIFF(.num_queries);
PRINT_DIFF(.dim);
PRINT_DIFF(.k);
PRINT_DIFF_V(.min_recall, p.min_recall.value_or(0));
PRINT_DIFF_V(.index_params.metric, print_metric{p.index_params.metric});
PRINT_DIFF(.index_params.metric_arg);
PRINT_DIFF(.index_params.add_data_on_build);
Expand All @@ -100,6 +104,7 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream
PRINT_DIFF(.index_params.pq_bits);
PRINT_DIFF(.index_params.pq_dim);
PRINT_DIFF(.index_params.codebook_kind);
PRINT_DIFF(.index_params.force_random_rotation);
PRINT_DIFF(.search_params.n_probes);
PRINT_DIFF_V(.search_params.lut_dtype, print_dtype{p.search_params.lut_dtype});
PRINT_DIFF_V(.search_params.internal_distance_dtype,
Expand Down Expand Up @@ -231,21 +236,26 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_);
handle_.sync_stream(stream_);

// Using very dense, small codebooks results in large errors in the distance calculation
double low_precision_factor =
static_cast<double>(index.pq_dim() * index.pq_bits()) / static_cast<double>(ps.dim * 8);
// A very conservative lower bound on recall
double min_recall = low_precision_factor * static_cast<double>(ps.search_params.n_probes) /
static_cast<double>(ps.index_params.n_lists);
double min_recall =
static_cast<double>(ps.search_params.n_probes) / static_cast<double>(ps.index_params.n_lists);
double low_precision_factor =
static_cast<double>(ps.dim * 8) / static_cast<double>(index.pq_dim() * index.pq_bits());
// Using a heuristic to lower the required recall due to code-packing errors
min_recall =
std::min(std::erfc(0.05 * low_precision_factor / std::max(min_recall, 0.5)), min_recall);
// Use explicit per-test min recall value if provided.
min_recall = ps.min_recall.value_or(min_recall);

ASSERT_TRUE(eval_neighbours(indices_ref,
indices_ivf_pq,
distances_ref,
distances_ivf_pq,
ps.num_queries,
ps.k,
0.001 / low_precision_factor,
min_recall));
0.0001 * low_precision_factor,
min_recall))
<< ps;

// Test a few extra invariants
IdxT min_results = min_output_size(handle_, index, ps.search_params.n_probes);
Expand Down Expand Up @@ -350,18 +360,28 @@ inline auto small_dims_per_cluster() -> test_cases_t

inline auto big_dims() -> test_cases_t
{
return with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144});
// return with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144, 8192, 12288,
// 16384});
// with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144, 8192, 12288, 16384});
auto xs = with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144});
return map<ivf_pq_inputs>(xs, [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
uint32_t pq_len = 2;
y.index_params.pq_dim = div_rounding_up_safe(x.dim, pq_len);
// This comes from pure experimentation, also the recall depens a lot on pq_len.
y.min_recall = 0.48 + 0.028 * std::log2(x.dim);
return y;
});
}

/** These will surely trigger no-smem-lut kernel. */
inline auto big_dims_moderate_lut() -> test_cases_t
{
return map<ivf_pq_inputs>(big_dims(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
uint32_t pq_len = 2;
y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u);
y.index_params.pq_bits = 6;
y.search_params.lut_dtype = CUDA_R_16F;
y.min_recall = 0.69;
return y;
});
}
Expand All @@ -371,9 +391,11 @@ inline auto big_dims_small_lut() -> test_cases_t
{
return map<ivf_pq_inputs>(big_dims(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
y.index_params.pq_dim = raft::round_up_safe(y.dim / 8u, 64u);
uint32_t pq_len = 8;
y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u);
y.index_params.pq_bits = 6;
y.search_params.lut_dtype = CUDA_R_8U;
y.min_recall = 0.21;
return y;
});
}
Expand All @@ -390,30 +412,68 @@ inline auto enum_variety() -> test_cases_t
([](ivf_pq_inputs & x) f)(xs[xs.size() - 1]); \
} while (0);

ADD_CASE({ x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; });
ADD_CASE({ x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; });
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER;
x.min_recall = 0.86;
});
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE;
x.min_recall = 0.86;
});
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER;
x.index_params.pq_bits = 4;
x.min_recall = 0.79;
});
ADD_CASE({
x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER;
x.index_params.pq_bits = 5;
x.min_recall = 0.83;
});

ADD_CASE({ x.index_params.pq_bits = 6; });
ADD_CASE({ x.index_params.pq_bits = 7; });
ADD_CASE({ x.index_params.pq_bits = 8; });
ADD_CASE({
x.index_params.pq_bits = 6;
x.min_recall = 0.84;
});
ADD_CASE({
x.index_params.pq_bits = 7;
x.min_recall = 0.85;
});
ADD_CASE({
x.index_params.pq_bits = 8;
x.min_recall = 0.86;
});

ADD_CASE({ x.index_params.force_random_rotation = true; });
ADD_CASE({ x.index_params.force_random_rotation = false; });
ADD_CASE({
x.index_params.force_random_rotation = true;
x.min_recall = 0.86;
});
ADD_CASE({
x.index_params.force_random_rotation = false;
x.min_recall = 0.86;
});

ADD_CASE({ x.search_params.lut_dtype = CUDA_R_32F; });
ADD_CASE({ x.search_params.lut_dtype = CUDA_R_16F; });
ADD_CASE({ x.search_params.lut_dtype = CUDA_R_8U; });
ADD_CASE({
x.search_params.lut_dtype = CUDA_R_32F;
x.min_recall = 0.86;
});
ADD_CASE({
x.search_params.lut_dtype = CUDA_R_16F;
x.min_recall = 0.86;
});
ADD_CASE({
x.search_params.lut_dtype = CUDA_R_8U;
x.min_recall = 0.85;
});

ADD_CASE({ x.search_params.internal_distance_dtype = CUDA_R_32F; });
ADD_CASE({ x.search_params.internal_distance_dtype = CUDA_R_16F; });
ADD_CASE({
x.search_params.internal_distance_dtype = CUDA_R_32F;
x.min_recall = 0.86;
});
ADD_CASE({
x.search_params.internal_distance_dtype = CUDA_R_16F;
x.min_recall = 0.86;
});

return xs;
}
Expand All @@ -431,6 +491,14 @@ inline auto enum_variety_ip() -> test_cases_t
{
return map<ivf_pq_inputs>(enum_variety(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
if (y.min_recall.has_value()) {
if (y.search_params.lut_dtype == CUDA_R_8U) {
// InnerProduct score is signed,
// thus we're forced to used signed 8-bit representation,
// thus we have one bit less precision
y.min_recall = y.min_recall.value() * 0.95;
}
}
y.index_params.metric = distance::DistanceType::InnerProduct;
return y;
});
Expand Down
11 changes: 9 additions & 2 deletions cpp/test/neighbors/ann_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -232,7 +232,14 @@ auto eval_neighbours(const std::vector<T>& expected_idx,
}
}
double actual_recall = static_cast<double>(match_count) / static_cast<double>(total_count);
RAFT_LOG_INFO("Recall = %f (%zu/%zu)", actual_recall, match_count, total_count);
double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps);
RAFT_LOG_INFO("Recall = %f (%zu/%zu), the error is %2.1f%% %s the threshold (eps = %f).",
actual_recall,
match_count,
total_count,
std::abs(error_margin * 100.0),
error_margin < 0 ? "above" : "below",
eps);
if (actual_recall < min_recall - eps) {
if (actual_recall < min_recall * min_recall - eps) {
RAFT_LOG_ERROR("Recall is much lower than the minimum (%f < %f)", actual_recall, min_recall);
Expand Down