From 0de9ece90043de236e0718444cd9ac4ad6c883a6 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 19 Oct 2022 20:03:24 +0200 Subject: [PATCH] Integrate `accumulate_into_selected` from ANN utils into `linalg::reduce_rows_by_keys` (#909) `accumulate_into_selected` achieves much better performance than the previous implementation of `reduce_rows_by_keys` for large `nkeys` (`sum_rows_by_key_large_nkeys_kernel_rowmajor`). According to the benchmark that I added for this primitive, the difference is a factor of 240x for sizes relevant to IVF-Flat (and a factor of ~10x for smaller `nkeys`, e.g 64). This is mostly because the legacy implementation, probably in an attempt to reduce atomic conflicts, assigned a key and a tile of the matrix to each block, and the block only reduces the rows corresponding to the assigned key. With a very large number of keys, e.g 1k, this results in blocks iterating over a large number of rows (possibly tens of thousands) and only reading and accumulating 1 in 1k rows. This PR: - Replaces `sum_rows_by_key_large_nkeys_rowmajor` with `accumulate_into_selected` (I didn't find any cases in which the old kernel performed better). - Removes `accumulate_into_selected` from `ann_utils.cuh`. - Fixes support for custom iterators in `reduce_rows_by_keys`. - Uses the raft prims in `calc_centers_and_sizes`. Perf notes: - The original kmeans gets a 15-20% speedup for large numbers of clusters. - The performance of `ivf_flat::build` stays the same as before. - There are a bunch of extra steps since I separated the cluster size count from the reduction by key, but they are quite neglectable in comparison. Question: the change breaks support for host-side-only arrays in `calc_centers_and_sizes`, is it actually a possibility? Should I add a branch and not use the raft prims when all arrays are host-side? cc @achirkin @tfeher @cjnolet Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/909 --- cpp/bench/CMakeLists.txt | 1 + cpp/bench/common/benchmark.hpp | 3 +- cpp/bench/linalg/reduce_rows_by_key.cu | 88 ++++++ .../raft/linalg/detail/reduce_rows_by_key.cuh | 280 ++++++++---------- .../raft/linalg/reduce_rows_by_key.cuh | 65 ++-- .../knn/detail/ann_kmeans_balanced.cuh | 114 +++++-- .../raft/spatial/knn/detail/ann_utils.cuh | 70 +---- .../spatial/knn/detail/ivf_flat_build.cuh | 3 +- cpp/test/linalg/reduce_rows_by_key.cu | 2 +- 9 files changed, 350 insertions(+), 276 deletions(-) create mode 100644 cpp/bench/linalg/reduce_rows_by_key.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 9c6b60d83b..e0f42d1803 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -95,6 +95,7 @@ if(BUILD_BENCH) bench/linalg/add.cu bench/linalg/map_then_reduce.cu bench/linalg/matrix_vector_op.cu + bench/linalg/reduce_rows_by_key.cu bench/linalg/reduce.cu bench/main.cpp ) diff --git a/cpp/bench/common/benchmark.hpp b/cpp/bench/common/benchmark.hpp index adfe5218e2..13ca40a033 100644 --- a/cpp/bench/common/benchmark.hpp +++ b/cpp/bench/common/benchmark.hpp @@ -53,8 +53,9 @@ struct using_pool_memory_res { rmm::mr::set_current_device_resource(&pool_res_); } - using_pool_memory_res() : using_pool_memory_res(size_t(1) << size_t(30), size_t(16) << size_t(30)) + using_pool_memory_res() : orig_res_(rmm::mr::get_current_device_resource()), pool_res_(&cuda_res_) { + rmm::mr::set_current_device_resource(&pool_res_); } ~using_pool_memory_res() { rmm::mr::set_current_device_resource(orig_res_); } diff --git a/cpp/bench/linalg/reduce_rows_by_key.cu b/cpp/bench/linalg/reduce_rows_by_key.cu new file mode 100644 index 0000000000..075bc7c8c4 --- /dev/null +++ b/cpp/bench/linalg/reduce_rows_by_key.cu @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include + +namespace raft::bench::linalg { + +struct rrbk_params { + int64_t rows, cols; + int64_t keys; +}; + +template +struct reduce_rows_by_key : public fixture { + reduce_rows_by_key(const rrbk_params& p) + : params(p), + in(p.rows * p.cols, stream), + out(p.keys * p.cols, stream), + keys(p.rows, stream), + workspace(p.rows, stream) + { + raft::random::RngState rng{42}; + raft::random::uniformInt(rng, keys.data(), p.rows, (KeyT)0, (KeyT)p.keys, stream); + } + + void run_benchmark(::benchmark::State& state) override + { + loop_on_state(state, [this]() { + raft::linalg::reduce_rows_by_key(in.data(), + params.cols, + keys.data(), + workspace.data(), + params.rows, + params.cols, + params.keys, + out.data(), + stream, + false); + }); + } + + protected: + rrbk_params params; + rmm::device_uvector in, out; + rmm::device_uvector keys; + rmm::device_uvector workspace; +}; // struct reduce_rows_by_key + +const std::vector kInputSizes{ + {10000, 128, 64}, + {100000, 128, 64}, + {1000000, 128, 64}, + {10000000, 128, 64}, + {10000, 128, 256}, + {100000, 128, 256}, + {1000000, 128, 256}, + {10000000, 128, 256}, + {10000, 128, 1024}, + {100000, 128, 1024}, + {1000000, 128, 1024}, + {10000000, 128, 1024}, + {10000, 128, 4096}, + {100000, 128, 4096}, + {1000000, 128, 4096}, + {10000000, 128, 4096}, +}; + +RAFT_BENCH_REGISTER((reduce_rows_by_key), "", kInputSizes); +RAFT_BENCH_REGISTER((reduce_rows_by_key), "", kInputSizes); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh index 9ddcbae20b..dc92271141 100644 --- a/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh @@ -92,48 +92,48 @@ struct quadSum { #define SUM_ROWS_SMALL_K_DIMX 256 #define SUM_ROWS_BY_KEY_SMALL_K_MAX_K 4 -template +template __launch_bounds__(SUM_ROWS_SMALL_K_DIMX, 4) - __global__ void sum_rows_by_key_small_nkeys_kernel(const DataIteratorT* d_A, - int lda, + __global__ void sum_rows_by_key_small_nkeys_kernel(const DataIteratorT d_A, + IdxT lda, const char* d_keys, const WeightT* d_weights, - int nrows, - int ncols, - int nkeys, - DataIteratorT* d_sums) + IdxT nrows, + IdxT ncols, + IdxT nkeys, + SumsT* d_sums) { - typedef typename std::iterator_traits::value_type DataType; - typedef cub::BlockReduce, SUM_ROWS_SMALL_K_DIMX> BlockReduce; + typedef typename std::iterator_traits::value_type DataType; + typedef cub::BlockReduce, SUM_ROWS_SMALL_K_DIMX> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - for (int idim = static_cast(blockIdx.y); idim < ncols; idim += gridDim.y) { - if (idim != static_cast(blockIdx.y)) __syncthreads(); // we're reusing temp_storage + for (IdxT idim = static_cast(blockIdx.y); idim < ncols; idim += gridDim.y) { + if (idim != static_cast(blockIdx.y)) __syncthreads(); // we're reusing temp_storage // threadIdx.x stores partial sum for current dim and key=threadIdx.x in this reg - quad thread_sums; + quad thread_sums; thread_sums.x = 0.0; thread_sums.y = 0.0; thread_sums.z = 0.0; thread_sums.w = 0.0; // May use vectorized load - not necessary for doubles - for (int block_offset_irow = blockIdx.x * blockDim.x; + for (IdxT block_offset_irow = blockIdx.x * blockDim.x; block_offset_irow < nrows; // we will syncthreads() inside the loop, no CTA divergence block_offset_irow += blockDim.x * gridDim.x) { - int irow = block_offset_irow + threadIdx.x; + IdxT irow = block_offset_irow + threadIdx.x; DataType val = (irow < nrows) ? d_A[irow * lda + idim] : 0.0; if (d_weights && irow < nrows) { val = val * d_weights[irow]; } // we are not reusing the keys - after profiling // d_keys is mainly loaded from L2, and this kernel is DRAM BW bounded // (experimentation gave a 10% speed up - not worth the many code lines added) - int row_key = (irow < nrows) ? d_keys[irow] : -1; + IdxT row_key = (irow < nrows) ? d_keys[irow] : std::numeric_limits::max(); - thread_sums.x += (row_key == 0) ? val : 0.0; - thread_sums.y += (row_key == 1) ? val : 0.0; - thread_sums.z += (row_key == 2) ? val : 0.0; - thread_sums.w += (row_key == 3) ? val : 0.0; + thread_sums.x += (row_key == 0) ? static_cast(val) : 0.0; + thread_sums.y += (row_key == 1) ? static_cast(val) : 0.0; + thread_sums.z += (row_key == 2) ? static_cast(val) : 0.0; + thread_sums.w += (row_key == 3) ? static_cast(val) : 0.0; } // End of column @@ -142,12 +142,12 @@ __launch_bounds__(SUM_ROWS_SMALL_K_DIMX, 4) // Strided access // Reducing by key - thread_sums = BlockReduce(temp_storage).Reduce(thread_sums, quadSum()); + thread_sums = BlockReduce(temp_storage).Reduce(thread_sums, quadSum()); if (threadIdx.x < 32) { // We only need 4 thread_sums = cub::ShuffleIndex<32>(thread_sums, 0, 0xffffffff); - if (static_cast(threadIdx.x) < nkeys) { + if (static_cast(threadIdx.x) < nkeys) { if (threadIdx.x == 0) raft::myAtomicAdd(&d_sums[threadIdx.x * ncols + idim], thread_sums.x); if (threadIdx.x == 1) raft::myAtomicAdd(&d_sums[threadIdx.x * ncols + idim], thread_sums.y); if (threadIdx.x == 2) raft::myAtomicAdd(&d_sums[threadIdx.x * ncols + idim], thread_sums.z); @@ -157,22 +157,22 @@ __launch_bounds__(SUM_ROWS_SMALL_K_DIMX, 4) } } -template -void sum_rows_by_key_small_nkeys(const DataIteratorT* d_A, - int lda, +template +void sum_rows_by_key_small_nkeys(const DataIteratorT d_A, + IdxT lda, const char* d_keys, const WeightT* d_weights, - int nrows, - int ncols, - int nkeys, - DataIteratorT* d_sums, + IdxT nrows, + IdxT ncols, + IdxT nkeys, + SumsT* d_sums, cudaStream_t st) { dim3 grid, block; block.x = SUM_ROWS_SMALL_K_DIMX; block.y = 1; // Necessary - grid.x = raft::ceildiv(nrows, (int)block.x); + grid.x = raft::ceildiv(nrows, (IdxT)block.x); grid.x = std::min(grid.x, 32u); grid.y = ncols; grid.y = std::min(grid.y, MAX_BLOCKS); @@ -188,45 +188,49 @@ void sum_rows_by_key_small_nkeys(const DataIteratorT* d_A, #define SUM_ROWS_BY_KEY_LARGE_K_MAX_K 1024 -template -__global__ void sum_rows_by_key_large_nkeys_kernel_colmajor(const DataIteratorT* d_A, - int lda, +template +__global__ void sum_rows_by_key_large_nkeys_kernel_colmajor(const DataIteratorT d_A, + IdxT lda, KeysIteratorT d_keys, const WeightT* d_weights, - int nrows, - int ncols, + IdxT nrows, + IdxT ncols, int key_offset, - int nkeys, - DataIteratorT* d_sums) + IdxT nkeys, + SumsT* d_sums) { typedef typename std::iterator_traits::value_type KeyType; - typedef typename std::iterator_traits::value_type DataType; - __shared__ DataType local_sums[SUM_ROWS_BY_KEY_LARGE_K_MAX_K]; + typedef typename std::iterator_traits::value_type DataType; + __shared__ SumsT local_sums[SUM_ROWS_BY_KEY_LARGE_K_MAX_K]; - for (int local_key = threadIdx.x; local_key < nkeys; local_key += blockDim.x) + for (IdxT local_key = threadIdx.x; local_key < nkeys; local_key += blockDim.x) local_sums[local_key] = 0.0; - for (int idim = blockIdx.y; idim < ncols; idim += gridDim.y) { + for (IdxT idim = blockIdx.y; idim < ncols; idim += gridDim.y) { __syncthreads(); // local_sums // At this point local_sums if full of zeros - for (int irow = blockIdx.x * blockDim.x + threadIdx.x; irow < nrows; + for (IdxT irow = blockIdx.x * blockDim.x + threadIdx.x; irow < nrows; irow += blockDim.x * gridDim.x) { // Branch div in this loop - not an issue with current code DataType val = d_A[idim * lda + irow]; if (d_weights) val = val * d_weights[irow]; - int local_key = d_keys[irow] - key_offset; + IdxT local_key = d_keys[irow] - key_offset; // We could load next val here - raft::myAtomicAdd(&local_sums[local_key], val); + raft::myAtomicAdd(&local_sums[local_key], static_cast(val)); } __syncthreads(); // local_sums - for (int local_key = threadIdx.x; local_key < nkeys; local_key += blockDim.x) { - DataType local_sum = local_sums[local_key]; + for (IdxT local_key = threadIdx.x; local_key < nkeys; local_key += blockDim.x) { + SumsT local_sum = local_sums[local_key]; if (local_sum != 0.0) { KeyType global_key = key_offset + local_key; @@ -237,22 +241,22 @@ __global__ void sum_rows_by_key_large_nkeys_kernel_colmajor(const DataIteratorT* } } -template -void sum_rows_by_key_large_nkeys_colmajor(const DataIteratorT* d_A, - int lda, +template +void sum_rows_by_key_large_nkeys_colmajor(const DataIteratorT d_A, + IdxT lda, KeysIteratorT d_keys, - int nrows, - int ncols, + IdxT nrows, + IdxT ncols, int key_offset, - int nkeys, - DataIteratorT* d_sums, + IdxT nkeys, + SumsT* d_sums, cudaStream_t st) { dim3 grid, block; block.x = SUM_ROWS_SMALL_K_DIMX; block.y = 1; // Necessary - grid.x = raft::ceildiv(nrows, (int)block.x); + grid.x = raft::ceildiv(nrows, (IdxT)block.x); grid.x = std::min(grid.x, 32u); grid.y = ncols; grid.y = std::min(grid.y, MAX_BLOCKS); @@ -260,91 +264,47 @@ void sum_rows_by_key_large_nkeys_colmajor(const DataIteratorT* d_A, d_A, lda, d_keys, nrows, ncols, key_offset, nkeys, d_sums); } -#define RRBK_SHMEM_SZ 32 - -//#define RRBK_SHMEM -template -__global__ void sum_rows_by_key_large_nkeys_kernel_rowmajor(const DataIteratorT* d_A, - int lda, +template +__global__ void sum_rows_by_key_large_nkeys_kernel_rowmajor(const DataIteratorT d_A, + IdxT lda, const WeightT* d_weights, KeysIteratorT d_keys, - int nrows, - int ncols, - int key_offset, - int nkeys, - DataIteratorT* d_sums) + IdxT nrows, + IdxT ncols, + SumsT* d_sums) { - typedef typename std::iterator_traits::value_type KeyType; - typedef typename std::iterator_traits::value_type DataType; - -#ifdef RRBK_SHMEM - __shared__ KeyType sh_keys[RRBK_SHMEM_SZ]; -#endif - int rows_per_partition = nrows / gridDim.z + 1; - int start_row = blockIdx.z * rows_per_partition; - int end_row = start_row + rows_per_partition; - end_row = end_row > nrows ? nrows : end_row; - - KeyType local_key = blockIdx.y; - if (local_key >= nkeys) return; - int this_col = threadIdx.x + blockIdx.x * blockDim.x; - if (this_col >= ncols) return; - - DataType sum = 0.0; - KeyType global_key = key_offset + local_key; -#ifdef RRBK_SHMEM - int sh_key_inx = 0; -#endif - for (int r = start_row; r < end_row; r++) { -#ifdef RRBK_SHMEM - if (0 == sh_key_inx % RRBK_SHMEM_SZ) { - for (int x = threadIdx.x; x < RRBK_SHMEM_SZ; x += blockDim.x) - sh_keys[x] = d_keys[r + x]; - __syncthreads(); - } - if (sh_keys[sh_key_inx] != global_key) continue; // No divergence since global_key is the - // same for the whole block - sh_key_inx++; -#else - if (d_keys[r] != global_key) - continue; // No divergence since global_key is the - // same for the whole block -#endif - // if ((end_row-start_row) / (r-start_row) != global_key) continue; - DataType val = __ldcg(&d_A[r * lda + this_col]); - if (d_weights) { val = val * d_weights[r]; } - sum += val; - } - - if (sum != 0.0) raft::myAtomicAdd(&d_sums[global_key * ncols + this_col], sum); + IdxT gid = threadIdx.x + (blockDim.x * static_cast(blockIdx.x)); + IdxT j = gid % ncols; + IdxT i = gid / ncols; + if (i >= nrows) return; + IdxT l = static_cast(d_keys[i]); + SumsT val = d_A[j + lda * i]; + if (d_weights != nullptr) val *= d_weights[i]; + raft::myAtomicAdd(&d_sums[j + ncols * l], val); } -template -void sum_rows_by_key_large_nkeys_rowmajor(const DataIteratorT* d_A, - int lda, - KeysIteratorT d_keys, +template +void sum_rows_by_key_large_nkeys_rowmajor(const DataIteratorT d_A, + IdxT lda, + const KeysIteratorT d_keys, const WeightT* d_weights, - int nrows, - int ncols, - int key_offset, - int nkeys, - DataIteratorT* d_sums, + IdxT nrows, + IdxT ncols, + SumsT* d_sums, cudaStream_t st) { - // x-dim refers to the column in the input data - // y-dim refers to the key - // z-dim refers to a partitioning of the rows among the threadblocks - dim3 grid, block; - block.x = 256; // Adjust me! - block.y = 1; // Don't adjust me! - grid.x = raft::ceildiv(ncols, (int)block.x); - grid.y = nkeys; - grid.z = std::max(40960000 / nkeys / ncols, (int)1); // Adjust me! - grid.z = std::min(grid.z, (unsigned int)nrows); - grid.z = std::min(grid.z, MAX_BLOCKS); - - sum_rows_by_key_large_nkeys_kernel_rowmajor<<>>( - d_A, lda, d_weights, d_keys, nrows, ncols, key_offset, nkeys, d_sums); + uint32_t block_dim = 128; + auto grid_dim = static_cast(ceildiv(nrows * ncols, (IdxT)block_dim)); + sum_rows_by_key_large_nkeys_kernel_rowmajor<<>>( + d_A, lda, d_weights, d_keys, nrows, ncols, d_sums); } /** @@ -354,6 +314,8 @@ void sum_rows_by_key_large_nkeys_rowmajor(const DataIteratorT* d_A, * (may be a simple pointer type) * @tparam KeysIteratorT Random-access iterator type, for reading input keys * (may be a simple pointer type) + * @tparam SumsT Type of the output sums + * @tparam IdxT Index type * * @param[in] d_A Input data array (lda x nrows) * @param[in] lda Real row size for input data, d_A @@ -365,26 +327,31 @@ void sum_rows_by_key_large_nkeys_rowmajor(const DataIteratorT* d_A, * @param[in] nkeys Number of unique keys in d_keys * @param[out] d_sums Row sums by key (ncols x d_keys) * @param[in] stream CUDA stream + * @param[in] reset_sums Whether to reset the output sums to zero before reducing */ -template -void reduce_rows_by_key(const DataIteratorT* d_A, - int lda, +template +void reduce_rows_by_key(const DataIteratorT d_A, + IdxT lda, KeysIteratorT d_keys, const WeightT* d_weights, char* d_keys_char, - int nrows, - int ncols, - int nkeys, - DataIteratorT* d_sums, - cudaStream_t stream) + IdxT nrows, + IdxT ncols, + IdxT nkeys, + SumsT* d_sums, + cudaStream_t stream, + bool reset_sums) { typedef typename std::iterator_traits::value_type KeyType; - typedef typename std::iterator_traits::value_type DataType; // Following kernel needs memset - cudaMemsetAsync(d_sums, 0, ncols * nkeys * sizeof(DataType), stream); + if (reset_sums) { cudaMemsetAsync(d_sums, 0, ncols * nkeys * sizeof(SumsT), stream); } - if (nkeys <= SUM_ROWS_BY_KEY_SMALL_K_MAX_K) { + if (d_keys_char != nullptr && nkeys <= SUM_ROWS_BY_KEY_SMALL_K_MAX_K) { // sum_rows_by_key_small_k is BW bounded. d_keys is loaded ncols time - avoiding wasting BW // with doubles we have ~20% speed up - with floats we can hope something around 2x // Converting d_keys to char @@ -392,12 +359,7 @@ void reduce_rows_by_key(const DataIteratorT* d_A, sum_rows_by_key_small_nkeys( d_A, lda, d_keys_char, d_weights, nrows, ncols, nkeys, d_sums, stream); } else { - for (KeyType key_offset = 0; key_offset < static_cast(nkeys); - key_offset += SUM_ROWS_BY_KEY_LARGE_K_MAX_K) { - KeyType this_call_nkeys = std::min(SUM_ROWS_BY_KEY_LARGE_K_MAX_K, nkeys); - sum_rows_by_key_large_nkeys_rowmajor( - d_A, lda, d_keys, d_weights, nrows, ncols, key_offset, this_call_nkeys, d_sums, stream); - } + sum_rows_by_key_large_nkeys_rowmajor(d_A, lda, d_keys, d_weights, nrows, ncols, d_sums, stream); } } @@ -407,6 +369,8 @@ void reduce_rows_by_key(const DataIteratorT* d_A, * pointer type) * @tparam KeysIteratorT Random-access iterator type, for reading input keys (may be a simple * pointer type) + * @tparam SumsT Type of the output sums + * @tparam IdxT Index type * @param[in] d_A Input data array (lda x nrows) * @param[in] lda Real row size for input data, d_A * @param[in] d_keys Keys for each row (1 x nrows) @@ -417,18 +381,19 @@ void reduce_rows_by_key(const DataIteratorT* d_A, * @param[out] d_sums Row sums by key (ncols x d_keys) * @param[in] stream CUDA stream */ -template -void reduce_rows_by_key(const DataIteratorT* d_A, - int lda, +template +void reduce_rows_by_key(const DataIteratorT d_A, + IdxT lda, KeysIteratorT d_keys, char* d_keys_char, - int nrows, - int ncols, - int nkeys, - DataIteratorT* d_sums, - cudaStream_t stream) + IdxT nrows, + IdxT ncols, + IdxT nkeys, + SumsT* d_sums, + cudaStream_t stream, + bool reset_sums) { - typedef typename std::iterator_traits::value_type DataType; + typedef typename std::iterator_traits::value_type DataType; reduce_rows_by_key(d_A, lda, d_keys, @@ -438,7 +403,8 @@ void reduce_rows_by_key(const DataIteratorT* d_A, ncols, nkeys, d_sums, - stream); + stream, + reset_sums); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/reduce_rows_by_key.cuh index 39c54e8b0c..1dabd92087 100644 --- a/cpp/include/raft/linalg/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_rows_by_key.cuh @@ -43,6 +43,8 @@ void convert_array(IteratorT1 dst, IteratorT2 src, int n, cudaStream_t st) * (may be a simple pointer type) * @tparam KeysIteratorT Random-access iterator type, for reading input keys * (may be a simple pointer type) + * @tparam SumsT Type of the output sums + * @tparam IdxT Index type * * @param[in] d_A Input data array (lda x nrows) * @param[in] lda Real row size for input data, d_A @@ -54,21 +56,27 @@ void convert_array(IteratorT1 dst, IteratorT2 src, int n, cudaStream_t st) * @param[in] nkeys Number of unique keys in d_keys * @param[out] d_sums Row sums by key (ncols x d_keys) * @param[in] stream CUDA stream + * @param[in] reset_sums Whether to reset the output sums to zero before reducing */ -template -void reduce_rows_by_key(const DataIteratorT* d_A, - int lda, +template +void reduce_rows_by_key(const DataIteratorT d_A, + IdxT lda, const KeysIteratorT d_keys, const WeightT* d_weights, char* d_keys_char, - int nrows, - int ncols, - int nkeys, - DataIteratorT* d_sums, - cudaStream_t stream) + IdxT nrows, + IdxT ncols, + IdxT nkeys, + SumsT* d_sums, + cudaStream_t stream, + bool reset_sums = true) { detail::reduce_rows_by_key( - d_A, lda, d_keys, d_weights, d_keys_char, nrows, ncols, nkeys, d_sums, stream); + d_A, lda, d_keys, d_weights, d_keys_char, nrows, ncols, nkeys, d_sums, stream, reset_sums); } /** @@ -77,6 +85,8 @@ void reduce_rows_by_key(const DataIteratorT* d_A, * pointer type) * @tparam KeysIteratorT Random-access iterator type, for reading input keys (may be a simple * pointer type) + * @tparam SumsT Type of the output sums + * @tparam IdxT Index type * @param[in] d_A Input data array (lda x nrows) * @param[in] lda Real row size for input data, d_A * @param[in] d_keys Keys for each row (1 x nrows) @@ -86,19 +96,21 @@ void reduce_rows_by_key(const DataIteratorT* d_A, * @param[in] nkeys Number of unique keys in d_keys * @param[out] d_sums Row sums by key (ncols x d_keys) * @param[in] stream CUDA stream + * @param[in] reset_sums Whether to reset the output sums to zero before reducing */ -template -void reduce_rows_by_key(const DataIteratorT* d_A, - int lda, - KeysIteratorT d_keys, +template +void reduce_rows_by_key(const DataIteratorT d_A, + IdxT lda, + const KeysIteratorT d_keys, char* d_keys_char, - int nrows, - int ncols, - int nkeys, - DataIteratorT* d_sums, - cudaStream_t stream) + IdxT nrows, + IdxT ncols, + IdxT nkeys, + SumsT* d_sums, + cudaStream_t stream, + bool reset_sums = true) { - typedef typename std::iterator_traits::value_type DataType; + typedef typename std::iterator_traits::value_type DataType; reduce_rows_by_key(d_A, lda, d_keys, @@ -108,7 +120,8 @@ void reduce_rows_by_key(const DataIteratorT* d_A, ncols, nkeys, d_sums, - stream); + stream, + reset_sums); } /** @@ -128,9 +141,10 @@ void reduce_rows_by_key(const DataIteratorT* d_A, * @param[in] d_keys Keys for each row raft::device_vector_view (1 x nrows) * @param[out] d_sums Row sums by key raft::device_matrix_view (ncols x d_keys) * @param[in] n_unique_keys Number of unique keys in d_keys + * @param[out] d_keys_char Scratch memory for conversion of keys to char, raft::device_vector_view * @param[in] d_weights Weights for each observation in d_A raft::device_vector_view optional (1 * x nrows) - * @param[out] d_keys_char Scratch memory for conversion of keys to char, raft::device_vector_view + * @param[in] reset_sums Whether to reset the output sums to zero before reducing */ template void reduce_rows_by_key( @@ -140,7 +154,8 @@ void reduce_rows_by_key( raft::device_matrix_view d_sums, IndexType n_unique_keys, raft::device_vector_view d_keys_char, - std::optional> d_weights = std::nullopt) + std::optional> d_weights = std::nullopt, + bool reset_sums = true) { RAFT_EXPECTS(d_A.extent(0) == d_A.extent(0) && d_sums.extent(1) == n_unique_keys, "Output is not of size ncols * n_unique_keys"); @@ -158,7 +173,8 @@ void reduce_rows_by_key( d_A.extent(0), n_unique_keys, d_sums.data_handle(), - handle.get_stream()); + handle.get_stream(), + reset_sums); } else { reduce_rows_by_key(d_A.data_handle(), d_A.extent(0), @@ -168,7 +184,8 @@ void reduce_rows_by_key( d_A.extent(0), n_unique_keys, d_sums.data_handle(), - handle.get_stream()); + handle.get_stream(), + reset_sums); } } diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index bf0df065b2..ff4708bb7b 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -21,13 +21,16 @@ #include #include +#include #include #include #include #include #include #include +#include #include +#include #include #include #include @@ -207,14 +210,7 @@ constexpr inline auto calc_minibatch_size(uint32_t n_clusters, * multiple times with different datasets with the same effect as if calling this function once * on the combined dataset_. * - * NB: `centers` and `cluster_sizes` must be accessible on GPU due to - * divide_along_rows/normalize_rows. The rest can be both, under assumption that all pointers are - * accessible from the same place. - * - * i.e. two variants are possible: - * - * 1. All pointers are on the device. - * 2. All pointers are on the host, but `centers` and `cluster_sizes` are accessible from GPU. + * NB: all pointers must be accessible on the device. * * @tparam T element type * @tparam IdxT index type @@ -231,9 +227,11 @@ constexpr inline auto calc_minibatch_size(uint32_t n_clusters, * When set to `false`, this function may be used to update existing centers and sizes using * the weighted average principle. * @param stream + * @param mr (optional) memory resource to use for temporary allocations on the device */ template -void calc_centers_and_sizes(float* centers, +void calc_centers_and_sizes(const handle_t& handle, + float* centers, uint32_t* cluster_sizes, uint32_t n_clusters, uint32_t dim, @@ -241,12 +239,12 @@ void calc_centers_and_sizes(float* centers, IdxT n_rows, const LabelT* labels, bool reset_counters, - rmm::cuda_stream_view stream) + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { - if (reset_counters) { - utils::memzero(centers, n_clusters * dim, stream); - utils::memzero(cluster_sizes, n_clusters, stream); - } else { + if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } + + if (!reset_counters) { utils::map_along_rows( n_clusters, dim, @@ -255,13 +253,70 @@ void calc_centers_and_sizes(float* centers, [] __device__(float c, uint32_t s) -> float { return c * s; }, stream); } - utils::accumulate_into_selected(n_rows, dim, centers, cluster_sizes, dataset, labels, stream); - utils::map_along_rows( - n_clusters, - dim, + + rmm::device_uvector workspace(0, stream, mr); + rmm::device_uvector cluster_sizes_f(n_clusters, stream, mr); + float* sizes_f = cluster_sizes_f.data(); + + // If we reset the counters, we can compute directly the new sizes in cluster_sizes. + // If we don't reset, we compute in a temporary buffer and add in a separate step. + rmm::device_uvector temp_cluster_sizes(0, stream, mr); + uint32_t* temp_sizes = cluster_sizes; + if (!reset_counters) { + temp_cluster_sizes.resize(n_clusters, stream); + temp_sizes = temp_cluster_sizes.data(); + } + + utils::mapping mapping_op; + cub::TransformInputIterator, const T*> mapping_itr(dataset, + mapping_op); + + // todo(lsugy): use iterator from KV output of fusedL2NN + raft::linalg::reduce_rows_by_key(mapping_itr, + static_cast(dim), + labels, + nullptr, + static_cast(n_rows), + static_cast(dim), + static_cast(n_clusters), + centers, + stream, + reset_counters); + + // Compute weight of each cluster + raft::cluster::detail::countLabels(handle, + labels, + temp_sizes, + static_cast(n_rows), + static_cast(n_clusters), + workspace); + + // Add previous sizes if necessary and cast to float + auto counting = thrust::make_counting_iterator(0); + thrust::for_each( + handle.get_thrust_policy(), counting, counting + n_clusters, [=] __device__(int idx) { + uint32_t temp_size = temp_sizes[idx]; + if (!reset_counters) { + temp_size += cluster_sizes[idx]; + cluster_sizes[idx] = temp_size; + } + sizes_f[idx] = static_cast(temp_size); + }); + + raft::linalg::matrixVectorOp( + centers, centers, - cluster_sizes, - [] __device__(float c, uint32_t s) -> float { return s == 0 ? 0.0f : c / float(s); }, + sizes_f, + static_cast(dim), + static_cast(n_clusters), + true, + false, + [=] __device__(float mat, float vec) { + if (vec == 0.0f) + return 0.0f; + else + return mat / vec; + }, stream); } @@ -627,7 +682,8 @@ void balancing_em_iters(const handle_t& handle, device_memory, dataset_norm); // M: Maximization step - calculate optimal cluster centers - calc_centers_and_sizes(cluster_centers, + calc_centers_and_sizes(handle, + cluster_centers, cluster_sizes, n_clusters, dim, @@ -635,7 +691,8 @@ void balancing_em_iters(const handle_t& handle, n_rows, cluster_labels, true, - stream); + stream, + device_memory); } } @@ -666,8 +723,17 @@ void build_clusters(const handle_t& handle, linalg::writeOnlyUnaryOp(cluster_labels, n_rows, f, stream); // update centers to match the initialized labels. - calc_centers_and_sizes( - cluster_centers, cluster_sizes, n_clusters, dim, dataset, n_rows, cluster_labels, true, stream); + calc_centers_and_sizes(handle, + cluster_centers, + cluster_sizes, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + true, + stream, + device_memory); // run EM balancing_em_iters(handle, diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 8dda574314..dbd509216b 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -112,13 +112,13 @@ struct mapping { * @{ */ template - HDI auto operator()(const S& x) -> std::enable_if_t, T> + HDI auto operator()(const S& x) const -> std::enable_if_t, T> { return x; }; template - HDI auto operator()(const S& x) -> std::enable_if_t, T> + HDI auto operator()(const S& x) const -> std::enable_if_t, T> { constexpr double kMult = config::kDivisor / config::kDivisor; if constexpr (std::is_floating_point_v) { return static_cast(x * static_cast(kMult)); } @@ -259,72 +259,6 @@ inline void dots_along_rows( */ } -template -__global__ void accumulate_into_selected_kernel(IdxT n_rows, - uint32_t n_cols, - float* output, - uint32_t* selection_counters, - const T* input, - const LabelT* row_ids) -{ - IdxT gid = threadIdx.x + (blockDim.x * static_cast(blockIdx.x)); - IdxT j = gid % n_cols; - IdxT i = gid / n_cols; - if (i >= n_rows) return; - IdxT l = static_cast(row_ids[i]); - if (j == 0) { atomicAdd(&(selection_counters[l]), 1); } - atomicAdd(&(output[j + n_cols * l]), mapping{}(input[gid])); -} - -/** - * @brief Add all rows of input matrix into a selection of rows in the output matrix - * (cast and possibly scale the data input type). Count the number of times every output - * row was selected along the way. - * - * @tparam T element type - * @tparam IdxT index type - * @tparam LabelT label type - * - * @param n_cols number of columns in all matrices - * @param[out] output output matrix [..., n_cols] - * @param[inout] selection_counters number of occurrences of each row id in row_ids [..., n_cols] - * @param n_rows number of rows in the input - * @param[in] input row-major input matrix [n_rows, n_cols] - * @param[in] row_ids row indices in the output matrix [n_rows] - */ -template -void accumulate_into_selected(IdxT n_rows, - uint32_t n_cols, - float* output, - uint32_t* selection_counters, - const T* input, - const LabelT* row_ids, - rmm::cuda_stream_view stream) -{ - switch (check_pointer_residency(output, input, selection_counters, row_ids)) { - case pointer_residency::host_and_device: - case pointer_residency::device_only: { - uint32_t block_dim = 128; - auto grid_dim = - static_cast(ceildiv(n_rows * static_cast(n_cols), block_dim)); - accumulate_into_selected_kernel<<>>( - n_rows, n_cols, output, selection_counters, input, row_ids); - } break; - case pointer_residency::host_only: { - stream.synchronize(); - for (IdxT i = 0; i < n_rows; i++) { - IdxT l = static_cast(row_ids[i]); - selection_counters[l]++; - for (IdxT j = 0; j < n_cols; j++) { - output[j + n_cols * l] += mapping{}(input[j + n_cols * i]); - } - } - stream.synchronize(); - } break; - default: RAFT_FAIL("All pointers must reside on the same side, host or device."); - } -} - template __global__ void normalize_rows_kernel(IdxT n_rows, IdxT n_cols, float* a) { diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index af1cb97d36..14f5ae4516 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -144,7 +144,8 @@ inline auto extend(const handle_t& handle, raft::copy( list_sizes_ptr, orig_index.list_sizes().data_handle(), ext_index.list_sizes().size(), stream); - kmeans::calc_centers_and_sizes(centers_ptr, + kmeans::calc_centers_and_sizes(handle, + centers_ptr, list_sizes_ptr, n_lists, dim, diff --git a/cpp/test/linalg/reduce_rows_by_key.cu b/cpp/test/linalg/reduce_rows_by_key.cu index e575f37dd6..7b124cb7bb 100644 --- a/cpp/test/linalg/reduce_rows_by_key.cu +++ b/cpp/test/linalg/reduce_rows_by_key.cu @@ -103,7 +103,7 @@ class ReduceRowTest : public ::testing::TestWithParam> { raft::random::RngState r(params.seed); raft::random::RngState r_int(params.seed); - int nobs = params.nobs; + uint32_t nobs = params.nobs; uint32_t cols = params.cols; uint32_t nkeys = params.nkeys; uniform(handle, r, in.data(), nobs * cols, T(0.0), T(2.0 / nobs));