diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index dfdbfa2d07..e8db6827b5 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -38,25 +38,23 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - rmm::mr::device_memory_resource* mr = nullptr, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY -#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ - extern template void raft::matrix::detail::select_k(raft::resources const& handle, \ - const T* in_val, \ - const IdxT* in_idx, \ - size_t batch_size, \ - size_t len, \ - int k, \ - T* out_val, \ - IdxT* out_idx, \ - bool select_min, \ - rmm::mr::device_memory_resource* mr, \ - bool sorted, \ +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + extern template void raft::matrix::detail::select_k(raft::resources const& handle, \ + const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + bool sorted, \ raft::matrix::SelectAlgo algo) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 0a6f292e68..8f40e6ae00 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -23,13 +23,12 @@ #include #include #include -#include +#include +#include +#include #include -#include -#include -#include -#include +#include namespace raft::matrix::detail { @@ -95,15 +94,17 @@ void segmented_sort_by_key(raft::resources const& handle, const ValT* offsets, bool asc) { - auto stream = raft::resource::get_cuda_stream(handle); - auto out_inds = raft::make_device_vector(handle, n_elements); - auto out_dists = raft::make_device_vector(handle, n_elements); + auto stream = resource::get_cuda_stream(handle); + auto mr = resource::get_workspace_resource(handle); + auto out_inds = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_elements)); + auto out_dists = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_elements)); // Determine temporary device storage requirements - auto d_temp_storage = raft::make_device_vector(handle, 0); size_t temp_storage_bytes = 0; if (asc) { - cub::DeviceSegmentedRadixSort::SortPairs((void*)d_temp_storage.data_handle(), + cub::DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, keys, out_dists.data_handle(), @@ -117,7 +118,7 @@ void segmented_sort_by_key(raft::resources const& handle, sizeof(ValT) * 8, stream); } else { - cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)d_temp_storage.data_handle(), + cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, keys, out_dists.data_handle(), @@ -132,7 +133,8 @@ void segmented_sort_by_key(raft::resources const& handle, stream); } - d_temp_storage = raft::make_device_vector(handle, temp_storage_bytes); + auto d_temp_storage = raft::make_device_mdarray( + handle, mr, raft::make_extents(temp_storage_bytes)); if (asc) { // Run sorting operation @@ -201,6 +203,7 @@ void segmented_sort_by_key(raft::resources const& handle, * @tparam IdxT * the index type (what is being selected together with the keys). * + * @param[in] handle container of reusable resources * @param[in] in_val * contiguous device array of inputs of size (len * batch_size); * these are compared and selected. @@ -222,9 +225,10 @@ void segmented_sort_by_key(raft::resources const& handle, * the payload selected together with `out_val`. * @param select_min * whether to select k smallest (true) or largest (false) keys. - * @param stream - * @param mr an optional memory resource to use across the calls (you can provide a large enough - * memory pool here to avoid memory allocations within the call). + * @param[in] sorted + * whether to make sure selected pairs are sorted by value + * @param[in] algo + * the selection algorithm to use */ template void select_k(raft::resources const& handle, @@ -236,24 +240,21 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - rmm::mr::device_memory_resource* mr = nullptr, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); - if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - if (algo == SelectAlgo::kAuto) { algo = choose_select_k_algorithm(batch_size, len, k); } - auto stream = raft::resource::get_cuda_stream(handle); switch (algo) { case SelectAlgo::kRadix8bits: case SelectAlgo::kRadix11bits: case SelectAlgo::kRadix11bitsExtraPass: { if (algo == SelectAlgo::kRadix8bits) { - detail::select::radix::select_k(in_val, + detail::select::radix::select_k(handle, + in_val, in_idx, batch_size, len, @@ -261,13 +262,13 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - true, // fused_last_filter - stream, - mr); + true // fused_last_filter + ); } else { bool fused_last_filter = algo == SelectAlgo::kRadix11bits; - detail::select::radix::select_k(in_val, + detail::select::radix::select_k(handle, + in_val, in_idx, batch_size, len, @@ -275,20 +276,12 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - fused_last_filter, - stream, - mr); + fused_last_filter); } if (sorted) { - auto offsets = raft::make_device_vector(handle, (IdxT)(batch_size + 1)); - - raft::matrix::fill(handle, offsets.view(), (IdxT)k); - - thrust::exclusive_scan(raft::resource::get_thrust_policy(handle), - offsets.data_handle(), - offsets.data_handle() + offsets.size(), - offsets.data_handle(), - 0); + auto offsets = make_device_mdarray( + handle, resource::get_workspace_resource(handle), make_extents(batch_size + 1)); + raft::linalg::map_offset(handle, offsets.view(), mul_const_op(k)); auto keys = raft::make_device_vector_view(out_val, (IdxT)(batch_size * k)); auto vals = raft::make_device_vector_view(out_idx, (IdxT)(batch_size * k)); @@ -301,22 +294,22 @@ void select_k(raft::resources const& handle, case SelectAlgo::kWarpDistributed: return detail::select::warpsort:: select_k_impl( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); case SelectAlgo::kWarpDistributedShm: return detail::select::warpsort:: select_k_impl( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); case SelectAlgo::kWarpAuto: return detail::select::warpsort::select_k( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); case SelectAlgo::kWarpImmediate: return detail::select::warpsort:: select_k_impl( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); case SelectAlgo::kWarpFiltered: return detail::select::warpsort:: select_k_impl( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + handle, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min); default: RAFT_FAIL("K-selection Algorithm not supported."); } } diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index b6ed03b93d..16b9ac0c6d 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -19,6 +19,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -1157,6 +1160,7 @@ void radix_topk_one_block(const T* in, * @tparam BlockSize * Number of threads in a kernel thread block. * + * @param[in] res container of reusable resources * @param[in] in * contiguous device array of inputs of size (len * batch_size); * these are compared and selected. @@ -1184,12 +1188,10 @@ void radix_topk_one_block(const T* in, * blocks is called. The later case is preferable when leading bits of input data are almost the * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. - * @param stream - * @param mr an optional memory resource to use across the calls (you can provide a large enough - * memory pool here to avoid memory allocations within the call). */ template -void select_k(const T* in, +void select_k(raft::resources const& res, + const T* in, const IdxT* in_idx, int batch_size, IdxT len, @@ -1197,10 +1199,10 @@ void select_k(const T* in, T* out, IdxT* out_idx, bool select_min, - bool fused_last_filter, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) + bool fused_last_filter) { + auto stream = resource::get_cuda_stream(res); + auto mr = resource::get_workspace_resource(res); if (k == len) { RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); @@ -1210,21 +1212,12 @@ void select_k(const T* in, } else { auto out_idx_view = raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); - raft::resources handle; - resource::set_cuda_stream(handle, stream); - raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op(len)); + raft::linalg::map_offset(res, out_idx_view, raft::mod_const_op(len)); } return; } - // TODO: use device_resources::get_device_properties() instead; should change it when we refactor - // resource management - int sm_cnt; - { - int dev; - RAFT_CUDA_TRY(cudaGetDevice(&dev)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); - } + int sm_cnt = resource::get_device_properties(res).multiProcessorCount; constexpr int items_per_thread = 32; diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index 018eea2306..7cd43b030b 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,11 @@ #include #include +#include +#include +#include #include +#include #include #include #include @@ -773,6 +777,11 @@ __launch_bounds__(256) RAFT_KERNEL queue.store(out + block_id * k, out_idx + block_id * k); } +struct launch_params { + int block_size = 0; + int min_grid_size = 0; +}; + template