diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 94777aedd1..b5671b74b0 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -42,15 +42,18 @@ #include #include #include +#include #include namespace raft::neighbors::ivf_pq { struct ivf_pq_inputs { - uint32_t num_db_vecs = 4096; - uint32_t num_queries = 1024; - uint32_t dim = 64; - uint32_t k = 32; + uint32_t num_db_vecs = 4096; + uint32_t num_queries = 1024; + uint32_t dim = 64; + uint32_t k = 32; + std::optional min_recall = std::nullopt; + ivf_pq::index_params index_params; ivf_pq::search_params search_params; @@ -91,6 +94,7 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream PRINT_DIFF(.num_queries); PRINT_DIFF(.dim); PRINT_DIFF(.k); + PRINT_DIFF_V(.min_recall, p.min_recall.value_or(0)); PRINT_DIFF_V(.index_params.metric, print_metric{p.index_params.metric}); PRINT_DIFF(.index_params.metric_arg); PRINT_DIFF(.index_params.add_data_on_build); @@ -100,6 +104,7 @@ inline auto operator<<(std::ostream& os, const ivf_pq_inputs& p) -> std::ostream PRINT_DIFF(.index_params.pq_bits); PRINT_DIFF(.index_params.pq_dim); PRINT_DIFF(.index_params.codebook_kind); + PRINT_DIFF(.index_params.force_random_rotation); PRINT_DIFF(.search_params.n_probes); PRINT_DIFF_V(.search_params.lut_dtype, print_dtype{p.search_params.lut_dtype}); PRINT_DIFF_V(.search_params.internal_distance_dtype, @@ -231,12 +236,16 @@ class ivf_pq_test : public ::testing::TestWithParam { update_host(indices_ivf_pq.data(), indices_ivf_pq_dev.data(), queries_size, stream_); handle_.sync_stream(stream_); - // Using very dense, small codebooks results in large errors in the distance calculation - double low_precision_factor = - static_cast(index.pq_dim() * index.pq_bits()) / static_cast(ps.dim * 8); // A very conservative lower bound on recall - double min_recall = low_precision_factor * static_cast(ps.search_params.n_probes) / - static_cast(ps.index_params.n_lists); + double min_recall = + static_cast(ps.search_params.n_probes) / static_cast(ps.index_params.n_lists); + double low_precision_factor = + static_cast(ps.dim * 8) / static_cast(index.pq_dim() * index.pq_bits()); + // Using a heuristic to lower the required recall due to code-packing errors + min_recall = + std::min(std::erfc(0.05 * low_precision_factor / std::max(min_recall, 0.5)), min_recall); + // Use explicit per-test min recall value if provided. + min_recall = ps.min_recall.value_or(min_recall); ASSERT_TRUE(eval_neighbours(indices_ref, indices_ivf_pq, @@ -244,8 +253,9 @@ class ivf_pq_test : public ::testing::TestWithParam { distances_ivf_pq, ps.num_queries, ps.k, - 0.001 / low_precision_factor, - min_recall)); + 0.0001 * low_precision_factor, + min_recall)) + << ps; // Test a few extra invariants IdxT min_results = min_output_size(handle_, index, ps.search_params.n_probes); @@ -350,9 +360,16 @@ inline auto small_dims_per_cluster() -> test_cases_t inline auto big_dims() -> test_cases_t { - return with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144}); - // return with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144, 8192, 12288, - // 16384}); + // with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144, 8192, 12288, 16384}); + auto xs = with_dims({512, 513, 1023, 1024, 1025, 2048, 2049, 2050, 2053, 6144}); + return map(xs, [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + uint32_t pq_len = 2; + y.index_params.pq_dim = div_rounding_up_safe(x.dim, pq_len); + // This comes from pure experimentation, also the recall depens a lot on pq_len. + y.min_recall = 0.48 + 0.028 * std::log2(x.dim); + return y; + }); } /** These will surely trigger no-smem-lut kernel. */ @@ -360,8 +377,11 @@ inline auto big_dims_moderate_lut() -> test_cases_t { return map(big_dims(), [](const ivf_pq_inputs& x) { ivf_pq_inputs y(x); + uint32_t pq_len = 2; + y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u); y.index_params.pq_bits = 6; y.search_params.lut_dtype = CUDA_R_16F; + y.min_recall = 0.69; return y; }); } @@ -371,9 +391,11 @@ inline auto big_dims_small_lut() -> test_cases_t { return map(big_dims(), [](const ivf_pq_inputs& x) { ivf_pq_inputs y(x); - y.index_params.pq_dim = raft::round_up_safe(y.dim / 8u, 64u); + uint32_t pq_len = 8; + y.index_params.pq_dim = round_up_safe(div_rounding_up_safe(x.dim, pq_len), 4u); y.index_params.pq_bits = 6; y.search_params.lut_dtype = CUDA_R_8U; + y.min_recall = 0.21; return y; }); } @@ -390,30 +412,68 @@ inline auto enum_variety() -> test_cases_t ([](ivf_pq_inputs & x) f)(xs[xs.size() - 1]); \ } while (0); - ADD_CASE({ x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; }); - ADD_CASE({ x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; }); + ADD_CASE({ + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_SUBSPACE; + x.min_recall = 0.86; + }); ADD_CASE({ x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; x.index_params.pq_bits = 4; + x.min_recall = 0.79; }); ADD_CASE({ x.index_params.codebook_kind = ivf_pq::codebook_gen::PER_CLUSTER; x.index_params.pq_bits = 5; + x.min_recall = 0.83; }); - ADD_CASE({ x.index_params.pq_bits = 6; }); - ADD_CASE({ x.index_params.pq_bits = 7; }); - ADD_CASE({ x.index_params.pq_bits = 8; }); + ADD_CASE({ + x.index_params.pq_bits = 6; + x.min_recall = 0.84; + }); + ADD_CASE({ + x.index_params.pq_bits = 7; + x.min_recall = 0.85; + }); + ADD_CASE({ + x.index_params.pq_bits = 8; + x.min_recall = 0.86; + }); - ADD_CASE({ x.index_params.force_random_rotation = true; }); - ADD_CASE({ x.index_params.force_random_rotation = false; }); + ADD_CASE({ + x.index_params.force_random_rotation = true; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.index_params.force_random_rotation = false; + x.min_recall = 0.86; + }); - ADD_CASE({ x.search_params.lut_dtype = CUDA_R_32F; }); - ADD_CASE({ x.search_params.lut_dtype = CUDA_R_16F; }); - ADD_CASE({ x.search_params.lut_dtype = CUDA_R_8U; }); + ADD_CASE({ + x.search_params.lut_dtype = CUDA_R_32F; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.search_params.lut_dtype = CUDA_R_16F; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.search_params.lut_dtype = CUDA_R_8U; + x.min_recall = 0.85; + }); - ADD_CASE({ x.search_params.internal_distance_dtype = CUDA_R_32F; }); - ADD_CASE({ x.search_params.internal_distance_dtype = CUDA_R_16F; }); + ADD_CASE({ + x.search_params.internal_distance_dtype = CUDA_R_32F; + x.min_recall = 0.86; + }); + ADD_CASE({ + x.search_params.internal_distance_dtype = CUDA_R_16F; + x.min_recall = 0.86; + }); return xs; } @@ -431,6 +491,14 @@ inline auto enum_variety_ip() -> test_cases_t { return map(enum_variety(), [](const ivf_pq_inputs& x) { ivf_pq_inputs y(x); + if (y.min_recall.has_value()) { + if (y.search_params.lut_dtype == CUDA_R_8U) { + // InnerProduct score is signed, + // thus we're forced to used signed 8-bit representation, + // thus we have one bit less precision + y.min_recall = y.min_recall.value() * 0.95; + } + } y.index_params.metric = distance::DistanceType::InnerProduct; return y; }); diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 05fe6ab92d..b88b6abd9e 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -232,7 +232,14 @@ auto eval_neighbours(const std::vector& expected_idx, } } double actual_recall = static_cast(match_count) / static_cast(total_count); - RAFT_LOG_INFO("Recall = %f (%zu/%zu)", actual_recall, match_count, total_count); + double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); + RAFT_LOG_INFO("Recall = %f (%zu/%zu), the error is %2.1f%% %s the threshold (eps = %f).", + actual_recall, + match_count, + total_count, + std::abs(error_margin * 100.0), + error_margin < 0 ? "above" : "below", + eps); if (actual_recall < min_recall - eps) { if (actual_recall < min_recall * min_recall - eps) { RAFT_LOG_ERROR("Recall is much lower than the minimum (%f < %f)", actual_recall, min_recall);