Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shared-memory-cached kernel for reduce_cols_by_key to limit atomic conflicts #1050

Merged
merged 4 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ if(BUILD_BENCH)
bench/linalg/matrix_vector_op.cu
bench/linalg/norm.cu
bench/linalg/normalize.cu
bench/linalg/reduce_cols_by_key.cu
bench/linalg/reduce_rows_by_key.cu
bench/linalg/reduce.cu
bench/main.cpp
Expand Down
78 changes: 78 additions & 0 deletions cpp/bench/linalg/reduce_cols_by_key.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/linalg/reduce_cols_by_key.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

namespace raft::bench::linalg {

template <typename IdxT>
struct rcbk_params {
IdxT rows, cols;
IdxT keys;
};

template <typename IdxT>
inline auto operator<<(std::ostream& os, const rcbk_params<IdxT>& p) -> std::ostream&
{
os << p.rows << "#" << p.cols << "#" << p.keys;
return os;
}

template <typename T, typename KeyT, typename IdxT>
struct reduce_cols_by_key : public fixture {
reduce_cols_by_key(const rcbk_params<IdxT>& p)
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.keys, stream), keys(p.cols, stream)
{
raft::random::RngState rng{42};
raft::random::uniformInt(rng, keys.data(), p.cols, (KeyT)0, (KeyT)p.keys, stream);
}

void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
label_stream << params;
state.SetLabel(label_stream.str());

loop_on_state(state, [this]() {
raft::linalg::reduce_cols_by_key(
in.data(), keys.data(), out.data(), params.rows, params.cols, params.keys, stream, false);
});
}

protected:
rcbk_params<IdxT> params;
rmm::device_uvector<T> in, out;
rmm::device_uvector<KeyT> keys;
}; // struct reduce_cols_by_key

const std::vector<rcbk_params<int>> rcbk_inputs_i32 =
raft::util::itertools::product<rcbk_params<int>>(
{1, 10, 100, 1000}, {1000, 10000, 100000}, {8, 32, 128, 512, 2048});
const std::vector<rcbk_params<int64_t>> rcbk_inputs_i64 =
raft::util::itertools::product<rcbk_params<int64_t>>(
{1, 10, 100, 1000}, {1000, 10000, 100000}, {8, 32, 128, 512, 2048});

