Skip to content

Commit

Permalink
Create cub-based argmin primitive and replace argmin_along_rows in …
Browse files Browse the repository at this point in the history
…ANN kmeans (#912)

This PR follows up on [a suggestion](#821 (comment)) from @cjnolet. The new `argmin` primitive is up to 5x faster than `argmin_along_rows` for dimensions relevant to ANN kmeans, and removes code duplication.

The reasons why it is faster are:

- `argmin_along_rows` often misses on doing a sequential reduction before the tree reduction, especially as it uses large block sizes, as much as 1024.
- CUB has a better reduction algorithm than the basic shared-mem reduction used in `argmin_along_rows`.
- If we switch the `argmin` prim to using the `cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY` algorithm, we can get up to 30% further speedup! (I believe it's safe to use the commutative algorithm here since the offset is contained in the key-value pair so the reduction operation is commutative).

The speedup that I have measured for IVF-Flat build with the `InnerProduct` metric is around 15%.

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #912
  • Loading branch information
Nyrio authored Nov 9, 2022
1 parent e60cd1c commit 836bb58
Show file tree
Hide file tree
Showing 12 changed files with 300 additions and 91 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ COMPILE_DIST_LIBRARY=OFF
ENABLE_NN_DEPENDENCIES=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;SPARSE_BENCH;RANDOM_BENCH"
BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"
ENABLE_thrust_DEPENDENCY=ON

CACHE_ARGS=""
Expand Down
6 changes: 6 additions & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ if(BUILD_BENCH)
bench/main.cpp
)

ConfigureBench(NAME MATRIX_BENCH
PATH
bench/matrix/argmin.cu
bench/main.cpp
)

ConfigureBench(NAME RANDOM_BENCH
PATH
bench/random/make_blobs.cu
Expand Down
71 changes: 71 additions & 0 deletions cpp/bench/matrix/argmin.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/matrix/argmin.cuh>
#include <raft/random/rng.cuh>

#include <rmm/device_uvector.hpp>

namespace raft::bench::linalg {

template <typename IdxT>
struct ArgminParams {
IdxT rows, cols;
};

template <typename T, typename OutT, typename IdxT>
struct Argmin : public fixture {
Argmin(const ArgminParams<IdxT>& p) : params(p) {}

void allocate_data(const ::benchmark::State& state) override
{
matrix = raft::make_device_matrix<T, IdxT>(handle, params.rows, params.cols);
indices = raft::make_device_vector<OutT, IdxT>(handle, params.rows);

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle.sync_stream(stream);
}

void run_benchmark(::benchmark::State& state) override
{
loop_on_state(state, [this]() {
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>(
matrix.data_handle(), matrix.extent(0), matrix.extent(1));
raft::matrix::argmin(handle, matrix_const_view, indices.view());
});
}

private:
ArgminParams<IdxT> params;
raft::device_matrix<T, IdxT> matrix;
raft::device_vector<OutT, IdxT> indices;
}; // struct Argmin

const std::vector<ArgminParams<int64_t>> argmin_inputs_i64{
{1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024},
{10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024},
{100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024},
{1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024},
{10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024},
};

RAFT_BENCH_REGISTER((Argmin<float, uint32_t, int64_t>), "", argmin_inputs_i64);
RAFT_BENCH_REGISTER((Argmin<double, uint32_t, int64_t>), "", argmin_inputs_i64);

} // namespace raft::bench::linalg
6 changes: 3 additions & 3 deletions cpp/include/raft/matrix/argmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
namespace raft::matrix {

/**
* @brief Argmax: find the row idx with maximum value for each column
* @brief Argmax: find the col idx with maximum value for each row
* @param[in] handle: raft handle
* @param[in] in: input matrix of size (n_rows, n_cols)
* @param[out] out: output vector of size n_cols
* @param[out] out: output vector of size n_rows
*/
template <typename math_t, typename idx_t, typename matrix_idx_t>
void argmax(const raft::handle_t& handle,
Expand All @@ -35,6 +35,6 @@ void argmax(const raft::handle_t& handle,
RAFT_EXPECTS(out.extent(0) == in.extent(0),
"Size of output vector must equal number of rows in input matrix.");
detail::argmax(
in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream());
in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream());
}
} // namespace raft::matrix
40 changes: 40 additions & 0 deletions cpp/include/raft/matrix/argmin.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/matrix/detail/math.cuh>

namespace raft::matrix {

/**
* @brief Argmin: find the col idx with minimum value for each row
* @param[in] handle: raft handle
* @param[in] in: input matrix of size (n_rows, n_cols)
* @param[out] out: output vector of size n_rows
*/
template <typename math_t, typename idx_t, typename matrix_idx_t>
void argmin(const raft::handle_t& handle,
raft::device_matrix_view<const math_t, matrix_idx_t, row_major> in,
raft::device_vector_view<idx_t, matrix_idx_t> out)
{
RAFT_EXPECTS(out.extent(0) == in.extent(0),
"Size of output vector must equal number of rows in input matrix.");
detail::argmin(
in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream());
}
} // namespace raft::matrix
66 changes: 45 additions & 21 deletions cpp/include/raft/matrix/detail/math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -362,45 +362,69 @@ void matrixVectorBinarySub(Type* data,
stream);
}

// Computes the argmax(d_in) column-wise in a DxN matrix
template <typename T, typename IdxT, int TPB>
__global__ void argmaxKernel(const T* d_in, int D, int N, IdxT* argmax)
// Computes an argmin/argmax column-wise in a DxN matrix
template <typename RedOp, int TPB, typename T, typename OutT, typename IdxT>
__global__ void argReduceKernel(const T* d_in, IdxT D, IdxT N, OutT* out)
{
typedef cub::BlockReduce<cub::KeyValuePair<int, T>, TPB> BlockReduce;
typedef cub::
BlockReduce<cub::KeyValuePair<IdxT, T>, TPB, cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY>
BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

// compute maxIndex=argMax index for column
using KVP = cub::KeyValuePair<int, T>;
int rowStart = blockIdx.x * D;
KVP thread_data(-1, -raft::myInf<T>());
using KVP = cub::KeyValuePair<IdxT, T>;
IdxT rowStart = static_cast<IdxT>(blockIdx.x) * D;
KVP thread_data(0, std::is_same_v<RedOp, cub::ArgMax> ? -raft::myInf<T>() : raft::myInf<T>());

for (int i = threadIdx.x; i < D; i += TPB) {
int idx = rowStart + i;
thread_data = cub::ArgMax()(thread_data, KVP(i, d_in[idx]));
for (IdxT i = threadIdx.x; i < D; i += TPB) {
IdxT idx = rowStart + i;
thread_data = RedOp()(thread_data, KVP(i, d_in[idx]));
}

auto maxKV = BlockReduce(temp_storage).Reduce(thread_data, cub::ArgMax());
auto maxKV = BlockReduce(temp_storage).Reduce(thread_data, RedOp());

if (threadIdx.x == 0) { argmax[blockIdx.x] = maxKV.key; }
if (threadIdx.x == 0) { out[blockIdx.x] = maxKV.key; }
}

template <typename math_t, typename idx_t>
void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream)
/**
* @brief Computes an argmin/argmax coalesced reduction
*
* @tparam RedOp Reduction operation (cub::ArgMin or cub::ArgMax)
* @tparam math_t Value type
* @tparam out_t Output key type
* @tparam idx_t Matrix index type
* @param[in] in Input matrix (DxN column-major or NxD row-major)
* @param[in] D Dimension of the axis to reduce along
* @param[in] N Number of reductions
* @param[out] out Output keys (N)
* @param[in] stream CUDA stream
*/
template <typename RedOp, typename math_t, typename out_t, typename idx_t>
inline void argReduce(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream)
{
int D = n_rows;
int N = n_cols;
if (D <= 32) {
argmaxKernel<math_t, idx_t, 32><<<N, 32, 0, stream>>>(in, D, N, out);
argReduceKernel<RedOp, 32><<<N, 32, 0, stream>>>(in, D, N, out);
} else if (D <= 64) {
argmaxKernel<math_t, idx_t, 64><<<N, 64, 0, stream>>>(in, D, N, out);
argReduceKernel<RedOp, 64><<<N, 64, 0, stream>>>(in, D, N, out);
} else if (D <= 128) {
argmaxKernel<math_t, idx_t, 128><<<N, 128, 0, stream>>>(in, D, N, out);
argReduceKernel<RedOp, 128><<<N, 128, 0, stream>>>(in, D, N, out);
} else {
argmaxKernel<math_t, idx_t, 256><<<N, 256, 0, stream>>>(in, D, N, out);
argReduceKernel<RedOp, 256><<<N, 256, 0, stream>>>(in, D, N, out);
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename math_t, typename out_t, typename idx_t>
void argmin(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream)
{
argReduce<cub::ArgMin>(in, D, N, out, stream);
}

template <typename math_t, typename out_t, typename idx_t>
void argmax(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream)
{
argReduce<cub::ArgMax>(in, D, N, out, stream);
}

// Utility kernel needed for signFlip.
// Computes the argmax(abs(d_in)) column-wise in a DxN matrix followed by
// flipping the sign if the |max| value for each column is negative.
Expand Down
20 changes: 17 additions & 3 deletions cpp/include/raft/matrix/math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -302,16 +302,30 @@ void ratio(

/** @} */

/**
* @brief Argmin: find the row idx with minimum value for each column
* @param in: input matrix (column-major)
* @param n_rows: number of rows of input matrix
* @param n_cols: number of columns of input matrix
* @param out: output vector of size n_cols
* @param stream: cuda stream
*/
template <typename math_t, typename out_t, typename idx_t = int>
void argmin(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream)
{
detail::argmin(in, n_rows, n_cols, out, stream);
}

/**
* @brief Argmax: find the row idx with maximum value for each column
* @param in: input matrix
* @param in: input matrix (column-major)
* @param n_rows: number of rows of input matrix
* @param n_cols: number of columns of input matrix
* @param out: output vector of size n_cols
* @param stream: cuda stream
*/
template <typename math_t, typename idx_t>
void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream)
template <typename math_t, typename out_t, typename idx_t = int>
void argmax(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream)
{
detail::argmax(in, n_rows, n_cols, out, stream);
}
Expand Down
8 changes: 6 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/argmin.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/util/cuda_utils.cuh>

Expand Down Expand Up @@ -147,8 +148,11 @@ inline void predict_float_core(const handle_t& handle,
distances.data(),
n_clusters,
stream);
utils::argmin_along_rows(
n_rows, static_cast<IdxT>(n_clusters), distances.data(), labels, stream);

auto distances_const_view = raft::make_device_matrix_view<const float, IdxT, row_major>(
distances.data(), n_rows, static_cast<IdxT>(n_clusters));
auto labels_view = raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows);
raft::matrix::argmin(handle, distances_const_view, labels_view);
break;
}
default: {
Expand Down
60 changes: 0 additions & 60 deletions cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,66 +151,6 @@ inline void memzero(T* ptr, IdxT n_elems, rmm::cuda_stream_view stream)
}
}

template <typename IdxT, typename OutT>
__global__ void argmin_along_rows_kernel(IdxT n_rows, uint32_t n_cols, const float* a, OutT* out)
{
__shared__ OutT shm_ids[1024]; // NOLINT
__shared__ float shm_vals[1024]; // NOLINT
IdxT i = blockIdx.x;
if (i >= n_rows) return;
OutT min_idx = n_cols;
float min_val = raft::upper_bound<float>();
for (OutT j = threadIdx.x; j < n_cols; j += blockDim.x) {
if (min_val > a[j + n_cols * i]) {
min_val = a[j + n_cols * i];
min_idx = j;
}
}
shm_vals[threadIdx.x] = min_val;
shm_ids[threadIdx.x] = min_idx;
__syncthreads();
for (IdxT offset = blockDim.x / 2; offset > 0; offset >>= 1) {
if (threadIdx.x < offset) {
if (shm_vals[threadIdx.x] < shm_vals[threadIdx.x + offset]) {
} else if (shm_vals[threadIdx.x] > shm_vals[threadIdx.x + offset]) {
shm_vals[threadIdx.x] = shm_vals[threadIdx.x + offset];
shm_ids[threadIdx.x] = shm_ids[threadIdx.x + offset];
} else if (shm_ids[threadIdx.x] > shm_ids[threadIdx.x + offset]) {
shm_ids[threadIdx.x] = shm_ids[threadIdx.x + offset];
}
}
__syncthreads();
}
if (threadIdx.x == 0) { out[i] = shm_ids[0]; }
}

/**
* @brief Find index of the smallest element in each row.
*
* NB: device-only function
* TODO: specialize select_k for the case of `k == 1` and use that one instead.
*
* @tparam IdxT index type
* @tparam OutT output type
*
* @param n_rows
* @param n_cols
* @param[in] a device pointer to the row-major matrix [n_rows, n_cols]
* @param[out] out device pointer to the vector of selected indices [n_rows]
* @param stream
*/
template <typename IdxT, typename OutT>
inline void argmin_along_rows(
IdxT n_rows, IdxT n_cols, const float* a, OutT* out, rmm::cuda_stream_view stream)
{
IdxT block_dim = 1024;
while (block_dim > n_cols) {
block_dim /= 2;
}
block_dim = max(block_dim, (IdxT)128);
argmin_along_rows_kernel<IdxT, OutT><<<n_rows, block_dim, 0, stream>>>(n_rows, n_cols, a, out);
}

template <typename IdxT>
__global__ void dots_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, float* out)
{
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ if(BUILD_TESTS)
ConfigureTest(NAME MATRIX_TEST
PATH
test/matrix/argmax.cu
test/matrix/argmin.cu
test/matrix/columnSort.cu
test/matrix/diagonal.cu
test/matrix/gather.cu
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/matrix/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ template <typename T, typename IdxT>
struct ArgMaxInputs {
std::vector<T> input_matrix;
std::vector<IdxT> output_matrix;
std::size_t n_cols;
std::size_t n_rows;
std::size_t n_cols;
};

template <typename T, typename IdxT>
Expand Down
Loading

0 comments on commit 836bb58

Please sign in to comment.