Skip to content

Commit

Permalink
Replace normalize_rows in ann_utils.cuh by a new rowNormalize p…
Browse files Browse the repository at this point in the history
…rim and improve performance for thin matrices (small `n_cols`) (#979)

This follows up on a discussion at #652 (comment). The main goal of this PR is to make this helper accessible as a raft primitive.

I also used the opportunity to look at the performance of this primitive, and have improved it for:

- Thin matrices: less than 32 threads per row with shuffle-based reductions.
- Thick matrices: cub-based reduction doing one row per block.

Here is an overview of the before/after performance on A100:

![2022-11-11_normalize_perf_float_int32](https://user-images.githubusercontent.com/17441062/201403965-bf68d368-b64b-4a1f-92f0-a5de03b9d1a8.png)

Authors:
  - Louis Sugy (https://github.com/Nyrio)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Artem M. Chirkin (https://github.com/achirkin)

URL: #979
  • Loading branch information
Nyrio authored Nov 17, 2022
1 parent f755fd8 commit e14bcbd
Show file tree
Hide file tree
Showing 12 changed files with 669 additions and 65 deletions.
1 change: 1 addition & 0 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ if(BUILD_BENCH)
bench/linalg/add.cu
bench/linalg/map_then_reduce.cu
bench/linalg/matrix_vector_op.cu
bench/linalg/normalize.cu
bench/linalg/reduce_rows_by_key.cu
bench/linalg/reduce.cu
bench/main.cpp
Expand Down
79 changes: 79 additions & 0 deletions cpp/bench/linalg/normalize.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <common/benchmark.hpp>
#include <raft/linalg/normalize.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>

namespace raft::bench::linalg {

template <typename IdxT>
struct normalize_input {
IdxT rows, cols;
};

template <typename IdxT>
inline auto operator<<(std::ostream& os, const normalize_input<IdxT>& p) -> std::ostream&
{
os << p.rows << "#" << p.cols;
return os;
}

template <typename T, typename IdxT>
struct rowNormalize : public fixture {
rowNormalize(const normalize_input<IdxT>& p)
: params(p), in(p.rows * p.cols, stream), out(p.rows * p.cols, stream)
{
raft::random::RngState rng{1234};
raft::random::uniform(rng, in.data(), p.rows * p.cols, (T)-10.0, (T)10.0, stream);
}

void run_benchmark(::benchmark::State& state) override
{
std::ostringstream label_stream;
label_stream << params;
state.SetLabel(label_stream.str());

loop_on_state(state, [this]() {
auto input_view = raft::make_device_matrix_view<const T, IdxT, raft::row_major>(
in.data(), params.rows, params.cols);
auto output_view = raft::make_device_matrix_view<T, IdxT, raft::row_major>(
out.data(), params.rows, params.cols);
raft::linalg::row_normalize(handle, input_view, output_view, raft::linalg::L2Norm);
});
}

private:
normalize_input<IdxT> params;
rmm::device_uvector<T> in, out;
}; // struct rowNormalize

const std::vector<normalize_input<int>> normalize_inputs_i32 =
raft::util::itertools::product<normalize_input<int>>(
{10, 100, 1000, 10000, 100000}, {8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384});
const std::vector<normalize_input<int64_t>> normalize_inputs_i64 =
raft::util::itertools::product<normalize_input<int64_t>>(
{10, 100, 1000, 10000, 100000}, {8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384});

RAFT_BENCH_REGISTER((rowNormalize<float, int>), "", normalize_inputs_i32);
RAFT_BENCH_REGISTER((rowNormalize<double, int>), "", normalize_inputs_i32);
RAFT_BENCH_REGISTER((rowNormalize<float, int64_t>), "", normalize_inputs_i64);
RAFT_BENCH_REGISTER((rowNormalize<double, int64_t>), "", normalize_inputs_i64);

} // namespace raft::bench::linalg
40 changes: 33 additions & 7 deletions cpp/include/raft/linalg/detail/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@

#pragma once

#include <raft/linalg/norm_types.hpp>
#include <raft/linalg/reduce.cuh>

namespace raft {
namespace linalg {
namespace detail {

/** different types of norms supported on the input buffers */
enum NormType { L1Norm = 0, L2Norm };

template <typename Type, typename IdxType, typename Lambda>
void rowNormCaller(Type* dots,
const Type* data,
Expand Down Expand Up @@ -64,7 +62,21 @@ void rowNormCaller(Type* dots,
raft::Sum<Type>(),
fin_op);
break;
default: ASSERT(false, "Invalid norm type passed! [%d]", type);
case LinfNorm:
raft::linalg::reduce<Type, Type, IdxType>(dots,
data,
D,
N,
(Type)0,
rowMajor,
true,
stream,
false,
raft::L1Op<Type>(),
raft::Max<Type>(),
fin_op);
break;
default: THROW("Unsupported norm type: %d", type);
};
}

Expand All @@ -89,7 +101,7 @@ void colNormCaller(Type* dots,
false,
stream,
false,
raft::L1Op<Type, IdxType>(),
raft::L1Op<Type>(),
raft::Sum<Type>(),
fin_op);
break;
Expand All @@ -103,11 +115,25 @@ void colNormCaller(Type* dots,
false,
stream,
false,
raft::L2Op<Type, IdxType>(),
raft::L2Op<Type>(),
raft::Sum<Type>(),
fin_op);
break;
default: ASSERT(false, "Invalid norm type passed! [%d]", type);
case LinfNorm:
raft::linalg::reduce<Type, Type, IdxType>(dots,
data,
D,
N,
(Type)0,
rowMajor,
false,
stream,
false,
raft::L1Op<Type>(),
raft::Max<Type>(),
fin_op);
break;
default: THROW("Unsupported norm type: %d", type);
};
}

Expand Down
187 changes: 187 additions & 0 deletions cpp/include/raft/linalg/detail/normalize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/util/cuda_utils.cuh>

namespace raft {
namespace linalg {
namespace detail {

template <int warpSize, int rpb>
struct NormalizeThinPolicy {
static constexpr int LogicalWarpSize = warpSize;
static constexpr int RowsPerBlock = rpb;
static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock;
};

template <typename Policy,
typename Type,
typename IdxType,
typename MainLambda,
typename ReduceLambda,
typename FinalLambda>
__global__ void __launch_bounds__(Policy::ThreadsPerBlock)
coalesced_normalize_thin_kernel(Type* out,
const Type* in,
IdxType D,
IdxType N,
Type init,
MainLambda main_op,
ReduceLambda reduce_op,
FinalLambda fin_op,
Type eps)
{
IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast<IdxType>(blockIdx.x));
if (i >= N) return;

Type acc = init;
for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) {
Type val = in[j + D * i];
acc = reduce_op(acc, main_op(val, j));
}
acc = raft::logicalWarpReduce<Policy::LogicalWarpSize>(acc, reduce_op);
acc = fin_op(acc);
if (acc <= eps) return;
for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) {
out[j + D * i] = in[j + D * i] / acc;
}
}

template <typename Policy,
typename Type,
typename IdxType,
typename MainLambda,
typename ReduceLambda,
typename FinalLambda>
inline void coalesced_normalize_thin(Type* out,
const Type* in,
IdxType D,
IdxType N,
Type init,
cudaStream_t stream,
MainLambda main_op,
ReduceLambda reduce_op,
FinalLambda fin_op,
Type eps)
{
dim3 grid(ceildiv(N, (IdxType)Policy::RowsPerBlock), 1, 1);
dim3 block(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1);
coalesced_normalize_thin_kernel<Policy>
<<<grid, block, 0, stream>>>(out, in, D, N, init, main_op, reduce_op, fin_op, eps);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <int TPB,
typename Type,
typename IdxType,
typename MainLambda,
typename ReduceLambda,
typename FinalLambda>
__global__ void __launch_bounds__(TPB) coalesced_normalize_medium_kernel(Type* out,
const Type* in,
IdxType D,
IdxType N,
Type init,
MainLambda main_op,
ReduceLambda reduce_op,
FinalLambda fin_op,
Type eps)
{
typedef cub::BlockReduce<Type, TPB, cub::BLOCK_REDUCE_RAKING> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ Type bcast_acc;
Type thread_data = init;
IdxType rowStart = blockIdx.x * D;
for (IdxType i = threadIdx.x; i < D; i += TPB) {
IdxType idx = rowStart + i;
thread_data = reduce_op(thread_data, main_op(in[idx], i));
}
Type acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op);
if (threadIdx.x == 0) { bcast_acc = fin_op(acc); }
__syncthreads();
if (bcast_acc <= eps) return;
for (IdxType i = threadIdx.x; i < D; i += TPB) {
IdxType idx = rowStart + i;
out[idx] = in[idx] / bcast_acc;
}
}

template <int TPB,
typename Type,
typename IdxType,
typename MainLambda,
typename ReduceLambda,
typename FinalLambda>
inline void coalesced_normalize_medium(Type* out,
const Type* in,
IdxType D,
IdxType N,
Type init,
cudaStream_t stream,
MainLambda main_op,
ReduceLambda reduce_op,
FinalLambda fin_op,
Type eps)
{
coalesced_normalize_medium_kernel<TPB>
<<<N, TPB, 0, stream>>>(out, in, D, N, init, main_op, reduce_op, fin_op, eps);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename Type,
typename IdxType,
typename MainLambda,
typename ReduceLambda,
typename FinalLambda>
void coalesced_normalize(Type* out,
const Type* in,
IdxType D,
IdxType N,
Type init,
cudaStream_t stream,
MainLambda main_op,
ReduceLambda reduce_op,
FinalLambda fin_op,
Type eps)
{
const IdxType numSMs = raft::getMultiProcessorCount();
if (D <= IdxType(256) || (D <= IdxType(512) && N >= 4 * numSMs)) {
if (D <= IdxType(2)) {
coalesced_normalize_thin<NormalizeThinPolicy<2, 64>>(
out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps);
} else if (D <= IdxType(4)) {
coalesced_normalize_thin<NormalizeThinPolicy<4, 32>>(
out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps);
} else if (D <= IdxType(8)) {
coalesced_normalize_thin<NormalizeThinPolicy<8, 16>>(
out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps);
} else if (D <= IdxType(16)) {
coalesced_normalize_thin<NormalizeThinPolicy<16, 8>>(
out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps);
} else {
coalesced_normalize_thin<NormalizeThinPolicy<32, 4>>(
out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps);
}
} else {
coalesced_normalize_medium<256>(out, in, D, N, init, stream, main_op, reduce_op, fin_op, eps);
}
}

} // namespace detail
} // namespace linalg
} // namespace raft
10 changes: 3 additions & 7 deletions cpp/include/raft/linalg/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@
#include "linalg_types.hpp"

#include <raft/core/device_mdspan.hpp>
#include <raft/linalg/norm_types.hpp>
#include <raft/util/input_validation.hpp>

namespace raft {
namespace linalg {

/** different types of norms supported on the input buffers */
using detail::L1Norm;
using detail::L2Norm;
using detail::NormType;

/**
* @brief Compute row-wise norm of the input matrix and perform fin_op lambda
*
Expand All @@ -44,7 +40,7 @@ using detail::NormType;
* @tparam Lambda device final lambda
* @tparam IdxType Integer type used to for addressing
* @param dots the output vector of row-wise dot products
* @param data the input matrix (currently assumed to be row-major)
* @param data the input matrix
* @param D number of columns of data
* @param N number of rows of data
* @param type the type of norm to be applied
Expand All @@ -71,7 +67,7 @@ void rowNorm(Type* dots,
* @tparam Lambda device final lambda
* @tparam IdxType Integer type used to for addressing
* @param dots the output vector of column-wise dot products
* @param data the input matrix (currently assumed to be row-major)
* @param data the input matrix
* @param D number of columns of data
* @param N number of rows of data
* @param type the type of norm to be applied
Expand Down
Loading

0 comments on commit e14bcbd

Please sign in to comment.