RAFT_BENCH_REGISTER((reduce_cols_by_key<float, uint32_t, int>), "", rcbk_inputs_i32);
RAFT_BENCH_REGISTER((reduce_cols_by_key<double, uint32_t, int>), "", rcbk_inputs_i32);
RAFT_BENCH_REGISTER((reduce_cols_by_key<float, uint32_t, int64_t>), "", rcbk_inputs_i64);
RAFT_BENCH_REGISTER((reduce_cols_by_key<double, uint32_t, int64_t>), "", rcbk_inputs_i64);

} // namespace raft::bench::linalg
1 change: 0 additions & 1 deletion cpp/include/raft/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
#include <raft/distance/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/linalg/reduce_cols_by_key.cuh>
#include <raft/linalg/reduce_rows_by_key.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/gather.cuh>
Expand Down
67 changes: 60 additions & 7 deletions cpp/include/raft/linalg/detail/reduce_cols_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ namespace detail {
///@todo: specialize this to support shared-mem based atomics

template <typename T, typename KeyIteratorT, typename IdxType>
__global__ void reduce_cols_by_key_kernel(
__global__ void reduce_cols_by_key_direct_kernel(
const T* data, const KeyIteratorT keys, T* out, IdxType nrows, IdxType ncols, IdxType nkeys)
{
typedef typename std::iterator_traits<KeyIteratorT>::value_type KeyType;

IdxType idx = blockIdx.x * blockDim.x + threadIdx.x;
IdxType idx = static_cast<IdxType>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx >= (nrows * ncols)) return;
///@todo: yikes! use fast-int-div
IdxType colId = idx % ncols;
Expand All @@ -43,6 +43,39 @@ __global__ void reduce_cols_by_key_kernel(
raft::myAtomicAdd(out + rowId * nkeys + key, data[idx]);
}

template <typename T, typename KeyIteratorT, typename IdxType>
__global__ void reduce_cols_by_key_cached_kernel(
const T* data, const KeyIteratorT keys, T* out, IdxType nrows, IdxType ncols, IdxType nkeys)
{
typedef typename std::iterator_traits<KeyIteratorT>::value_type KeyType;
extern __shared__ char smem[];
T* out_cache = reinterpret_cast<T*>(smem);

// Initialize the shared memory accumulators to 0.
for (IdxType idx = threadIdx.x; idx < nrows * nkeys; idx += blockDim.x) {
out_cache[idx] = T{0};
}
__syncthreads();

// Accumulate in shared memory
for (IdxType idx = static_cast<IdxType>(blockIdx.x) * blockDim.x + threadIdx.x;
idx < nrows * ncols;
idx += blockDim.x * static_cast<IdxType>(gridDim.x)) {
if (idx >= (nrows * ncols)) return;
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
IdxType colId = idx % ncols;
IdxType rowId = idx / ncols;
KeyType key = keys[colId];
raft::myAtomicAdd(out_cache + rowId * nkeys + key, data[idx]);
}

// Add the shared-memory accumulators to the global results.
__syncthreads();
for (IdxType idx = threadIdx.x; idx < nrows * nkeys; idx += blockDim.x) {
T val = out_cache[idx];
if (val != T{0}) { raft::myAtomicAdd(out + idx, val); }
}
}

/**
* @brief Computes the sum-reduction of matrix columns for each given key
* @tparam T the input data type (as well as the output reduced matrix)
Expand All @@ -60,6 +93,7 @@ __global__ void reduce_cols_by_key_kernel(
* @param ncols number of columns in the input data
* @param nkeys number of unique keys in the keys array
* @param stream cuda stream to launch the kernel onto
* @param reset_sums Whether to reset the output sums to zero before reducing
*/
template <typename T, typename KeyIteratorT, typename IdxType = int>
void reduce_cols_by_key(const T* data,
Expand All @@ -68,16 +102,35 @@ void reduce_cols_by_key(const T* data,
IdxType nrows,
IdxType ncols,
IdxType nkeys,
cudaStream_t stream)
cudaStream_t stream,
bool reset_sums)
{
typedef typename std::iterator_traits<KeyIteratorT>::value_type KeyType;

RAFT_CUDA_TRY(cudaMemsetAsync(out, 0, sizeof(T) * nrows * nkeys, stream));
constexpr int TPB = 256;
int nblks = (int)raft::ceildiv<IdxType>(nrows * ncols, TPB);
reduce_cols_by_key_kernel<<<nblks, TPB, 0, stream>>>(data, keys, out, nrows, ncols, nkeys);
// Memset the output to zero bto use atomics-based reduction.
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
if (reset_sums) { RAFT_CUDA_TRY(cudaMemsetAsync(out, 0, sizeof(T) * nrows * nkeys, stream)); }

// The cached version is used when the cache fits in shared memory and the number of input
// elements is above a threshold (the cached version is slightly slower for small input arrays,
// and orders of magnitude faster for large input arrays).
size_t cache_size = static_cast<size_t>(nrows * nkeys) * sizeof(T);
Nyrio marked this conversation as resolved.
Show resolved Hide resolved
if (cache_size <= 48000ull && nrows * ncols >= IdxType{8192}) {
constexpr int TPB = 256;
int n_sm = raft::getMultiProcessorCount();
int target_nblks = 4 * n_sm;
int max_nblks = raft::ceildiv<IdxType>(nrows * ncols, TPB);
int nblks = std::min(target_nblks, max_nblks);
reduce_cols_by_key_cached_kernel<<<nblks, TPB, cache_size, stream>>>(
data, keys, out, nrows, ncols, nkeys);
} else {
constexpr int TPB = 256;
int nblks = raft::ceildiv<IdxType>(nrows * ncols, TPB);
reduce_cols_by_key_direct_kernel<<<nblks, TPB, 0, stream>>>(
data, keys, out, nrows, ncols, nkeys);
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

}; // end namespace detail
}; // end namespace linalg
}; // end namespace raft
13 changes: 9 additions & 4 deletions cpp/include/raft/linalg/reduce_cols_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace linalg {
* @param ncols number of columns in the input data
* @param nkeys number of unique keys in the keys array
* @param stream cuda stream to launch the kernel onto
* @param reset_sums Whether to reset the output sums to zero before reducing
*/
template <typename T, typename KeyIteratorT, typename IdxType = int>
void reduce_cols_by_key(const T* data,
Expand All @@ -51,9 +52,10 @@ void reduce_cols_by_key(const T* data,
IdxType nrows,
IdxType ncols,
IdxType nkeys,
cudaStream_t stream)
cudaStream_t stream,
bool reset_sums = true)
{
detail::reduce_cols_by_key(data, keys, out, nrows, ncols, nkeys, stream);
detail::reduce_cols_by_key(data, keys, out, nrows, ncols, nkeys, stream, reset_sums);
}

/**
Expand All @@ -77,14 +79,16 @@ void reduce_cols_by_key(const T* data,
* @param[out] out the output reduced raft::device_matrix_view along columns (dim = nrows x nkeys).
* This will be assumed to be in row-major layout
* @param[in] nkeys number of unique keys in the keys array
* @param[in] reset_sums Whether to reset the output sums to zero before reducing
*/
template <typename ElementType, typename KeyType = ElementType, typename IndexType = std::uint32_t>
void reduce_cols_by_key(
const raft::handle_t& handle,
raft::device_matrix_view<const ElementType, IndexType, raft::row_major> data,
raft::device_vector_view<const KeyType, IndexType> keys,
raft::device_matrix_view<ElementType, IndexType, raft::row_major> out,
IndexType nkeys)
IndexType nkeys,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not part of this PR, but I wonder if there could someday be an overload of this function that would compute this value when not specified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but what kind of use cases do you have in mind? Since this primitive works with keys in the range [0, nkeys[ it's fair to assume nkeys will typically be known. For a more generic primitive working with arbitrary keys, we would first need to map them to such a range, and we can compute nkeys in the process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless you mean that instead of passing nkeys we should get it from the dimensions of out, which actually makes a lot of sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess the latter option makes sense. I was considering doing something similar in the knn APIs, since the shape of the output matrices already tell us k.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the following options, let me know if you have a preference:

  • Give nkeys a default value and ignore it, using output dims instead (might break compilation due to unused variable warning?)
  • Give nkeys a default value and if not overridden by the user, get it from the output dims.
  • Remove nkeys arg (potentially breaking but I doubt anyone is using this prim outside the raft codebase?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second option looks good to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nyrio do you want me to go ahead and merge this PR in the meantime?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we aren't in a hurry, I just did the change in this PR. I'd rather put nkeys in last position so we can specify reset_sums and ignore nkeys, but in the interest of not breaking the API I can't do that. If we don't mind breaking the API, I can simply remove the arg and always infer.

bool reset_sums = true)
{
RAFT_EXPECTS(out.extent(0) == data.extent(0) && out.extent(1) == nkeys,
"Output is not of size nrows * nkeys");
Expand All @@ -96,7 +100,8 @@ void reduce_cols_by_key(
data.extent(0),
data.extent(1),
nkeys,
handle.get_stream());
handle.get_stream(),
reset_sums);
}

/** @} */ // end of group reduce_cols_by_key
Expand Down
Loading