diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu index 324d3aef84..6364ab17da 100644 --- a/cpp/bench/prims/matrix/select_k.cu +++ b/cpp/bench/prims/matrix/select_k.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,7 +52,7 @@ struct replace_with_mask { } }; -template +template struct selection : public fixture { explicit selection(const select::params& p) : fixture(p.use_memory_pool), @@ -110,16 +110,24 @@ struct selection : public fixture { int iter = 0; loop_on_state(state, [&iter, this]() { common::nvtx::range lap_scope("lap-", iter++); - select::select_k_impl(handle, - Algo, - in_dists_.data(), - params_.use_index_input ? in_ids_.data() : NULL, - params_.batch_size, - params_.len, - params_.k, - out_dists_.data(), - out_ids_.data(), - params_.select_min); + + std::optional> in_ids_view; + if (params_.use_index_input) { + in_ids_view = raft::make_device_matrix_view( + in_ids_.data(), params_.batch_size, params_.len); + } + + matrix::select_k(handle, + raft::make_device_matrix_view( + in_dists_.data(), params_.batch_size, params_.len), + in_ids_view, + raft::make_device_matrix_view( + out_dists_.data(), params_.batch_size, params_.k), + raft::make_device_matrix_view( + out_ids_.data(), params_.batch_size, params_.k), + params_.select_min, + false, + Algo); }); } catch (raft::exception& e) { state.SkipWithError(e.what()); @@ -213,13 +221,13 @@ const std::vector kInputs{ {1000, 10000, 256, true, false, false, true, 0.999}, }; -#define SELECTION_REGISTER(KeyT, IdxT, A) \ - namespace BENCHMARK_PRIVATE_NAME(selection) { \ - using SelectK = selection; \ - RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ } -SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT +SELECTION_REGISTER(float, uint32_t, kAuto); // NOLINT SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT @@ -252,7 +260,7 @@ SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT // register other benchmarks #define SELECTION_REGISTER_ALGO_INPUT(KeyT, IdxT, A, input) \ { \ - using SelectK = selection; \ + using SelectK = selection; \ std::stringstream name; \ name << "SelectKDataset/" << #KeyT "/" #IdxT "/" #A << "/" << input.batch_size << "/" \ << input.len << "/" << input.k << "/" << input.use_index_input << "/" \ diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 870f0c3240..dfdbfa2d07 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include // uint32_t #include // __half #include +#include #include // RAFT_EXPLICIT #include // rmm:cuda_stream_view #include // rmm::mr::device_memory_resource @@ -38,7 +39,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, bool select_min, rmm::mr::device_memory_resource* mr = nullptr, - bool sorted = false) RAFT_EXPLICIT; + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -54,7 +56,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 63aeff2f1c..0a6f292e68 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -1,5 +1,6 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +24,7 @@ #include #include #include +#include #include #include @@ -31,10 +33,6 @@ namespace raft::matrix::detail { -// this is a subset of algorithms, chosen by running the algorithm_selection -// notebook in cpp/scripts/heuristics/select_k -enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bitsExtraPass }; - /** * Predict the fastest select_k algorithm based on the number of rows/cols/k * @@ -47,31 +45,31 @@ enum class Algo { kRadix11bits, kWarpDistributedShm, kWarpImmediate, kRadix11bit * 'generate_heuristic' notebook there will replace the body of this function * with the latest learned heuristic */ -inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k) +inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k) { if (k > 256) { if (cols > 16862) { if (rows > 1020) { - return Algo::kRadix11bitsExtraPass; + return SelectAlgo::kRadix11bitsExtraPass; } else { - return Algo::kRadix11bits; + return SelectAlgo::kRadix11bits; } } else { - return Algo::kRadix11bitsExtraPass; + return SelectAlgo::kRadix11bitsExtraPass; } } else { if (k > 2) { if (cols > 22061) { - return Algo::kWarpDistributedShm; + return SelectAlgo::kWarpDistributedShm; } else { if (rows > 198) { - return Algo::kWarpDistributedShm; + return SelectAlgo::kWarpDistributedShm; } else { - return Algo::kWarpImmediate; + return SelectAlgo::kWarpImmediate; } } } else { - return Algo::kWarpImmediate; + return SelectAlgo::kWarpImmediate; } } } @@ -239,31 +237,48 @@ void select_k(raft::resources const& handle, IdxT* out_idx, bool select_min, rmm::mr::device_memory_resource* mr = nullptr, - bool sorted = false) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - auto stream = raft::resource::get_cuda_stream(handle); - auto algo = choose_select_k_algorithm(batch_size, len, k); + if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); } + + auto stream = raft::resource::get_cuda_stream(handle); switch (algo) { - case Algo::kRadix11bits: - case Algo::kRadix11bitsExtraPass: { - bool fused_last_filter = algo == Algo::kRadix11bits; - detail::select::radix::select_k(in_val, - in_idx, - batch_size, - len, - k, - out_val, - out_idx, - select_min, - fused_last_filter, - stream, - mr); + case SelectAlgo::kRadix8bits: + case SelectAlgo::kRadix11bits: + case SelectAlgo::kRadix11bitsExtraPass: { + if (algo == SelectAlgo::kRadix8bits) { + detail::select::radix::select_k(in_val, + in_idx, + batch_size, + len, + k, + out_val, + out_idx, + select_min, + true, // fused_last_filter + stream, + mr); + } else { + bool fused_last_filter = algo == SelectAlgo::kRadix11bits; + detail::select::radix::select_k(in_val, + in_idx, + batch_size, + len, + k, + out_val, + out_idx, + select_min, + fused_last_filter, + stream, + mr); + } if (sorted) { auto offsets = raft::make_device_vector(handle, (IdxT)(batch_size + 1)); @@ -283,14 +298,25 @@ void select_k(raft::resources const& handle, } return; } - case Algo::kWarpDistributedShm: + case SelectAlgo::kWarpDistributed: + return detail::select::warpsort:: + select_k_impl( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + case SelectAlgo::kWarpDistributedShm: return detail::select::warpsort:: select_k_impl( in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); - case Algo::kWarpImmediate: + case SelectAlgo::kWarpAuto: + return detail::select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + case SelectAlgo::kWarpImmediate: return detail::select::warpsort:: select_k_impl( in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + case SelectAlgo::kWarpFiltered: + return detail::select::warpsort:: + select_k_impl( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); default: RAFT_FAIL("K-selection Algorithm not supported."); } } diff --git a/cpp/include/raft/matrix/select_k.cuh b/cpp/include/raft/matrix/select_k.cuh index 37a36cbf6b..92d7db006d 100644 --- a/cpp/include/raft/matrix/select_k.cuh +++ b/cpp/include/raft/matrix/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -76,6 +77,8 @@ namespace raft::matrix { * whether to select k smallest (true) or largest (false) keys. * @param[in] sorted * whether to make sure selected pairs are sorted by value + * @param[in] algo + * the selection algorithm to use */ template void select_k(raft::resources const& handle, @@ -84,7 +87,8 @@ void select_k(raft::resources const& handle, raft::device_matrix_view out_val, raft::device_matrix_view out_idx, bool select_min, - bool sorted = false) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) { RAFT_EXPECTS(out_val.extent(1) <= int64_t(std::numeric_limits::max()), "output k must fit the int type."); @@ -109,7 +113,8 @@ void select_k(raft::resources const& handle, out_idx.data_handle(), select_min, nullptr, - sorted); + sorted, + algo); } /** @} */ // end of group select_k diff --git a/cpp/include/raft/matrix/select_k_types.hpp b/cpp/include/raft/matrix/select_k_types.hpp new file mode 100644 index 0000000000..f001f91770 --- /dev/null +++ b/cpp/include/raft/matrix/select_k_types.hpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft::matrix { + +/** + * @defgroup select_k Batched-select k smallest or largest key/values + * @{ + */ + +/** + * @brief Algorithm used to select the k largest neighbors + * + * Details about how the the select-k algorithms in RAFT work can be found in the + * paper "Parallel Top-K Algorithms on GPU: A Comprehensive Study and New Methods" + * https://doi.org/10.1145/3581784.3607062. The kRadix* variants below correspond + * to the 'Air Top-k' algorithm described in the paper, and the kWarp* variants + * correspond to the 'GridSelect' algorithm. + */ +enum class SelectAlgo : uint8_t { + /** Automatically pick the select-k algorithm based off the input dimensions and k value */ + kAuto = 0, + /** Radix Select using 8 bits per pass */ + kRadix8bits = 1, + /** Radix Select using 11 bits per pass, fusing the last filter step */ + kRadix11bits = 2, + /** Radix Select using 11 bits per pass, without fusing the last filter step */ + kRadix11bitsExtraPass = 3, + /** + * Automatically switches between the kWarpImmediate and kWarpFiltered algorithms + * based off of input size + */ + kWarpAuto = 4, + /** + * This version of warp_sort adds every input element into the intermediate sorting + * buffer, and thus does the sorting step every `Capacity` input elements. + * + * This implementation is preferred for very small len values. + */ + kWarpImmediate = 5, + /** + * This version of warp_sort compares each input element against the current + * estimate of k-th value before adding it to the intermediate sorting buffer. + * This makes the algorithm do less sorting steps for long input sequences + * at the cost of extra checks on each step. + * + * This implementation is preferred for large len values. + */ + kWarpFiltered = 6, + /** + * This version of warp_sort compares each input element against the current + * estimate of k-th value before adding it to the intermediate sorting buffer. + * In contrast to `warp_sort_filtered`, it keeps one distributed buffer for + * all threads in a warp (independently of the subwarp size), which makes its flushing less often. + */ + kWarpDistributed = 7, + /** + * The same as `warp_sort_distributed`, but keeps the temporary value and index buffers + * in the given external pointers (normally, a shared memory pointer should be passed in). + */ + kWarpDistributedShm = 8, +}; + +inline auto operator<<(std::ostream& os, const SelectAlgo& algo) -> std::ostream& +{ + auto underlying_value = static_cast::type>(algo); + + switch (algo) { + case SelectAlgo::kAuto: return os << "kAuto=" << underlying_value; + case SelectAlgo::kRadix8bits: return os << "kRadix8bits=" << underlying_value; + case SelectAlgo::kRadix11bits: return os << "kRadix11bits=" << underlying_value; + case SelectAlgo::kRadix11bitsExtraPass: + return os << "kRadix11bitsExtraPass=" << underlying_value; + case SelectAlgo::kWarpAuto: return os << "kWarpAuto=" << underlying_value; + case SelectAlgo::kWarpImmediate: return os << "kWarpImmediate=" << underlying_value; + case SelectAlgo::kWarpFiltered: return os << "kWarpFiltered=" << underlying_value; + case SelectAlgo::kWarpDistributed: return os << "kWarpDistributed=" << underlying_value; + case SelectAlgo::kWarpDistributedShm: return os << "kWarpDistributedShm=" << underlying_value; + default: throw std::invalid_argument("invalid value for SelectAlgo"); + } +} + +/** @} */ // end of group select_k + +} // namespace raft::matrix diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index 93095ff82e..b899978f1c 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,123 +47,4 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& os << "}"; return os; } - -enum class Algo { - kPublicApi, - kRadix8bits, - kRadix11bits, - kRadix11bitsExtraPass, - kWarpAuto, - kWarpImmediate, - kWarpFiltered, - kWarpDistributed, - kWarpDistributedShm, -}; - -inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream& -{ - switch (algo) { - case Algo::kPublicApi: return os << "kPublicApi"; - case Algo::kRadix8bits: return os << "kRadix8bits"; - case Algo::kRadix11bits: return os << "kRadix11bits"; - case Algo::kRadix11bitsExtraPass: return os << "kRadix11bitsExtraPass"; - case Algo::kWarpAuto: return os << "kWarpAuto"; - case Algo::kWarpImmediate: return os << "kWarpImmediate"; - case Algo::kWarpFiltered: return os << "kWarpFiltered"; - case Algo::kWarpDistributed: return os << "kWarpDistributed"; - case Algo::kWarpDistributedShm: return os << "kWarpDistributedShm"; - default: return os << "unknown enum value"; - } -} - -template -void select_k_impl(const resources& handle, - const Algo& algo, - const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min) -{ - auto stream = resource::get_cuda_stream(handle); - switch (algo) { - case Algo::kPublicApi: { - auto in_extent = make_extents(batch_size, len); - auto out_extent = make_extents(batch_size, k); - auto in_span = make_mdspan(in, in_extent); - auto in_idx_span = - make_mdspan(in_idx, in_extent); - auto out_span = make_mdspan(out, out_extent); - auto out_idx_span = make_mdspan(out_idx, out_extent); - if (in_idx == nullptr) { - // NB: std::nullopt prevents automatic inference of the template parameters. - return matrix::select_k( - handle, in_span, std::nullopt, out_span, out_idx_span, select_min, true); - } else { - return matrix::select_k(handle, - in_span, - std::make_optional(in_idx_span), - out_span, - out_idx_span, - select_min, - true); - } - } - case Algo::kRadix8bits: - return detail::select::radix::select_k(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - true, // fused_last_filter - stream); - case Algo::kRadix11bits: - return detail::select::radix::select_k(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - true, // fused_last_filter - stream); - case Algo::kRadix11bitsExtraPass: - return detail::select::radix::select_k(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - false, // fused_last_filter - stream); - case Algo::kWarpAuto: - return detail::select::warpsort::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kWarpImmediate: - return detail::select::warpsort:: - select_k_impl( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kWarpFiltered: - return detail::select::warpsort:: - select_k_impl( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kWarpDistributed: - return detail::select::warpsort:: - select_k_impl( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - case Algo::kWarpDistributedShm: - return detail::select::warpsort:: - select_k_impl( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); - } -} } // namespace raft::matrix::select diff --git a/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb b/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb index 50bc12556a..f764d2f88f 100644 --- a/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb +++ b/cpp/scripts/heuristics/select_k/generate_heuristic.ipynb @@ -405,31 +405,31 @@ "name": "stdout", "output_type": "stream", "text": [ - "inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)\n", + "inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k)\n", "{\n", " if (k > 256) {\n", " if (cols > 16862) {\n", " if (rows > 1020) {\n", - " return Algo::kRadix11bitsExtraPass;\n", + " return SelectAlgo::kRadix11bitsExtraPass;\n", " } else {\n", - " return Algo::kRadix11bits;\n", + " return SelectAlgo::kRadix11bits;\n", " }\n", " } else {\n", - " return Algo::kRadix11bitsExtraPass;\n", + " return SelectAlgo::kRadix11bitsExtraPass;\n", " }\n", " } else {\n", " if (k > 2) {\n", " if (cols > 22061) {\n", - " return Algo::kWarpDistributedShm;\n", + " return SelectAlgo::kWarpDistributedShm;\n", " } else {\n", " if (rows > 198) {\n", - " return Algo::kWarpDistributedShm;\n", + " return SelectAlgo::kWarpDistributedShm;\n", " } else {\n", - " return Algo::kWarpImmediate;\n", + " return SelectAlgo::kWarpImmediate;\n", " }\n", " }\n", " } else {\n", - " return Algo::kWarpImmediate;\n", + " return SelectAlgo::kWarpImmediate;\n", " }\n", " }\n", "}\n" @@ -466,7 +466,7 @@ " if _is_leaf_node(nodeid):\n", " # we're a leaf node, just output the label of the most frequent algorithm\n", " class_name = _get_label(nodeid)\n", - " code.append(\" \" * indent + f\"return Algo::{class_name};\")\n", + " code.append(\" \" * indent + f\"return SelectAlgo::{class_name};\")\n", " else: \n", " feature = feature_names[tree.feature[nodeid]]\n", " threshold = int(np.floor(tree.threshold[nodeid]))\n", @@ -476,7 +476,7 @@ " _convert_node(tree.children_left[nodeid], indent + 2)\n", " code.append(\" \" * indent + \"}\")\n", " \n", - " code.append(\"inline Algo choose_select_k_algorithm(size_t rows, size_t cols, int k)\")\n", + " code.append(\"inline SelectAlgo choose_select_k_algorithm(size_t rows, size_t cols, int k)\")\n", " code.append(\"{\")\n", " _convert_node(0, indent=2)\n", " code.append(\"}\")\n", diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index c75a5b5261..87e5d49d29 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 171c8a1ae7..67dce0e166 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index a21444dc0c..4be7c54839 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 9542874ec0..6337994e86 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index fbf311d9bd..ad26547812 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index fdbfd66c46..e3c29a2033 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index 48a3e91f9d..3e3a738915 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ rmm::mr::device_memory_resource* mr, \ - bool sorted) + bool sorted, \ + raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(__half, uint32_t); diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index ce4e3e867e..f3eb32b2e1 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -93,56 +93,56 @@ auto inputs_random_many_infs = select::params{1000, 10000, 256, true, false, false, true, 0.999}); using ReferencedRandomFloatInt = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatInt, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomFloatInt, testing::Combine(inputs_random_longlist, - testing::Values(select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass, - select::Algo::kWarpImmediate, - select::Algo::kWarpFiltered, - select::Algo::kWarpDistributed, - select::Algo::kWarpDistributedShm))); + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed, + SelectAlgo::kWarpDistributedShm))); using ReferencedRandomDoubleSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomDoubleSizeT, testing::Combine(inputs_random_longlist, - testing::Values(select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass, - select::Algo::kWarpImmediate, - select::Algo::kWarpFiltered, - select::Algo::kWarpDistributed, - select::Algo::kWarpDistributedShm))); + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed, + SelectAlgo::kWarpDistributedShm))); using ReferencedRandomDoubleInt = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomDoubleInt, testing::Combine(inputs_random_largesize, - testing::Values(select::Algo::kWarpAuto, - select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass))); + testing::Values(SelectAlgo::kWarpAuto, + SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); using ReferencedRandomFloatIntkWarpsortAsGT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatIntkWarpsortAsGT, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomFloatIntkWarpsortAsGT, testing::Combine(inputs_random_many_infs, - testing::Values(select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass))); + testing::Values(SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); } // namespace raft::matrix diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh index fdea982d6c..412a9ae5a2 100644 --- a/cpp/test/matrix/select_k.cuh +++ b/cpp/test/matrix/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,8 +49,8 @@ auto gen_simple_ids(uint32_t batch_size, uint32_t len) -> std::vector template struct io_simple { public: - bool not_supported = false; - std::optional algo = std::nullopt; + bool not_supported = false; + std::optional algo = std::nullopt; io_simple(const select::params& spec, const std::vector& in_dists, @@ -80,10 +80,10 @@ template struct io_computed { public: bool not_supported = false; - select::Algo algo; + SelectAlgo algo; io_computed(const select::params& spec, - const select::Algo& algo, + const SelectAlgo& algo, const std::vector& in_dists, const std::optional>& in_ids = std::nullopt) : algo(algo), @@ -94,11 +94,11 @@ struct io_computed { { // check if the size is supported by the algorithm switch (algo) { - case select::Algo::kWarpAuto: - case select::Algo::kWarpImmediate: - case select::Algo::kWarpFiltered: - case select::Algo::kWarpDistributed: - case select::Algo::kWarpDistributedShm: { + case SelectAlgo::kWarpAuto: + case SelectAlgo::kWarpImmediate: + case SelectAlgo::kWarpFiltered: + case SelectAlgo::kWarpDistributed: + case SelectAlgo::kWarpDistributedShm: { if (spec.k > raft::matrix::detail::select::warpsort::kMaxCapacity) { not_supported = true; return; @@ -118,16 +118,22 @@ struct io_computed { update_device(in_dists_d.data(), in_dists_.data(), in_dists_.size(), stream); update_device(in_ids_d.data(), in_ids_.data(), in_ids_.size(), stream); - select::select_k_impl(handle, - algo, - in_dists_d.data(), - spec.use_index_input ? in_ids_d.data() : nullptr, - spec.batch_size, - spec.len, - spec.k, - out_dists_d.data(), - out_ids_d.data(), - spec.select_min); + std::optional> in_ids_view; + if (spec.use_index_input) { + in_ids_view = raft::make_device_matrix_view( + in_ids_d.data(), spec.batch_size, spec.len); + } + + matrix::select_k( + handle, + raft::make_device_matrix_view( + in_dists_d.data(), spec.batch_size, spec.len), + in_ids_view, + raft::make_device_matrix_view(out_dists_d.data(), spec.batch_size, spec.k), + raft::make_device_matrix_view(out_ids_d.data(), spec.batch_size, spec.k), + spec.select_min, + false, + algo); update_host(out_dists_.data(), out_dists_d.data(), out_dists_.size(), stream); update_host(out_ids_.data(), out_ids_d.data(), out_ids_.size(), stream); @@ -194,13 +200,13 @@ struct io_computed { }; template -using Params = std::tuple; +using Params = std::tuple; template typename ParamsReader> struct SelectK // NOLINT : public testing::TestWithParam::params_t> { const select::params spec; - const select::Algo algo; + const SelectAlgo algo; typename ParamsReader::io_t ref; io_computed res; @@ -255,18 +261,18 @@ struct SelectK // NOLINT ASSERT_TRUE(hostVecMatch(ref.get_out_ids(), res.get_out_ids(), compare_ids)); } - auto forgive_algo(const std::optional& algo, IdxT ix) const -> bool + auto forgive_algo(const std::optional& algo, IdxT ix) const -> bool { if (!algo.has_value()) { return false; } switch (algo.value()) { // not sure which algo this is. - case select::Algo::kPublicApi: return true; + case SelectAlgo::kAuto: return true; // warp-sort-based algos currently return zero index for inf distances. - case select::Algo::kWarpAuto: - case select::Algo::kWarpImmediate: - case select::Algo::kWarpFiltered: - case select::Algo::kWarpDistributed: - case select::Algo::kWarpDistributedShm: return ix == 0; + case SelectAlgo::kWarpAuto: + case SelectAlgo::kWarpImmediate: + case SelectAlgo::kWarpFiltered: + case SelectAlgo::kWarpDistributed: + case SelectAlgo::kWarpDistributedShm: return ix == 0; // Do not forgive by default default: return false; } @@ -281,7 +287,7 @@ struct params_simple { std::optional>, std::vector, std::vector>; - using params_t = std::tuple; + using params_t = std::tuple; static auto read(params_t ps) -> Params { @@ -387,13 +393,13 @@ INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, SimpleFloatInt, testing::Combine(inputs_simple_f, - testing::Values(select::Algo::kPublicApi, - select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass, - select::Algo::kWarpImmediate, - select::Algo::kWarpFiltered, - select::Algo::kWarpDistributed))); + testing::Values(SelectAlgo::kAuto, + SelectAlgo::kRadix8bits, + SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass, + SelectAlgo::kWarpImmediate, + SelectAlgo::kWarpFiltered, + SelectAlgo::kWarpDistributed))); template struct replace_with_mask { @@ -401,12 +407,12 @@ struct replace_with_mask { constexpr auto inline operator()(KeyT x, uint8_t mask) -> KeyT { return mask ? replacement : x; } }; -template +template struct with_ref { template struct params_random { using io_t = io_computed; - using params_t = std::tuple; + using params_t = std::tuple; static auto read(params_t ps) -> Params { diff --git a/cpp/test/matrix/select_large_k.cu b/cpp/test/matrix/select_large_k.cu index 2772e84eb3..baa07f5e87 100644 --- a/cpp/test/matrix/select_large_k.cu +++ b/cpp/test/matrix/select_large_k.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,12 +25,12 @@ auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, tr select::params{100, 100000, 1237, true}); using ReferencedRandomFloatSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, testing::Combine(inputs_random_largek, - testing::Values(select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass))); + testing::Values(SelectAlgo::kRadix11bits, + SelectAlgo::kRadix11bitsExtraPass))); } // namespace raft::matrix