Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ANN tests: make the min_recall check strict #1156

Merged
Merged
13 changes: 8 additions & 5 deletions cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {
protected:
void gen_data()
{
database.resize(ps.num_db_vecs * ps.dim, stream_);
search_queries.resize(ps.num_queries * ps.dim, stream_);
database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_);
search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_);

raft::random::Rng r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
Expand All @@ -155,7 +155,7 @@ class ivf_pq_test : public ::testing::TestWithParam<ivf_pq_inputs> {

void calc_ref()
{
size_t queries_size = ps.num_queries * ps.k;
size_t queries_size = size_t{ps.num_queries} * size_t{ps.k};
rmm::device_uvector<EvalT> distances_naive_dev(queries_size, stream_);
rmm::device_uvector<IdxT> indices_naive_dev(queries_size, stream_);
naiveBfKnn<EvalT, DataT, IdxT>(distances_naive_dev.data(),
Expand Down Expand Up @@ -463,7 +463,7 @@ inline auto enum_variety() -> test_cases_t
});
ADD_CASE({
x.search_params.lut_dtype = CUDA_R_8U;
x.min_recall = 0.85;
x.min_recall = 0.84;
});

ADD_CASE({
Expand Down Expand Up @@ -496,7 +496,10 @@ inline auto enum_variety_ip() -> test_cases_t
// InnerProduct score is signed,
// thus we're forced to used signed 8-bit representation,
// thus we have one bit less precision
y.min_recall = y.min_recall.value() * 0.95;
y.min_recall = y.min_recall.value() * 0.90;
} else {
// In other cases it seems to perform a little bit better, still worse than L2
y.min_recall = y.min_recall.value() * 0.94;
}
}
y.index_params.metric = distance::DistanceType::InnerProduct;
Expand Down
52 changes: 28 additions & 24 deletions cpp/test/neighbors/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -110,28 +110,39 @@ __global__ void naive_distance_kernel(EvalT* dist,
IdxT m,
IdxT n,
IdxT k,
raft::distance::DistanceType type)
raft::distance::DistanceType metric)
{
IdxT midx = threadIdx.x + blockIdx.x * blockDim.x;
IdxT midx = IdxT(threadIdx.x) + IdxT(blockIdx.x) * IdxT(blockDim.x);
if (midx >= m) return;
for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n;
nidx += blockDim.y * gridDim.y) {
IdxT grid_size = IdxT(blockDim.y) * IdxT(gridDim.y);
for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n; nidx += grid_size) {
EvalT acc = EvalT(0);
for (IdxT i = 0; i < k; ++i) {
IdxT xidx = i + midx * k;
IdxT yidx = i + nidx * k;
EvalT xv = (EvalT)x[xidx];
EvalT yv = (EvalT)y[yidx];
if (type == raft::distance::DistanceType::InnerProduct) {
acc += xv * yv;
} else {
EvalT diff = xv - yv;
acc += diff * diff;
auto xv = EvalT(x[xidx]);
auto yv = EvalT(y[yidx]);
switch (metric) {
case raft::distance::DistanceType::InnerProduct: {
acc += xv * yv;
} break;
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2Unexpanded: {
auto diff = xv - yv;
acc += diff * diff;
} break;
default: break;
}
}
if (type == raft::distance::DistanceType::L2SqrtExpanded ||
type == raft::distance::DistanceType::L2SqrtUnexpanded)
acc = raft::sqrt(acc);
switch (metric) {
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded: {
acc = raft::sqrt(acc);
} break;
default: break;
}
dist[midx * n + nidx] = acc;
}
}
Expand Down Expand Up @@ -241,16 +252,9 @@ auto eval_neighbours(const std::vector<T>& expected_idx,
error_margin < 0 ? "above" : "below",
eps);
if (actual_recall < min_recall - eps) {
if (actual_recall < min_recall * min_recall - eps) {
RAFT_LOG_ERROR("Recall is much lower than the minimum (%f < %f)", actual_recall, min_recall);
} else {
RAFT_LOG_WARN("Recall is suspiciously too low (%f < %f)", actual_recall, min_recall);
}
if (match_count == 0 || actual_recall < min_recall * std::min(min_recall, 0.5) - eps) {
return testing::AssertionFailure()
<< "actual recall (" << actual_recall
<< ") is much smaller than the minimum expected recall (" << min_recall << ").";
}
return testing::AssertionFailure()
<< "actual recall (" << actual_recall << ") is lower than the minimum expected recall ("
<< min_recall << "); eps = " << eps << ". ";
}
return testing::AssertionSuccess();
}
Expand Down