Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create cub-based argmin primitive and replace argmin_along_rows in ANN kmeans #912

Merged
merged 6 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ COMPILE_DIST_LIBRARY=OFF
ENABLE_NN_DEPENDENCIES=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;SPARSE_BENCH;RANDOM_BENCH"
BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"
ENABLE_thrust_DEPENDENCY=ON

CACHE_ARGS=""
Expand Down
6 changes: 6 additions & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ if(BUILD_BENCH)
bench/main.cpp
)

ConfigureBench(NAME MATRIX_BENCH
PATH
bench/matrix/argmin.cu
bench/main.cpp
)

ConfigureBench(NAME RANDOM_BENCH
PATH
bench/random/make_blobs.cu
Expand Down
71 changes: 71 additions & 0 deletions cpp/bench/matrix/argmin.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/matrix/argmin.cuh>
#include <raft/random/rng.cuh>

#include <rmm/device_uvector.hpp>

namespace raft::bench::linalg {

template <typename IdxT>
struct ArgminParams {
IdxT rows, cols;
};

template <typename T, typename OutT, typename IdxT>
struct Argmin : public fixture {
Argmin(const ArgminParams<IdxT>& p) : params(p) {}

void allocate_data(const ::benchmark::State& state) override
{
matrix = raft::make_device_matrix<T, IdxT>(handle, params.rows, params.cols);
indices = raft::make_device_vector<OutT, IdxT>(handle, params.rows);

raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle.sync_stream(stream);
}

void run_benchmark(::benchmark::State& state) override
{
loop_on_state(state, [this]() {
auto matrix_const_view = raft::make_device_matrix_view<const T, IdxT, row_major>(
matrix.data_handle(), matrix.extent(0), matrix.extent(1));
raft::matrix::argmin(handle, matrix_const_view, indices.view());
});
}

private:
ArgminParams<IdxT> params;
raft::device_matrix<T, IdxT> matrix;
raft::device_vector<OutT, IdxT> indices;
}; // struct Argmin

const std::vector<ArgminParams<int64_t>> argmin_inputs_i64{
{1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024},
{10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024},
{100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024},
{1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024},
{10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024},
};

RAFT_BENCH_REGISTER((Argmin<float, uint32_t, int64_t>), "", argmin_inputs_i64);
RAFT_BENCH_REGISTER((Argmin<double, uint32_t, int64_t>), "", argmin_inputs_i64);

} // namespace raft::bench::linalg
6 changes: 3 additions & 3 deletions cpp/include/raft/matrix/argmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
namespace raft::matrix {

/**
* @brief Argmax: find the row idx with maximum value for each column
* @brief Argmax: find the col idx with maximum value for each row
* @param[in] handle: raft handle
* @param[in] in: input matrix of size (n_rows, n_cols)
* @param[out] out: output vector of size n_cols
* @param[out] out: output vector of size n_rows
*/
template <typename math_t, typename idx_t, typename matrix_idx_t>
void argmax(const raft::handle_t& handle,
Expand All @@ -35,6 +35,6 @@ void argmax(const raft::handle_t& handle,
RAFT_EXPECTS(out.extent(0) == in.extent(0),
"Size of output vector must equal number of rows in input matrix.");
detail::argmax(
in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream());
in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream());
}
} // namespace raft::matrix
40 changes: 40 additions & 0 deletions cpp/include/raft/matrix/argmin.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/matrix/detail/math.cuh>

namespace raft::matrix {

/**
* @brief Argmin: find the col idx with minimum value for each row
* @param[in] handle: raft handle
* @param[in] in: input matrix of size (n_rows, n_cols)
* @param[out] out: output vector of size n_rows
*/
template <typename math_t, typename idx_t, typename matrix_idx_t>
void argmin(const raft::handle_t& handle,
raft::device_matrix_view<const math_t, matrix_idx_t, row_major> in,
raft::device_vector_view<idx_t, matrix_idx_t> out)
{
RAFT_EXPECTS(out.extent(0) == in.extent(0),
"Size of output vector must equal number of rows in input matrix.");
detail::argmin(
in.data_handle(), in.extent(1), in.extent(0), out.data_handle(), handle.get_stream());
}
} // namespace raft::matrix
66 changes: 45 additions & 21 deletions cpp/include/raft/matrix/detail/math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -362,45 +362,69 @@ void matrixVectorBinarySub(Type* data,
stream);
}

