From 8c639d994801cb00b7af6c4756cf75aedcf693b9 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 7 Sep 2022 00:38:14 +0200 Subject: [PATCH] Optimize fusedL2NN when data is skinny (#794) The fusedL2NN kernel uses tiling to maximize performance. The current implementation assumes that the input matrices are at least 32 elements wide. When this is not the case, it performs redundant computations. This PR adds a policy to apply when the matrix is skinny (less than 32 elements wide). This results in a 1.5 - 2x performance improvement across GPU architectures. Authors: - Allard Hendriksen (https://github.com/ahendriksen) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/794 --- cpp/bench/spatial/fused_l2_nn.cu | 35 ++++++- .../raft/distance/detail/fused_l2_nn.cuh | 27 +++--- cpp/include/raft/distance/fused_l2_nn.cuh | 59 ++++++++++-- cpp/include/raft/linalg/contractions.cuh | 22 +++++ cpp/test/distance/fused_l2_nn.cu | 91 ++++++++++++++++--- 5 files changed, 195 insertions(+), 39 deletions(-) diff --git a/cpp/bench/spatial/fused_l2_nn.cu b/cpp/bench/spatial/fused_l2_nn.cu index e5b5dc377a..aa36483145 100644 --- a/cpp/bench/spatial/fused_l2_nn.cu +++ b/cpp/bench/spatial/fused_l2_nn.cu @@ -22,9 +22,12 @@ #include #include -#if defined RAFT_NN_COMPILED -#include -#endif +// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add +// back +// +// #if defined RAFT_NN_COMPILED +// #include +// #endif namespace raft::bench::spatial { @@ -73,6 +76,30 @@ struct fused_l2_nn : public fixture { false, stream); }); + + // Num distance calculations + int64_t num_dist_calcs = (int64_t)params.n * (int64_t)params.m; + + int64_t num_flops = 3 * num_dist_calcs * params.k; + + int64_t read_elts = (int64_t)params.n * params.k + (int64_t)params.m * params.k; + int64_t write_elts = (int64_t)params.n; + + state.counters["D/s"] = benchmark::Counter(num_dist_calcs, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["FLOP/s"] = benchmark::Counter( + num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000); + + state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(cub::KeyValuePair), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(float), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["K"] = benchmark::Counter(params.k); } private: @@ -88,9 +115,9 @@ const std::vector fused_l2_nn_input_vecs = { {32, 16384, 16384}, {64, 16384, 16384}, {128, 16384, 16384}, {256, 16384, 16384}, {512, 16384, 16384}, {1024, 16384, 16384}, {16384, 32, 16384}, {16384, 64, 16384}, {16384, 128, 16384}, {16384, 256, 16384}, {16384, 512, 16384}, {16384, 1024, 16384}, + {16384, 16384, 2}, {16384, 16384, 4}, {16384, 16384, 8}, {16384, 16384, 16}, {16384, 16384, 32}, {16384, 16384, 64}, {16384, 16384, 128}, {16384, 16384, 256}, {16384, 16384, 512}, {16384, 16384, 1024}, {16384, 16384, 16384}, - }; RAFT_BENCH_REGISTER(fused_l2_nn, "", fused_l2_nn_input_vecs); diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 81d02c410c..308f8a096a 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -92,14 +92,14 @@ DI void updateReducedVal( const auto lid = threadIdx.x % raft::WarpSize; const auto accrowid = threadIdx.x / P::AccThCols; - // for now have first lane from each warp update a unique output row. This - // will resolve hang issues with pre-Volta architectures + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures #pragma unroll for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == 0) { + if (lid == j * P::AccThCols) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + j + i * P::AccThRows; + auto rid = gridStrideY + accrowid + i * P::AccThRows; if (rid < m) { auto value = val[i]; while (atomicCAS(mutex + rid, 0, 1) == 1) @@ -111,14 +111,6 @@ DI void updateReducedVal( } } } - if (j < (raft::WarpSize / P::AccThCols) - 1) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols); - val[i] = {tmpkey, tmpvalue}; - } - } } } @@ -210,8 +202,10 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); KVPair tmp = {tmpkey, tmpvalue}; val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); } @@ -261,7 +255,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, template void fusedL2NNImpl(OutT* min, @@ -279,7 +273,8 @@ void fusedL2NNImpl(OutT* min, bool initOutBuffer, cudaStream_t stream) { - typedef typename linalg::Policy4x4::Policy P; + // The kernel policy is determined by fusedL2NN. + typedef Policy P; dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index ac8895c9ce..121ccbf60d 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace raft { @@ -99,20 +100,64 @@ void fusedL2NN(OutT* min, bool initOutBuffer, cudaStream_t stream) { + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + size_t bytes = sizeof(DataT) * k; if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } else { - detail::fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } } } } // namespace distance } // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index 5ccbd15c3d..8aed0cb4be 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -167,6 +167,28 @@ struct Policy4x4 { }; /** @} */ +/** + * A smaller k-block (8 instead of 32) with fewer threads per block (8x8 instead + * of 16x16), which is faster for raft::distance::fusedL2NN on skinny matrices, + * i.e., matrices with a small k dimension. + * + */ +template +struct Policy4x4Skinny { +}; + +template +struct Policy4x4Skinny { + typedef KernelPolicy Policy; + typedef ColKernelPolicy ColPolicy; +}; + +template +struct Policy4x4Skinny { + typedef KernelPolicy Policy; + typedef ColKernelPolicy ColPolicy; +}; + /** * @defgroup Policy2x8 16 elements per thread Policy with k-block = 16 * @{ diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 192f0c9a74..2a5b30e01f 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -23,6 +23,13 @@ #include #include +// TODO: Once fusedL2NN is specialized in the raft_distance shared library, add +// the following: +// +// #if defined RAFT_NN_COMPILED +// #include +// #endif + namespace raft { namespace distance { @@ -102,6 +109,23 @@ struct Inputs { DataT tolerance; int m, n, k; unsigned long long int seed; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } }; template @@ -231,19 +255,62 @@ template } const std::vector> inputsf = { - {0.001f, 32, 32, 32, 1234ULL}, {0.001f, 32, 64, 32, 1234ULL}, {0.001f, 64, 32, 32, 1234ULL}, - {0.001f, 64, 64, 32, 1234ULL}, {0.001f, 128, 32, 32, 1234ULL}, {0.001f, 128, 64, 32, 1234ULL}, - {0.001f, 128, 128, 64, 1234ULL}, {0.001f, 64, 128, 128, 1234ULL}, - - {0.001f, 32, 32, 34, 1234ULL}, {0.001f, 32, 64, 34, 1234ULL}, {0.001f, 64, 32, 34, 1234ULL}, - {0.001f, 64, 64, 34, 1234ULL}, {0.001f, 128, 32, 34, 1234ULL}, {0.001f, 128, 64, 34, 1234ULL}, - {0.001f, 128, 128, 66, 1234ULL}, {0.001f, 64, 128, 130, 1234ULL}, - - {0.001f, 32, 32, 33, 1234ULL}, {0.001f, 32, 64, 33, 1234ULL}, {0.001f, 64, 32, 33, 1234ULL}, - {0.001f, 64, 64, 33, 1234ULL}, {0.001f, 128, 32, 33, 1234ULL}, {0.001f, 128, 64, 33, 1234ULL}, - {0.001f, 128, 128, 65, 1234ULL}, {0.001f, 64, 128, 129, 1234ULL}, - + {0.001f, 32, 32, 32, 1234ULL}, + {0.001f, 32, 64, 32, 1234ULL}, + {0.001f, 64, 32, 32, 1234ULL}, + {0.001f, 64, 64, 32, 1234ULL}, + {0.001f, 128, 32, 32, 1234ULL}, + {0.001f, 128, 64, 32, 1234ULL}, + {0.001f, 128, 128, 64, 1234ULL}, + {0.001f, 64, 128, 128, 1234ULL}, + + {0.001f, 32, 32, 34, 1234ULL}, + {0.001f, 32, 64, 34, 1234ULL}, + {0.001f, 64, 32, 34, 1234ULL}, + {0.001f, 64, 64, 34, 1234ULL}, + {0.001f, 128, 32, 34, 1234ULL}, + {0.001f, 128, 64, 34, 1234ULL}, + {0.001f, 128, 128, 66, 1234ULL}, + {0.001f, 64, 128, 130, 1234ULL}, + + {0.001f, 32, 32, 33, 1234ULL}, + {0.001f, 32, 64, 33, 1234ULL}, + {0.001f, 64, 32, 33, 1234ULL}, + {0.001f, 64, 64, 33, 1234ULL}, + {0.001f, 128, 32, 33, 1234ULL}, + {0.001f, 128, 64, 33, 1234ULL}, + {0.001f, 128, 128, 65, 1234ULL}, + {0.001f, 64, 128, 129, 1234ULL}, {0.006f, 1805, 134, 2, 1234ULL}, + + // Repeat with smaller values of k + {0.006f, 32, 32, 1, 1234ULL}, + {0.001f, 32, 64, 2, 1234ULL}, + {0.001f, 64, 32, 3, 1234ULL}, + {0.001f, 64, 64, 4, 1234ULL}, + {0.001f, 128, 32, 5, 1234ULL}, + {0.001f, 128, 64, 6, 1234ULL}, + {0.001f, 128, 128, 7, 1234ULL}, + {0.001f, 64, 128, 8, 1234ULL}, + + {0.001f, 32, 32, 9, 1234ULL}, + {0.001f, 32, 64, 10, 1234ULL}, + {0.001f, 64, 32, 11, 1234ULL}, + {0.001f, 64, 64, 12, 1234ULL}, + {0.001f, 128, 32, 13, 1234ULL}, + {0.001f, 128, 64, 14, 1234ULL}, + {0.001f, 128, 128, 15, 1234ULL}, + {0.001f, 64, 128, 16, 1234ULL}, + + {0.001f, 32, 32, 17, 1234ULL}, + {0.001f, 32, 64, 18, 1234ULL}, + {0.001f, 64, 32, 19, 1234ULL}, + {0.001f, 64, 64, 20, 1234ULL}, + {0.001f, 128, 32, 21, 1234ULL}, + {0.001f, 128, 64, 22, 1234ULL}, + {0.001f, 128, 128, 23, 1234ULL}, + {0.00001, 64, 128, 24, 1234ULL}, + {0.001f, 1805, 134, 25, 1234ULL}, }; typedef FusedL2NNTest FusedL2NNTestF_Sq; TEST_P(FusedL2NNTestF_Sq, Result)