diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index af5a5770fb..d5dc61eddf 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -327,7 +327,7 @@ void select_k(raft::resources const& handle, case Algo::kWarpDistributedShm: return detail::select::warpsort:: select_k_impl( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); case Algo::kFaissBlockSelect: return neighbors::detail::select_k( in_val, in_idx, batch_size, len, out_val, out_idx, select_min, k, stream); diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index dc86a04733..935a9a1d13 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -18,7 +18,9 @@ #include #include +#include #include +#include #include #include #include @@ -773,6 +775,11 @@ __launch_bounds__(256) __global__ queue.store(out + block_id * k, out_idx + block_id * k); } +struct launch_params { + int block_size = 0; + int min_grid_size = 0; +}; + template