// Computes the argmax(d_in) column-wise in a DxN matrix
template <typename T, typename IdxT, int TPB>
__global__ void argmaxKernel(const T* d_in, int D, int N, IdxT* argmax)
// Computes an argmin/argmax column-wise in a DxN matrix
template <typename RedOp, int TPB, typename T, typename OutT, typename IdxT>
__global__ void argReduceKernel(const T* d_in, IdxT D, IdxT N, OutT* out)
{
typedef cub::BlockReduce<cub::KeyValuePair<int, T>, TPB> BlockReduce;
typedef cub::
BlockReduce<cub::KeyValuePair<IdxT, T>, TPB, cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY>
BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

// compute maxIndex=argMax index for column
using KVP = cub::KeyValuePair<int, T>;
int rowStart = blockIdx.x * D;
KVP thread_data(-1, -raft::myInf<T>());
using KVP = cub::KeyValuePair<IdxT, T>;
IdxT rowStart = static_cast<IdxT>(blockIdx.x) * D;
KVP thread_data(0, std::is_same_v<RedOp, cub::ArgMax> ? -raft::myInf<T>() : raft::myInf<T>());

for (int i = threadIdx.x; i < D; i += TPB) {
int idx = rowStart + i;
thread_data = cub::ArgMax()(thread_data, KVP(i, d_in[idx]));
for (IdxT i = threadIdx.x; i < D; i += TPB) {
IdxT idx = rowStart + i;
thread_data = RedOp()(thread_data, KVP(i, d_in[idx]));
}

auto maxKV = BlockReduce(temp_storage).Reduce(thread_data, cub::ArgMax());
auto maxKV = BlockReduce(temp_storage).Reduce(thread_data, RedOp());

if (threadIdx.x == 0) { argmax[blockIdx.x] = maxKV.key; }
if (threadIdx.x == 0) { out[blockIdx.x] = maxKV.key; }
}

template <typename math_t, typename idx_t>
void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream)
/**
* @brief Computes an argmin/argmax coalesced reduction
*
* @tparam RedOp Reduction operation (cub::ArgMin or cub::ArgMax)
* @tparam math_t Value type
* @tparam out_t Output key type
* @tparam idx_t Matrix index type
* @param[in] in Input matrix (DxN column-major or NxD row-major)
* @param[in] D Dimension of the axis to reduce along
* @param[in] N Number of reductions
* @param[out] out Output keys (N)
* @param[in] stream CUDA stream
*/
template <typename RedOp, typename math_t, typename out_t, typename idx_t>
inline void argReduce(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream)
{
int D = n_rows;
int N = n_cols;
if (D <= 32) {
argmaxKernel<math_t, idx_t, 32><<<N, 32, 0, stream>>>(in, D, N, out);
argReduceKernel<RedOp, 32><<<N, 32, 0, stream>>>(in, D, N, out);
} else if (D <= 64) {
argmaxKernel<math_t, idx_t, 64><<<N, 64, 0, stream>>>(in, D, N, out);
argReduceKernel<RedOp, 64><<<N, 64, 0, stream>>>(in, D, N, out);
} else if (D <= 128) {
argmaxKernel<math_t, idx_t, 128><<<N, 128, 0, stream>>>(in, D, N, out);
argReduceKernel<RedOp, 128><<<N, 128, 0, stream>>>(in, D, N, out);
} else {
argmaxKernel<math_t, idx_t, 256><<<N, 256, 0, stream>>>(in, D, N, out);
argReduceKernel<RedOp, 256><<<N, 256, 0, stream>>>(in, D, N, out);
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename math_t, typename out_t, typename idx_t>
void argmin(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream)
{
argReduce<cub::ArgMin>(in, D, N, out, stream);
}

template <typename math_t, typename out_t, typename idx_t>
void argmax(const math_t* in, idx_t D, idx_t N, out_t* out, cudaStream_t stream)
{
argReduce<cub::ArgMax>(in, D, N, out, stream);
}

// Utility kernel needed for signFlip.
// Computes the argmax(abs(d_in)) column-wise in a DxN matrix followed by
// flipping the sign if the |max| value for each column is negative.
Expand Down
20 changes: 17 additions & 3 deletions cpp/include/raft/matrix/math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -302,16 +302,30 @@ void ratio(

/** @} */

/**
* @brief Argmin: find the row idx with minimum value for each column
* @param in: input matrix (column-major)
* @param n_rows: number of rows of input matrix
* @param n_cols: number of columns of input matrix
* @param out: output vector of size n_cols
* @param stream: cuda stream
*/
template <typename math_t, typename out_t, typename idx_t = int>
void argmin(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream)
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
{
detail::argmin(in, n_rows, n_cols, out, stream);
}

/**
* @brief Argmax: find the row idx with maximum value for each column
* @param in: input matrix
* @param in: input matrix (column-major)
* @param n_rows: number of rows of input matrix
* @param n_cols: number of columns of input matrix
* @param out: output vector of size n_cols
* @param stream: cuda stream
*/
template <typename math_t, typename idx_t>
void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream)
template <typename math_t, typename out_t, typename idx_t = int>
void argmax(const math_t* in, idx_t n_rows, idx_t n_cols, out_t* out, cudaStream_t stream)
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
{
detail::argmax(in, n_rows, n_cols, out, stream);
}
Expand Down
8 changes: 6 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/argmin.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/util/cuda_utils.cuh>

Expand Down Expand Up @@ -147,8 +148,11 @@ inline void predict_float_core(const handle_t& handle,
distances.data(),
n_clusters,
stream);
utils::argmin_along_rows(
n_rows, static_cast<IdxT>(n_clusters), distances.data(), labels, stream);

auto distances_const_view = raft::make_device_matrix_view<const float, IdxT, row_major>(
distances.data(), n_rows, static_cast<IdxT>(n_clusters));
auto labels_view = raft::make_device_vector_view<LabelT, IdxT>(labels, n_rows);
raft::matrix::argmin(handle, distances_const_view, labels_view);
break;
}
default: {
Expand Down
60 changes: 0 additions & 60 deletions cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,66 +151,6 @@ inline void memzero(T* ptr, IdxT n_elems, rmm::cuda_stream_view stream)
}
}

