From 26213062e82f84a5fd889e12d90c27cf5b998439 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 20:42:40 +0100 Subject: [PATCH] Take distance_op in pairwise_distance_base Fixes issue #1323 --- cpp/include/raft/core/kvp.hpp | 2 +- .../raft/distance/detail/fused_l2_nn.cuh | 133 +++++------------ .../detail/pairwise_distance_base.cuh | 37 +++-- .../detail/pairwise_matrix/kernel_sm60.cuh | 28 +--- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 135 +++++++----------- cpp/test/distance/fused_l2_nn.cu | 9 +- 6 files changed, 124 insertions(+), 220 deletions(-) diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index 8d3321eb77..8abc379792 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -20,7 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include -#include +#include // raft::shfl_xor #endif namespace raft { /** diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 8fbd7a9c69..be6fed9f10 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -16,23 +16,20 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy +#include // raft::ceildiv, raft::shfl namespace raft { namespace distance { namespace detail { -#if (ENABLE_MEMCPY_ASYNC == 1) -#include -using namespace nvcuda::experimental; -#endif - template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; @@ -124,11 +121,10 @@ DI void updateReducedVal( template __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, const DataT* x, @@ -142,7 +138,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, int* mutex, ReduceOpT redOp, KVPReduceOpT pairRedOp, - CoreLambda core_op, + OpT distance_op, FinalLambda fin_op) { extern __shared__ char smem[]; @@ -163,24 +159,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (Sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; - } - } - } - // intra thread reduce const auto acccolid = threadIdx.x % P::AccThCols; const auto accrowid = threadIdx.x / P::AccThCols; @@ -229,18 +207,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, }; IdxT lda = k, ldb = k, ldd = n; - PairwiseDistances + row_major, + write_out> obj(x, y, m, @@ -251,9 +229,9 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, ldd, xn, yn, - nullptr, + nullptr, // Output pointer smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -289,9 +267,6 @@ void fusedL2NNImpl(OutT* min, constexpr auto maxVal = std::numeric_limits::max(); typedef KeyValuePair KVPair; - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel @@ -300,59 +275,25 @@ void fusedL2NNImpl(OutT* min, } constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } else { - auto fusedL2NN = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedL2NNkernel; + + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index a051bdf4cd..583476ede6 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -26,16 +26,12 @@ namespace detail { /** * @brief Device class for L1, L2 and cosine distance metrics. - * @tparam useNorms whether norms are needed * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) + * @tparam OpT A distance operation, e.g., cosine_distance_op. * @tparam EpilogueLambda applies an elementwise function to compute final values. Its signature is: template void epilogue_lambda @@ -53,19 +49,17 @@ namespace detail { * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine * @param[output] pD output matrix * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. - * @param core_op the core accumulation operation lambda + * @param distance_op the distance operation, e.g. cosine_distance_op * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ -template > struct PairwiseDistances : public BaseClass { + // Get accumulation type from distance_op + using AccT = typename OpT::AccT; + private: typedef Policy P; const DataT* xn; @@ -80,7 +77,7 @@ struct PairwiseDistances : public BaseClass { const DataT* const yBase; OutT* dOutput; char* smem; - CoreLambda core_op; + OpT distance_op; EpilogueLambda epilog_op; FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; @@ -106,7 +103,7 @@ struct PairwiseDistances : public BaseClass { const DataT* _yn, OutT* _dOutput, char* _smem, - CoreLambda _core_op, + OpT _distance_op, EpilogueLambda _epilog_op, FinalLambda _fin_op, rowEpilogueLambda _rowEpilog_op) @@ -116,7 +113,7 @@ struct PairwiseDistances : public BaseClass { yBase(_y), dOutput(_dOutput), smem(_smem), - core_op(_core_op), + distance_op(_distance_op), epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), @@ -156,15 +153,25 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); // Epilog: - if (useNorms) { + if (distance_op.use_norms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { store_output(tile_idx_m, tile_idx_n); } @@ -209,7 +216,7 @@ struct PairwiseDistances : public BaseClass { for (int j = 0; j < P::AccColsPerTh; ++j) { #pragma unroll for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); + distance_op.core(acc[i][j], this->regx[i][v], this->regy[j][v]); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 410dfa1080..b298391ef2 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -43,36 +43,20 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( extern __shared__ char smem[]; - using AccT = typename OpT::AccT; - - // Wrap operator back into lambdas. This is temporary and should be removed. - // See: https://github.com/rapidsai/raft/issues/1323 - auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - distance_op.core(acc, x, y); - }; - auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); - }; - + // The epilog is already provided by distance_op. Do not provide additional + // epilogs. + auto epilog_op = raft::void_op(); // No support for row_epilog_op. auto row_epilog_op = raft::void_op(); // Always write output constexpr bool write_out = true; constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances #include #include +#include +#include #include namespace raft { @@ -183,13 +185,11 @@ DI void updateSortedWarpQ( } } -template Pair; @@ -223,16 +223,12 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; - auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( + auto rowEpilog_lambda = [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__( IdxT gridStrideY) { if (gridDim.x == 1) { return; } - Pair* shDumpKV = nullptr; - if (useNorms) { - shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } + int smem_offset = distance_op.template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); const int lid = threadIdx.x % warpSize; const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); @@ -343,30 +339,16 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + auto epilog_lambda = [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT * regxn, DataT * regyn, IdxT gridStrideX, IdxT gridStrideY) { - if (useNorms) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - } - Pair* shDumpKV = nullptr; - if (useNorms) { - constexpr size_t shmemSize = - Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - shDumpKV = (Pair*)(&smem[shmemSize]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } + + int smem_offset = distance_op.template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); constexpr uint32_t mask = 0xffffffffu; const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); @@ -500,18 +482,17 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x } }; - raft::distance::detail::PairwiseDistances + write_out> obj(x, y, m, @@ -522,9 +503,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x ldd, _xn, _yn, - nullptr, + nullptr, // output ptr, can be null as write_out == false. smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -563,38 +544,32 @@ void fusedL2UnexpKnnImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = x - y; - acc += diff * diff; - }; - typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; if (numOfNN <= 32) { @@ -605,7 +580,10 @@ void fusedL2UnexpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); + const auto sharedMemSize = + distance_op.template shared_mem_size() + + KPolicy::Mblk * numOfNN * sizeof(Pair); + dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); @@ -629,9 +607,8 @@ void fusedL2UnexpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, (int*)workspace, out_dists, @@ -754,36 +731,33 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(workspace != nullptr, "workspace is null"); dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; if (numOfNN <= 32) { @@ -794,9 +768,9 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + - ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + - (KPolicy::Mblk * numOfNN * sizeof(Pair)); + const auto sharedMemSize = + distance_op.template shared_mem_size() + + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2ExpKnnRowMajor); int32_t* mutexes = nullptr; @@ -836,9 +810,8 @@ void fusedL2ExpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, mutexes, out_dists, diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index af67214193..c9a81455ac 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -182,8 +182,9 @@ class FusedL2NNTest : public ::testing::TestWithParam> { int m = params.m; int n = params.n; int k = params.k; - MinAndDistanceReduceOp redOp; - fusedL2NN, int>( + + const bool init_out_buffer = true; + fusedL2NNMinReduce, int>( out, x.data(), y.data(), @@ -193,10 +194,8 @@ class FusedL2NNTest : public ::testing::TestWithParam> { n, k, (void*)workspace.data(), - redOp, - raft::distance::KVPMinReduce(), Sqrt, - true, + init_out_buffer, stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); }