From 836bb58eab22420c4b75c5db194743a5f8d3075a Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Wed, 9 Nov 2022 12:47:40 +0100 Subject: [PATCH] Create cub-based argmin primitive and replace `argmin_along_rows` in ANN kmeans (#912) This PR follows up on [a suggestion](https://github.com/rapidsai/raft/pull/821#discussion_r984815136) from @cjnolet. The new `argmin` primitive is up to 5x faster than `argmin_along_rows` for dimensions relevant to ANN kmeans, and removes code duplication. The reasons why it is faster are: - `argmin_along_rows` often misses on doing a sequential reduction before the tree reduction, especially as it uses large block sizes, as much as 1024. - CUB has a better reduction algorithm than the basic shared-mem reduction used in `argmin_along_rows`. - If we switch the `argmin` prim to using the `cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY` algorithm, we can get up to 30% further speedup! (I believe it's safe to use the commutative algorithm here since the offset is contained in the key-value pair so the reduction operation is commutative). The speedup that I have measured for IVF-Flat build with the `InnerProduct` metric is around 15%. Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/912 --- build.sh | 2 +- cpp/bench/CMakeLists.txt | 6 + cpp/bench/matrix/argmin.cu | 71 ++++++++++++ cpp/include/raft/matrix/argmax.cuh | 6 +- cpp/include/raft/matrix/argmin.cuh | 40 +++++++ cpp/include/raft/matrix/detail/math.cuh | 66 +++++++---- cpp/include/raft/matrix/math.cuh | 20 +++- .../knn/detail/ann_kmeans_balanced.cuh | 8 +- .../raft/spatial/knn/detail/ann_utils.cuh | 60 ---------- cpp/test/CMakeLists.txt | 1 + cpp/test/matrix/argmax.cu | 2 +- cpp/test/matrix/argmin.cu | 109 ++++++++++++++++++ 12 files changed, 300 insertions(+), 91 deletions(-) create mode 100644 cpp/bench/matrix/argmin.cu create mode 100644 cpp/include/raft/matrix/argmin.cuh create mode 100644 cpp/test/matrix/argmin.cu diff --git a/build.sh b/build.sh index 61e6d1a007..b48465922a 100755 --- a/build.sh +++ b/build.sh @@ -73,7 +73,7 @@ COMPILE_DIST_LIBRARY=OFF ENABLE_NN_DEPENDENCIES=OFF TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" -BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;SPARSE_BENCH;RANDOM_BENCH" +BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" ENABLE_thrust_DEPENDENCY=ON CACHE_ARGS="" diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index ef91fe4e6c..81e894fbbc 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -101,6 +101,12 @@ if(BUILD_BENCH) bench/main.cpp ) + ConfigureBench(NAME MATRIX_BENCH + PATH + bench/matrix/argmin.cu + bench/main.cpp + ) + ConfigureBench(NAME RANDOM_BENCH PATH bench/random/make_blobs.cu diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu new file mode 100644 index 0000000000..0d0dea0fdb --- /dev/null +++ b/cpp/bench/matrix/argmin.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include + +namespace raft::bench::linalg { + +template +struct ArgminParams { + IdxT rows, cols; +}; + +template +struct Argmin : public fixture { + Argmin(const ArgminParams& p) : params(p) {} + + void allocate_data(const ::benchmark::State& state) override + { + matrix = raft::make_device_matrix(handle, params.rows, params.cols); + indices = raft::make_device_vector(handle, params.rows); + + raft::random::RngState rng{1234}; + raft::random::uniform( + rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + handle.sync_stream(stream); + } + + void run_benchmark(::benchmark::State& state) override + { + loop_on_state(state, [this]() { + auto matrix_const_view = raft::make_device_matrix_view( + matrix.data_handle(), matrix.extent(0), matrix.extent(1)); + raft::matrix::argmin(handle, matrix_const_view, indices.view()); + }); + } + + private: + ArgminParams params; + raft::device_matrix matrix; + raft::device_vector indices; +}; // struct Argmin + +const std::vector> argmin_inputs_i64{ + {1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024}, + {10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024}, + {100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024}, + {1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024}, + {10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024}, +}; + +RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); +RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); + +} // namespace raft::bench::linalg diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index b3face1012..e6736b14de 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -22,10 +22,10 @@ namespace raft::matrix { /** - * @brief Argmax: find the row idx with maximum value for each column + * @brief Argmax: find the col idx with maximum value for each row * @param[in] handle: raft handle * @param[in] in: input matrix of size (n_rows, n_cols) - * @param[out] out: output vector of size n_cols + * @param[out] out: output vector of size n_rows */ template void argmax(const raft::handle_t& handle, @@ -35,6 +35,6 @@ void argmax(const raft::handle_t& handle, RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); detail::argmax( - in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); + in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/argmin.cuh b/cpp/include/raft/matrix/argmin.cuh new file mode 100644 index 0000000000..e8cf763f70 --- /dev/null +++ b/cpp/include/raft/matrix/argmin.cuh @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::matrix { + +/** + * @brief Argmin: find the col idx with minimum value for each row + * @param[in] handle: raft handle + * @param[in] in: input matrix of size (n_rows, n_cols) + * @param[out] out: output vector of size n_rows + */ +template +void argmin(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view out) +{ + RAFT_EXPECTS(out.extent(0) == in.extent(0), + "Size of output vector must equal number of rows in input matrix."); + detail::argmin( + in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 07b9ccc12b..64c85a03a5 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -362,45 +362,69 @@ void matrixVectorBinarySub(Type* data, stream); } -// Computes the argmax(d_in) column-wise in a DxN matrix -template -__global__ void argmaxKernel(const T* d_in, int D, int N, IdxT* argmax) +// Computes an argmin/argmax column-wise in a DxN matrix +template +__global__ void argReduceKernel(const T* d_in, IdxT D, IdxT N, OutT* out) { - typedef cub::BlockReduce, TPB> BlockReduce; + typedef cub:: + BlockReduce, TPB, cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY> + BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - // compute maxIndex=argMax index for column - using KVP = cub::KeyValuePair; - int rowStart = blockIdx.x * D; - KVP thread_data(-1, -raft::myInf()); + using KVP = cub::KeyValuePair; + IdxT rowStart = static_cast(blockIdx.x) * D; + KVP thread_data(0, std::is_same_v ? -raft::myInf() : raft::myInf()); - for (int i = threadIdx.x; i < D; i += TPB) { - int idx = rowStart + i; - thread_data = cub::ArgMax()(thread_data, KVP(i, d_in[idx])); + for (IdxT i = threadIdx.x; i < D; i += TPB) { + IdxT idx = rowStart + i; + thread_data = RedOp()(thread_data, KVP(i, d_in[idx])); } - auto maxKV = BlockReduce(temp_storage).Reduce(thread_data, cub::ArgMax()); + auto maxKV = BlockReduce(temp_storage).Reduce(thread_data, RedOp()); - if (threadIdx.x == 0) { argmax[blockIdx.x] = maxKV.key; } + if (threadIdx.x == 0) { out[blockIdx.x] = maxKV.key; } } -template -void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream) +/** + * @brief Computes an argmin/argmax coalesced reduction + * + * @tparam RedOp Reduction operation (cub::ArgMin or cub::ArgMax) + * @tparam math_t Value type + * @tparam out_t Output key type + * @tparam idx_t Matrix index type + * @param[in] in Input matrix (DxN column-major or NxD row-major) + * @param[in] D Dimension of the axis to reduce along + * @param[in] N Number of reductions + * @param[out] out Output keys (N) + * @param[in] stream CUDA stream + */ +template +inline void argReduce(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream) { - int D = n_rows; - int N = n_cols; if (D <= 32) { - argmaxKernel<<>>(in, D, N, out); + argReduceKernel<<>>(in, D, N, out); } else if (D <= 64) { - argmaxKernel<<>>(in, D, N, out); + argReduceKernel<<>>(in, D, N, out); } else if (D <= 128) { - argmaxKernel<<>>(in, D, N, out); + argReduceKernel<<>>(in, D, N, out); } else { - argmaxKernel<<>>(in, D, N, out); + argReduceKernel<<>>(in, D, N, out); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } +template +void argmin(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream) +{ + argReduce(in, D, N, out, stream); +} + +template +void argmax(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream) +{ + argReduce(in, D, N, out, stream); +} + // Utility kernel needed for signFlip. // Computes the argmax(abs(d_in)) column-wise in a DxN matrix followed by // flipping the sign if the |max| value for each column is negative. diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh index 3c2705cf87..fd5ddf2df3 100644 --- a/cpp/include/raft/matrix/math.cuh +++ b/cpp/include/raft/matrix/math.cuh @@ -302,16 +302,30 @@ void ratio( /** @} */ +/** + * @brief Argmin: find the row idx with minimum value for each column + * @param in: input matrix (column-major) + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param out: output vector of size n_cols + * @param stream: cuda stream + */ +template +void argmin(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) +{ + detail::argmin(in, n_rows, n_cols, out, stream); +} + /** * @brief Argmax: find the row idx with maximum value for each column - * @param in: input matrix + * @param in: input matrix (column-major) * @param n_rows: number of rows of input matrix * @param n_cols: number of columns of input matrix * @param out: output vector of size n_cols * @param stream: cuda stream */ -template -void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream) +template +void argmax(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) { detail::argmax(in, n_rows, n_cols, out, stream); } diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index fd009b30af..b766e12cbd 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -147,8 +148,11 @@ inline void predict_float_core(const handle_t& handle, distances.data(), n_clusters, stream); - utils::argmin_along_rows( - n_rows, static_cast(n_clusters), distances.data(), labels, stream); + + auto distances_const_view = raft::make_device_matrix_view( + distances.data(), n_rows, static_cast(n_clusters)); + auto labels_view = raft::make_device_vector_view(labels, n_rows); + raft::matrix::argmin(handle, distances_const_view, labels_view); break; } default: { diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 7b26ccfb42..5d031cc51d 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -151,66 +151,6 @@ inline void memzero(T* ptr, IdxT n_elems, rmm::cuda_stream_view stream) } } -template -__global__ void argmin_along_rows_kernel(IdxT n_rows, uint32_t n_cols, const float* a, OutT* out) -{ - __shared__ OutT shm_ids[1024]; // NOLINT - __shared__ float shm_vals[1024]; // NOLINT - IdxT i = blockIdx.x; - if (i >= n_rows) return; - OutT min_idx = n_cols; - float min_val = raft::upper_bound(); - for (OutT j = threadIdx.x; j < n_cols; j += blockDim.x) { - if (min_val > a[j + n_cols * i]) { - min_val = a[j + n_cols * i]; - min_idx = j; - } - } - shm_vals[threadIdx.x] = min_val; - shm_ids[threadIdx.x] = min_idx; - __syncthreads(); - for (IdxT offset = blockDim.x / 2; offset > 0; offset >>= 1) { - if (threadIdx.x < offset) { - if (shm_vals[threadIdx.x] < shm_vals[threadIdx.x + offset]) { - } else if (shm_vals[threadIdx.x] > shm_vals[threadIdx.x + offset]) { - shm_vals[threadIdx.x] = shm_vals[threadIdx.x + offset]; - shm_ids[threadIdx.x] = shm_ids[threadIdx.x + offset]; - } else if (shm_ids[threadIdx.x] > shm_ids[threadIdx.x + offset]) { - shm_ids[threadIdx.x] = shm_ids[threadIdx.x + offset]; - } - } - __syncthreads(); - } - if (threadIdx.x == 0) { out[i] = shm_ids[0]; } -} - -/** - * @brief Find index of the smallest element in each row. - * - * NB: device-only function - * TODO: specialize select_k for the case of `k == 1` and use that one instead. - * - * @tparam IdxT index type - * @tparam OutT output type - * - * @param n_rows - * @param n_cols - * @param[in] a device pointer to the row-major matrix [n_rows, n_cols] - * @param[out] out device pointer to the vector of selected indices [n_rows] - * @param stream - */ -template -inline void argmin_along_rows( - IdxT n_rows, IdxT n_cols, const float* a, OutT* out, rmm::cuda_stream_view stream) -{ - IdxT block_dim = 1024; - while (block_dim > n_cols) { - block_dim /= 2; - } - block_dim = max(block_dim, (IdxT)128); - argmin_along_rows_kernel<<>>(n_rows, n_cols, a, out); -} - template __global__ void dots_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, float* out) { diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 088b15aaf1..792bcf1ec1 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -167,6 +167,7 @@ if(BUILD_TESTS) ConfigureTest(NAME MATRIX_TEST PATH test/matrix/argmax.cu + test/matrix/argmin.cu test/matrix/columnSort.cu test/matrix/diagonal.cu test/matrix/gather.cu diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index 9568c06d93..0219eb1aff 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -29,8 +29,8 @@ template struct ArgMaxInputs { std::vector input_matrix; std::vector output_matrix; - std::size_t n_cols; std::size_t n_rows; + std::size_t n_cols; }; template diff --git a/cpp/test/matrix/argmin.cu b/cpp/test/matrix/argmin.cu new file mode 100644 index 0000000000..bdf178cd8a --- /dev/null +++ b/cpp/test/matrix/argmin.cu @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace matrix { + +template +struct ArgMinInputs { + std::vector input_matrix; + std::vector output_matrix; + std::size_t n_rows; + std::size_t n_cols; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const ArgMinInputs& dims) +{ + return os; +} + +template +class ArgMinTest : public ::testing::TestWithParam> { + public: + ArgMinTest() + : params(::testing::TestWithParam>::GetParam()), + input(raft::make_device_matrix( + handle, params.n_rows, params.n_cols)), + output(raft::make_device_vector(handle, params.n_rows)), + expected(raft::make_device_vector(handle, params.n_rows)) + { + raft::update_device(input.data_handle(), + params.input_matrix.data(), + params.input_matrix.size(), + handle.get_stream()); + raft::update_device(expected.data_handle(), + params.output_matrix.data(), + params.output_matrix.size(), + handle.get_stream()); + + auto input_const_view = raft::make_device_matrix_view( + input.data_handle(), input.extent(0), input.extent(1)); + + raft::matrix::argmin(handle, input_const_view, output.view()); + + handle.sync_stream(); + } + + protected: + raft::handle_t handle; + ArgMinInputs params; + + raft::device_matrix input; + raft::device_vector output; + raft::device_vector expected; +}; + +const std::vector> inputsf = { + {{0.1f, 0.2f, 0.3f, 0.4f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.5f, 0.0f}, {0, 3, 3}, 3, 4}}; + +const std::vector> inputsd = { + {{0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1, 0.2, 0.3, 0.5, 0.0}, {0, 3, 3}, 3, 4}}; + +typedef ArgMinTest ArgMinTestF; +TEST_P(ArgMinTestF, Result) +{ + ASSERT_TRUE(devArrMatch(expected.data_handle(), + output.data_handle(), + params.n_rows, + Compare(), + handle.get_stream())); +} + +typedef ArgMinTest ArgMinTestD; +TEST_P(ArgMinTestD, Result) +{ + ASSERT_TRUE(devArrMatch(expected.data_handle(), + output.data_handle(), + params.n_rows, + Compare(), + handle.get_stream())); +} + +INSTANTIATE_TEST_SUITE_P(ArgMinTest, ArgMinTestF, ::testing::ValuesIn(inputsf)); + +INSTANTIATE_TEST_SUITE_P(ArgMinTest, ArgMinTestD, ::testing::ValuesIn(inputsd)); + +} // namespace matrix +} // namespace raft \ No newline at end of file