Skip to content

Commit

Permalink
Fix ANN bench ground truth generation for k>1024 (#2180)
Browse files Browse the repository at this point in the history
Generating ANN bench ground truth is affected by bug #2171, when k>1024. This PR fixes the issue for the ground truth generation.

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2180
  • Loading branch information
tfeher authored Mar 19, 2024
1 parent d14cac2 commit bd50c37
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
3 changes: 3 additions & 0 deletions cpp/include/raft/neighbors/detail/knn_merge_parts.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/error.hpp>
#include <raft/neighbors/detail/faiss_select/DistanceUtils.h>
#include <raft/neighbors/detail/faiss_select/Select.cuh>
#include <raft/util/cuda_utils.cuh>
Expand Down Expand Up @@ -168,5 +169,7 @@ inline void knn_merge_parts(const value_t* inK,
else if (k <= 1024)
knn_merge_parts_impl<value_idx, value_t, 1024, 8>(
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
else
THROW("Unimplemented for k=%d, knn_merge_parts works for k<=1024", k);
}
} // namespace raft::neighbors::detail
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
#
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -62,17 +62,12 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"):

X = cp.asarray(dataset[i : i + n_batch, :], cp.float32)

D, Ind = knn(
X,
queries,
k,
metric=metric,
handle=handle,
global_id_offset=i, # shift neighbor index by offset i
)
D, Ind = knn(X, queries, k, metric=metric, handle=handle)
handle.sync()

D, Ind = cp.asarray(D), cp.asarray(Ind)
Ind += i # shift neighbor index by offset i

if distances is None:
distances = D
indices = Ind
Expand Down

0 comments on commit bd50c37

Please sign in to comment.