From 8c38daf7fab9ab91f687fd35860a61fdf5e20276 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 11 Oct 2022 16:48:23 +0200 Subject: [PATCH 1/4] Create cub-based argmin primitive and replace argmin_along_rows in ANN kmeans --- cpp/include/raft/matrix/detail/math.cuh | 53 +++++++++------- cpp/include/raft/matrix/math.cuh | 18 +++++- .../knn/detail/ann_kmeans_balanced.cuh | 4 +- .../raft/spatial/knn/detail/ann_utils.cuh | 60 ------------------- 4 files changed, 50 insertions(+), 85 deletions(-) diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 07b9ccc12b..025883e25c 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -362,45 +362,56 @@ void matrixVectorBinarySub(Type* data, stream); } -// Computes the argmax(d_in) column-wise in a DxN matrix -template -__global__ void argmaxKernel(const T* d_in, int D, int N, IdxT* argmax) +// Computes an argmin/argmax column-wise in a DxN matrix +template +__global__ void argReduceKernel(const T* d_in, IdxT D, IdxT N, OutT* out) { - typedef cub::BlockReduce, TPB> BlockReduce; + typedef cub::BlockReduce, TPB> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - // compute maxIndex=argMax index for column - using KVP = cub::KeyValuePair; - int rowStart = blockIdx.x * D; - KVP thread_data(-1, -raft::myInf()); + using KVP = cub::KeyValuePair; + IdxT rowStart = static_cast(blockIdx.x) * D; + KVP thread_data(0, std::is_same_v ? -raft::myInf() : raft::myInf()); - for (int i = threadIdx.x; i < D; i += TPB) { - int idx = rowStart + i; - thread_data = cub::ArgMax()(thread_data, KVP(i, d_in[idx])); + for (IdxT i = threadIdx.x; i < D; i += TPB) { + IdxT idx = rowStart + i; + thread_data = RedOp()(thread_data, KVP(i, d_in[idx])); } - auto maxKV = BlockReduce(temp_storage).Reduce(thread_data, cub::ArgMax()); + auto maxKV = BlockReduce(temp_storage).Reduce(thread_data, RedOp()); - if (threadIdx.x == 0) { argmax[blockIdx.x] = maxKV.key; } + if (threadIdx.x == 0) { out[blockIdx.x] = maxKV.key; } } -template -void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream) +template +inline void argReduce(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) { - int D = n_rows; - int N = n_cols; + idx_t D = n_rows; + idx_t N = n_cols; if (D <= 32) { - argmaxKernel<<>>(in, D, N, out); + argReduceKernel<<>>(in, D, N, out); } else if (D <= 64) { - argmaxKernel<<>>(in, D, N, out); + argReduceKernel<<>>(in, D, N, out); } else if (D <= 128) { - argmaxKernel<<>>(in, D, N, out); + argReduceKernel<<>>(in, D, N, out); } else { - argmaxKernel<<>>(in, D, N, out); + argReduceKernel<<>>(in, D, N, out); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } +template +void argmin(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) +{ + argReduce(in, n_rows, n_cols, out, stream); +} + +template +void argmax(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) +{ + argReduce(in, n_rows, n_cols, out, stream); +} + // Utility kernel needed for signFlip. // Computes the argmax(abs(d_in)) column-wise in a DxN matrix followed by // flipping the sign if the |max| value for each column is negative. diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh index 3c2705cf87..e2f5cb822b 100644 --- a/cpp/include/raft/matrix/math.cuh +++ b/cpp/include/raft/matrix/math.cuh @@ -302,6 +302,20 @@ void ratio( /** @} */ +/** + * @brief Argmin: find the row idx with minimum value for each column + * @param in: input matrix + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param out: output vector of size n_cols + * @param stream: cuda stream + */ +template +void argmin(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) +{ + detail::argmin(in, n_rows, n_cols, out, stream); +} + /** * @brief Argmax: find the row idx with maximum value for each column * @param in: input matrix @@ -310,8 +324,8 @@ void ratio( * @param out: output vector of size n_cols * @param stream: cuda stream */ -template -void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream) +template +void argmax(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) { detail::argmax(in, n_rows, n_cols, out, stream); } diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index bf0df065b2..00c5cf7303 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -144,8 +145,7 @@ inline void predict_float_core(const handle_t& handle, distances.data(), n_clusters, stream); - utils::argmin_along_rows( - n_rows, static_cast(n_clusters), distances.data(), labels, stream); + raft::matrix::argmin(distances.data(), (IdxT)n_clusters, n_rows, labels, stream); break; } default: { diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 8dda574314..f7b358beb6 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -151,66 +151,6 @@ inline void memzero(T* ptr, IdxT n_elems, rmm::cuda_stream_view stream) } } -template -__global__ void argmin_along_rows_kernel(IdxT n_rows, uint32_t n_cols, const float* a, OutT* out) -{ - __shared__ OutT shm_ids[1024]; // NOLINT - __shared__ float shm_vals[1024]; // NOLINT - IdxT i = blockIdx.x; - if (i >= n_rows) return; - OutT min_idx = n_cols; - float min_val = raft::upper_bound(); - for (OutT j = threadIdx.x; j < n_cols; j += blockDim.x) { - if (min_val > a[j + n_cols * i]) { - min_val = a[j + n_cols * i]; - min_idx = j; - } - } - shm_vals[threadIdx.x] = min_val; - shm_ids[threadIdx.x] = min_idx; - __syncthreads(); - for (IdxT offset = blockDim.x / 2; offset > 0; offset >>= 1) { - if (threadIdx.x < offset) { - if (shm_vals[threadIdx.x] < shm_vals[threadIdx.x + offset]) { - } else if (shm_vals[threadIdx.x] > shm_vals[threadIdx.x + offset]) { - shm_vals[threadIdx.x] = shm_vals[threadIdx.x + offset]; - shm_ids[threadIdx.x] = shm_ids[threadIdx.x + offset]; - } else if (shm_ids[threadIdx.x] > shm_ids[threadIdx.x + offset]) { - shm_ids[threadIdx.x] = shm_ids[threadIdx.x + offset]; - } - } - __syncthreads(); - } - if (threadIdx.x == 0) { out[i] = shm_ids[0]; } -} - -/** - * @brief Find index of the smallest element in each row. - * - * NB: device-only function - * TODO: specialize select_k for the case of `k == 1` and use that one instead. - * - * @tparam IdxT index type - * @tparam OutT output type - * - * @param n_rows - * @param n_cols - * @param[in] a device pointer to the row-major matrix [n_rows, n_cols] - * @param[out] out device pointer to the vector of selected indices [n_rows] - * @param stream - */ -template -inline void argmin_along_rows( - IdxT n_rows, IdxT n_cols, const float* a, OutT* out, rmm::cuda_stream_view stream) -{ - IdxT block_dim = 1024; - while (block_dim > n_cols) { - block_dim /= 2; - } - block_dim = max(block_dim, (IdxT)128); - argmin_along_rows_kernel<<>>(n_rows, n_cols, a, out); -} - template __global__ void dots_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, float* out) { From 2cdcc902c099d82068395e780362e565684bd056 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 11 Oct 2022 18:18:05 +0200 Subject: [PATCH 2/4] Add benchmark for argmin --- build.sh | 2 +- cpp/bench/CMakeLists.txt | 6 ++++ cpp/bench/matrix/argmin.cu | 59 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 cpp/bench/matrix/argmin.cu diff --git a/build.sh b/build.sh index d1dd8bdde1..02b6c54dfe 100755 --- a/build.sh +++ b/build.sh @@ -73,7 +73,7 @@ COMPILE_DIST_LIBRARY=OFF ENABLE_NN_DEPENDENCIES=OFF TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NN_TEST;SPATIAL_TEST;STATS_TEST;UTILS_TEST" -BENCH_TARGETS="CLUSTER_BENCH;SPATIAL_BENCH;DISTANCE_BENCH;LINALG_BENCH;SPARSE_BENCH;RANDOM_BENCH" +BENCH_TARGETS="CLUSTER_BENCH;SPATIAL_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" ENABLE_thrust_DEPENDENCY=ON CACHE_ARGS="" diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 51170e4265..e98196bc26 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -100,6 +100,12 @@ if(BUILD_BENCH) bench/main.cpp ) + ConfigureBench(NAME MATRIX_BENCH + PATH + bench/matrix/argmin.cu + bench/main.cpp + ) + ConfigureBench(NAME RANDOM_BENCH PATH bench/random/make_blobs.cu diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu new file mode 100644 index 0000000000..63a672274a --- /dev/null +++ b/cpp/bench/matrix/argmin.cu @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +namespace raft::bench::linalg { + +struct ArgminParams { + int64_t rows, cols; +}; + +template +struct Argmin : public fixture { + Argmin(const ArgminParams& p) + : params(p), matrix(p.rows * p.cols, stream), indices(p.cols, stream) + { + } + + void run_benchmark(::benchmark::State& state) override + { + loop_on_state(state, [this]() { + raft::matrix::argmin(matrix.data(), params.rows, params.cols, indices.data(), stream); + }); + } + + private: + ArgminParams params; + rmm::device_uvector matrix; + rmm::device_uvector indices; +}; // struct Argmin + +const std::vector kInputSizes{ + {64, 1000}, {128, 1000}, {256, 1000}, {512, 1000}, {1024, 1000}, + {64, 10000}, {128, 10000}, {256, 10000}, {512, 10000}, {1024, 10000}, + {64, 100000}, {128, 100000}, {256, 100000}, {512, 100000}, {1024, 100000}, + {64, 1000000}, {128, 1000000}, {256, 1000000}, {512, 1000000}, {1024, 1000000}, + {64, 10000000}, {128, 10000000}, {256, 10000000}, {512, 10000000}, {1024, 10000000}, +}; + +RAFT_BENCH_REGISTER((Argmin), "", kInputSizes); +RAFT_BENCH_REGISTER((Argmin), "", kInputSizes); + +} // namespace raft::bench::linalg From 94289adb304cbc3aaa39c3668b299548e0e2bc24 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Tue, 11 Oct 2022 19:05:12 +0200 Subject: [PATCH 3/4] Use faster commutative algorithm --- cpp/include/raft/matrix/detail/math.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 025883e25c..d8bdfad7bd 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -366,7 +366,9 @@ void matrixVectorBinarySub(Type* data, template __global__ void argReduceKernel(const T* d_in, IdxT D, IdxT N, OutT* out) { - typedef cub::BlockReduce, TPB> BlockReduce; + typedef cub:: + BlockReduce, TPB, cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY> + BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; using KVP = cub::KeyValuePair; From d9995346ce029504045e24dce35119442627c52d Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Fri, 28 Oct 2022 20:58:52 +0200 Subject: [PATCH 4/4] Add argmin header and test --- cpp/bench/matrix/argmin.cu | 46 +++++--- cpp/include/raft/matrix/argmax.cuh | 6 +- cpp/include/raft/matrix/argmin.cuh | 40 +++++++ cpp/include/raft/matrix/detail/math.cuh | 25 ++-- cpp/include/raft/matrix/math.cuh | 4 +- .../knn/detail/ann_kmeans_balanced.cuh | 8 +- cpp/test/CMakeLists.txt | 1 + cpp/test/matrix/argmax.cu | 2 +- cpp/test/matrix/argmin.cu | 109 ++++++++++++++++++ 9 files changed, 209 insertions(+), 32 deletions(-) create mode 100644 cpp/include/raft/matrix/argmin.cuh create mode 100644 cpp/test/matrix/argmin.cu diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 63a672274a..0d0dea0fdb 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -15,45 +15,57 @@ */ #include -#include +#include +#include #include namespace raft::bench::linalg { +template struct ArgminParams { - int64_t rows, cols; + IdxT rows, cols; }; -template +template struct Argmin : public fixture { - Argmin(const ArgminParams& p) - : params(p), matrix(p.rows * p.cols, stream), indices(p.cols, stream) + Argmin(const ArgminParams& p) : params(p) {} + + void allocate_data(const ::benchmark::State& state) override { + matrix = raft::make_device_matrix(handle, params.rows, params.cols); + indices = raft::make_device_vector(handle, params.rows); + + raft::random::RngState rng{1234}; + raft::random::uniform( + rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + handle.sync_stream(stream); } void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - raft::matrix::argmin(matrix.data(), params.rows, params.cols, indices.data(), stream); + auto matrix_const_view = raft::make_device_matrix_view( + matrix.data_handle(), matrix.extent(0), matrix.extent(1)); + raft::matrix::argmin(handle, matrix_const_view, indices.view()); }); } private: - ArgminParams params; - rmm::device_uvector matrix; - rmm::device_uvector indices; + ArgminParams params; + raft::device_matrix matrix; + raft::device_vector indices; }; // struct Argmin -const std::vector kInputSizes{ - {64, 1000}, {128, 1000}, {256, 1000}, {512, 1000}, {1024, 1000}, - {64, 10000}, {128, 10000}, {256, 10000}, {512, 10000}, {1024, 10000}, - {64, 100000}, {128, 100000}, {256, 100000}, {512, 100000}, {1024, 100000}, - {64, 1000000}, {128, 1000000}, {256, 1000000}, {512, 1000000}, {1024, 1000000}, - {64, 10000000}, {128, 10000000}, {256, 10000000}, {512, 10000000}, {1024, 10000000}, +const std::vector> argmin_inputs_i64{ + {1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024}, + {10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024}, + {100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024}, + {1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024}, + {10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024}, }; -RAFT_BENCH_REGISTER((Argmin), "", kInputSizes); -RAFT_BENCH_REGISTER((Argmin), "", kInputSizes); +RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); +RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); } // namespace raft::bench::linalg diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index b3face1012..e6736b14de 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -22,10 +22,10 @@ namespace raft::matrix { /** - * @brief Argmax: find the row idx with maximum value for each column + * @brief Argmax: find the col idx with maximum value for each row * @param[in] handle: raft handle * @param[in] in: input matrix of size (n_rows, n_cols) - * @param[out] out: output vector of size n_cols + * @param[out] out: output vector of size n_rows */ template void argmax(const raft::handle_t& handle, @@ -35,6 +35,6 @@ void argmax(const raft::handle_t& handle, RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); detail::argmax( - in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); + in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/argmin.cuh b/cpp/include/raft/matrix/argmin.cuh new file mode 100644 index 0000000000..e8cf763f70 --- /dev/null +++ b/cpp/include/raft/matrix/argmin.cuh @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::matrix { + +/** + * @brief Argmin: find the col idx with minimum value for each row + * @param[in] handle: raft handle + * @param[in] in: input matrix of size (n_rows, n_cols) + * @param[out] out: output vector of size n_rows + */ +template +void argmin(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view out) +{ + RAFT_EXPECTS(out.extent(0) == in.extent(0), + "Size of output vector must equal number of rows in input matrix."); + detail::argmin( + in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream()); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index d8bdfad7bd..64c85a03a5 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -385,11 +385,22 @@ __global__ void argReduceKernel(const T* d_in, IdxT D, IdxT N, OutT* out) if (threadIdx.x == 0) { out[blockIdx.x] = maxKV.key; } } +/** + * @brief Computes an argmin/argmax coalesced reduction + * + * @tparam RedOp Reduction operation (cub::ArgMin or cub::ArgMax) + * @tparam math_t Value type + * @tparam out_t Output key type + * @tparam idx_t Matrix index type + * @param[in] in Input matrix (DxN column-major or NxD row-major) + * @param[in] D Dimension of the axis to reduce along + * @param[in] N Number of reductions + * @param[out] out Output keys (N) + * @param[in] stream CUDA stream + */ template -inline void argReduce(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) +inline void argReduce(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream) { - idx_t D = n_rows; - idx_t N = n_cols; if (D <= 32) { argReduceKernel<<>>(in, D, N, out); } else if (D <= 64) { @@ -403,15 +414,15 @@ inline void argReduce(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, } template -void argmin(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) +void argmin(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream) { - argReduce(in, n_rows, n_cols, out, stream); + argReduce(in, D, N, out, stream); } template -void argmax(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream) +void argmax(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream) { - argReduce(in, n_rows, n_cols, out, stream); + argReduce(in, D, N, out, stream); } // Utility kernel needed for signFlip. diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh index e2f5cb822b..fd5ddf2df3 100644 --- a/cpp/include/raft/matrix/math.cuh +++ b/cpp/include/raft/matrix/math.cuh @@ -304,7 +304,7 @@ void ratio( /** * @brief Argmin: find the row idx with minimum value for each column - * @param in: input matrix + * @param in: input matrix (column-major) * @param n_rows: number of rows of input matrix * @param n_cols: number of columns of input matrix * @param out: output vector of size n_cols @@ -318,7 +318,7 @@ void argmin(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream /** * @brief Argmax: find the row idx with maximum value for each column - * @param in: input matrix + * @param in: input matrix (column-major) * @param n_rows: number of rows of input matrix * @param n_cols: number of columns of input matrix * @param out: output vector of size n_cols diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 214e74c0e5..8659c39f45 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -33,7 +33,7 @@ #include #include #include -#include +#include #include #include @@ -148,7 +148,11 @@ inline void predict_float_core(const handle_t& handle, distances.data(), n_clusters, stream); - raft::matrix::argmin(distances.data(), (IdxT)n_clusters, n_rows, labels, stream); + + auto distances_const_view = raft::make_device_matrix_view( + distances.data(), n_rows, static_cast(n_clusters)); + auto labels_view = raft::make_device_vector_view(labels, n_rows); + raft::matrix::argmin(handle, distances_const_view, labels_view); break; } default: { diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 4280be91ff..767c33bbaa 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -162,6 +162,7 @@ if(BUILD_TESTS) ConfigureTest(NAME MATRIX_TEST PATH test/matrix/argmax.cu + test/matrix/argmin.cu test/matrix/columnSort.cu test/matrix/diagonal.cu test/matrix/gather.cu diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index 9568c06d93..0219eb1aff 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -29,8 +29,8 @@ template struct ArgMaxInputs { std::vector input_matrix; std::vector output_matrix; - std::size_t n_cols; std::size_t n_rows; + std::size_t n_cols; }; template diff --git a/cpp/test/matrix/argmin.cu b/cpp/test/matrix/argmin.cu new file mode 100644 index 0000000000..bdf178cd8a --- /dev/null +++ b/cpp/test/matrix/argmin.cu @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace matrix { + +template +struct ArgMinInputs { + std::vector input_matrix; + std::vector output_matrix; + std::size_t n_rows; + std::size_t n_cols; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const ArgMinInputs& dims) +{ + return os; +} + +template +class ArgMinTest : public ::testing::TestWithParam> { + public: + ArgMinTest() + : params(::testing::TestWithParam>::GetParam()), + input(raft::make_device_matrix( + handle, params.n_rows, params.n_cols)), + output(raft::make_device_vector(handle, params.n_rows)), + expected(raft::make_device_vector(handle, params.n_rows)) + { + raft::update_device(input.data_handle(), + params.input_matrix.data(), + params.input_matrix.size(), + handle.get_stream()); + raft::update_device(expected.data_handle(), + params.output_matrix.data(), + params.output_matrix.size(), + handle.get_stream()); + + auto input_const_view = raft::make_device_matrix_view( + input.data_handle(), input.extent(0), input.extent(1)); + + raft::matrix::argmin(handle, input_const_view, output.view()); + + handle.sync_stream(); + } + + protected: + raft::handle_t handle; + ArgMinInputs params; + + raft::device_matrix input; + raft::device_vector output; + raft::device_vector expected; +}; + +const std::vector> inputsf = { + {{0.1f, 0.2f, 0.3f, 0.4f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.5f, 0.0f}, {0, 3, 3}, 3, 4}}; + +const std::vector> inputsd = { + {{0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1, 0.2, 0.3, 0.5, 0.0}, {0, 3, 3}, 3, 4}}; + +typedef ArgMinTest ArgMinTestF; +TEST_P(ArgMinTestF, Result) +{ + ASSERT_TRUE(devArrMatch(expected.data_handle(), + output.data_handle(), + params.n_rows, + Compare(), + handle.get_stream())); +} + +typedef ArgMinTest ArgMinTestD; +TEST_P(ArgMinTestD, Result) +{ + ASSERT_TRUE(devArrMatch(expected.data_handle(), + output.data_handle(), + params.n_rows, + Compare(), + handle.get_stream())); +} + +INSTANTIATE_TEST_SUITE_P(ArgMinTest, ArgMinTestF, ::testing::ValuesIn(inputsf)); + +INSTANTIATE_TEST_SUITE_P(ArgMinTest, ArgMinTestD, ::testing::ValuesIn(inputsd)); + +} // namespace matrix +} // namespace raft \ No newline at end of file