From 67e1328fc3680bbb2c11e32ca1999299c729336d Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 26 Aug 2022 13:18:32 +0200 Subject: [PATCH] Optimize fusedL2NN when data is skinny --- cpp/bench/spatial/fused_l2_nn.cu | 33 +++++++++-- .../raft/distance/detail/fused_l2_nn.cuh | 5 +- cpp/include/raft/distance/fused_l2_nn.cuh | 59 ++++++++++++++++--- cpp/include/raft/linalg/contractions.cuh | 26 ++++++++ 4 files changed, 109 insertions(+), 14 deletions(-) diff --git a/cpp/bench/spatial/fused_l2_nn.cu b/cpp/bench/spatial/fused_l2_nn.cu index dc3b507fbf..2a9c6714a6 100644 --- a/cpp/bench/spatial/fused_l2_nn.cu +++ b/cpp/bench/spatial/fused_l2_nn.cu @@ -17,14 +17,13 @@ #include #include #include -#include +#include #include #include #include -#if defined RAFT_NN_COMPILED -#include -#endif +// Note: do not include raft/spatial/knn/specializations.hpp based on +// RAFT_NN_COMPILED, as fusedL2NN is not specialized and not defined there. namespace raft::bench::spatial { @@ -73,6 +72,30 @@ struct fused_l2_nn : public fixture { false, stream); }); + + // Num distance calculations + int64_t num_dist_calcs = (int64_t)params.n * (int64_t)params.m; + + int64_t num_flops = 3 * num_dist_calcs * params.k; + + int64_t read_elts = (int64_t)params.n * params.k + (int64_t)params.m * params.k; + int64_t write_elts = (int64_t)params.n; + + state.counters["D/s"] = benchmark::Counter(num_dist_calcs, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["FLOP/s"] = benchmark::Counter( + num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000); + + state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["K"] = benchmark::Counter(params.k); } private: @@ -88,9 +111,9 @@ const std::vector fused_l2_nn_input_vecs = { {32, 16384, 16384}, {64, 16384, 16384}, {128, 16384, 16384}, {256, 16384, 16384}, {512, 16384, 16384}, {1024, 16384, 16384}, {16384, 32, 16384}, {16384, 64, 16384}, {16384, 128, 16384}, {16384, 256, 16384}, {16384, 512, 16384}, {16384, 1024, 16384}, + {16384, 16384, 2}, {16384, 16384, 4}, {16384, 16384, 8}, {16384, 16384, 16}, {16384, 16384, 32}, {16384, 16384, 64}, {16384, 16384, 128}, {16384, 16384, 256}, {16384, 16384, 512}, {16384, 16384, 1024}, {16384, 16384, 16384}, - }; RAFT_BENCH_REGISTER(fused_l2_nn, "", fused_l2_nn_input_vecs); diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 81d02c410c..39bd1508f8 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -261,7 +261,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, template void fusedL2NNImpl(OutT* min, @@ -279,7 +279,8 @@ void fusedL2NNImpl(OutT* min, bool initOutBuffer, cudaStream_t stream) { - typedef typename linalg::Policy4x4::Policy P; + // The kernel policy is determined by fusedL2NN. + typedef Policy P; dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index ac8895c9ce..121ccbf60d 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace raft { @@ -99,20 +100,64 @@ void fusedL2NN(OutT* min, bool initOutBuffer, cudaStream_t stream) { + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + size_t bytes = sizeof(DataT) * k; if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } else { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } } } // namespace distance } // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index 5ccbd15c3d..800632ada5 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -167,6 +167,32 @@ struct Policy4x4 { }; /** @} */ +/** + * @defgroup Policy4x4Skinny + * + * A smaller k-block (8 instead of 32) with fewer threads per block (8x8 instead + * of 16x16), which is faster for raft::distance::fusedL2NN on skinny matrices, + * i.e., matrices with a small k dimension. + * + * @{ + */ +template +struct Policy4x4Skinny { +}; + +template +struct Policy4x4Skinny { + typedef KernelPolicy Policy; + typedef ColKernelPolicy ColPolicy; +}; + +template +struct Policy4x4Skinny { + typedef KernelPolicy Policy; + typedef ColKernelPolicy ColPolicy; +}; +/** @} */ + /** * @defgroup Policy2x8 16 elements per thread Policy with k-block = 16 * @{