template <typename IdxT, typename OutT>
__global__ void argmin_along_rows_kernel(IdxT n_rows, uint32_t n_cols, const float* a, OutT* out)
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
{
__shared__ OutT shm_ids[1024]; // NOLINT
__shared__ float shm_vals[1024]; // NOLINT
IdxT i = blockIdx.x;
if (i >= n_rows) return;
OutT min_idx = n_cols;
float min_val = raft::upper_bound<float>();
for (OutT j = threadIdx.x; j < n_cols; j += blockDim.x) {
if (min_val > a[j + n_cols * i]) {
min_val = a[j + n_cols * i];
min_idx = j;
}
}
shm_vals[threadIdx.x] = min_val;
shm_ids[threadIdx.x] = min_idx;
__syncthreads();
for (IdxT offset = blockDim.x / 2; offset > 0; offset >>= 1) {
if (threadIdx.x < offset) {
if (shm_vals[threadIdx.x] < shm_vals[threadIdx.x + offset]) {
} else if (shm_vals[threadIdx.x] > shm_vals[threadIdx.x + offset]) {
shm_vals[threadIdx.x] = shm_vals[threadIdx.x + offset];
shm_ids[threadIdx.x] = shm_ids[threadIdx.x + offset];
} else if (shm_ids[threadIdx.x] > shm_ids[threadIdx.x + offset]) {
shm_ids[threadIdx.x] = shm_ids[threadIdx.x + offset];
}
}
__syncthreads();
}
if (threadIdx.x == 0) { out[i] = shm_ids[0]; }
}

/**
* @brief Find index of the smallest element in each row.
*
* NB: device-only function
* TODO: specialize select_k for the case of `k == 1` and use that one instead.
*
* @tparam IdxT index type
* @tparam OutT output type
*
* @param n_rows
* @param n_cols
* @param[in] a device pointer to the row-major matrix [n_rows, n_cols]
* @param[out] out device pointer to the vector of selected indices [n_rows]
* @param stream
*/
template <typename IdxT, typename OutT>
inline void argmin_along_rows(
IdxT n_rows, IdxT n_cols, const float* a, OutT* out, rmm::cuda_stream_view stream)
{
IdxT block_dim = 1024;
while (block_dim > n_cols) {
block_dim /= 2;
}
block_dim = max(block_dim, (IdxT)128);
argmin_along_rows_kernel<IdxT, OutT><<<n_rows, block_dim, 0, stream>>>(n_rows, n_cols, a, out);
}

template <typename IdxT>
__global__ void dots_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, float* out)
{
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ if(BUILD_TESTS)
ConfigureTest(NAME MATRIX_TEST
PATH
test/matrix/argmax.cu
test/matrix/argmin.cu
test/matrix/columnSort.cu
test/matrix/diagonal.cu
test/matrix/gather.cu
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/matrix/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ template <typename T, typename IdxT>
struct ArgMaxInputs {
std::vector<T> input_matrix;
std::vector<IdxT> output_matrix;
std::size_t n_cols;
std::size_t n_rows;
std::size_t n_cols;
};

template <typename T, typename IdxT>
Expand Down
Loading