Skip to content

Commit

Permalink
Launch neighborhood_recall kernel on CUDA stream (#2156)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2156
  • Loading branch information
divyegala authored Feb 6, 2024
1 parent 6328be7 commit 9f6af2f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cpp/include/raft/stats/detail/neighborhood_recall.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
#include <raft/core/math.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

#include <cub/cub.cuh>
Expand Down Expand Up @@ -108,7 +109,7 @@ void neighborhood_recall(
auto constexpr kThreadsPerBlock = 32;
auto const num_blocks = indices.extent(0);

neighborhood_recall<<<num_blocks, kThreadsPerBlock>>>(
neighborhood_recall<<<num_blocks, kThreadsPerBlock, 0, raft::resource::get_cuda_stream(res)>>>(
indices, ref_indices, distances, ref_distances, recall_score, eps);
}

Expand Down

0 comments on commit 9f6af2f

Please sign in to comment.