Skip to content

Commit

Permalink
Fix failing TiledKNNTest unittest (#1533)
Browse files Browse the repository at this point in the history
The TiledKNNTest test was faiiling - and it seems to be because the matrix::select_k code isn't guaranteed to return elements in sorted order. The test was expecting outputs to be sorted, and was failing because of it. This change fixes the test to sort the outputs before comparing.

Closes #1526

Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1533
  • Loading branch information
benfred authored May 18, 2023
1 parent 650699b commit 29d1c15
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
21 changes: 18 additions & 3 deletions cpp/test/neighbors/knn_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ testing::AssertionResult devArrMatchKnnPair(const T* expected_idx,
size_t rows,
size_t cols,
const DistT eps,
cudaStream_t stream = 0)
cudaStream_t stream = 0,
bool sort_inputs = false)
{
size_t size = rows * cols;
std::unique_ptr<T[]> exp_idx_h(new T[size]);
Expand All @@ -57,16 +58,30 @@ testing::AssertionResult devArrMatchKnnPair(const T* expected_idx,
raft::update_host<T>(act_idx_h.get(), actual_idx, size, stream);
raft::update_host<DistT>(exp_dist_h.get(), expected_dist, size, stream);
raft::update_host<DistT>(act_dist_h.get(), actual_dist, size, stream);

RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
for (size_t i(0); i < rows; ++i) {
std::vector<std::pair<DistT, T>> actual;
std::vector<std::pair<DistT, T>> expected;
for (size_t j(0); j < cols; ++j) {
auto idx = i * cols + j; // row major assumption!
auto exp_idx = exp_idx_h.get()[idx];
auto act_idx = act_idx_h.get()[idx];
auto exp_dist = exp_dist_h.get()[idx];
auto act_dist = act_dist_h.get()[idx];
idx_dist_pair exp_kvp(exp_idx, exp_dist, raft::CompareApprox<DistT>(eps));
idx_dist_pair act_kvp(act_idx, act_dist, raft::CompareApprox<DistT>(eps));
actual.push_back(std::make_pair(act_dist, act_idx));
expected.push_back(std::make_pair(exp_dist, exp_idx));
}
if (sort_inputs) {
// inputs could be unsorted here, sort for comparison
std::sort(actual.begin(), actual.end());
std::sort(expected.begin(), expected.end());
}
for (size_t j(0); j < cols; ++j) {
auto act = actual[j];
auto exp = expected[j];
idx_dist_pair exp_kvp(exp.second, exp.first, raft::CompareApprox<DistT>(eps));
idx_dist_pair act_kvp(act.second, act.first, raft::CompareApprox<DistT>(eps));
if (!(exp_kvp == act_kvp)) {
return testing::AssertionFailure()
<< "actual=" << act_kvp.idx << "," << act_kvp.dist << "!="
Expand Down
3 changes: 2 additions & 1 deletion cpp/test/neighbors/tiled_knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ class TiledKNNTest : public ::testing::TestWithParam<TiledKNNInputs> {
num_queries,
k_,
float(0.001),
stream_));
stream_,
true));
}

void SetUp() override
Expand Down

0 comments on commit 29d1c15

Please sign in to comment.