diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index 8d3321eb77..192d160d45 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -20,7 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include -#include +#include // raft::shfl_xor #endif namespace raft { /** diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 8fbd7a9c69..be6fed9f10 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -16,23 +16,20 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy +#include // raft::ceildiv, raft::shfl namespace raft { namespace distance { namespace detail { -#if (ENABLE_MEMCPY_ASYNC == 1) -#include -using namespace nvcuda::experimental; -#endif - template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; @@ -124,11 +121,10 @@ DI void updateReducedVal( template __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, const DataT* x, @@ -142,7 +138,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, int* mutex, ReduceOpT redOp, KVPReduceOpT pairRedOp, - CoreLambda core_op, + OpT distance_op, FinalLambda fin_op) { extern __shared__ char smem[]; @@ -163,24 +159,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (Sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; - } - } - } - // intra thread reduce const auto acccolid = threadIdx.x % P::AccThCols; const auto accrowid = threadIdx.x / P::AccThCols; @@ -229,18 +207,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, }; IdxT lda = k, ldb = k, ldd = n; - PairwiseDistances + row_major, + write_out> obj(x, y, m, @@ -251,9 +229,9 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, ldd, xn, yn, - nullptr, + nullptr, // Output pointer smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -289,9 +267,6 @@ void fusedL2NNImpl(OutT* min, constexpr auto maxVal = std::numeric_limits::max(); typedef KeyValuePair KVPair; - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel @@ -300,59 +275,25 @@ void fusedL2NNImpl(OutT* min, } constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } else { - auto fusedL2NN = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedL2NNkernel; + + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index a051bdf4cd..583476ede6 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -26,16 +26,12 @@ namespace detail { /** * @brief Device class for L1, L2 and cosine distance metrics. - * @tparam useNorms whether norms are needed * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) + * @tparam OpT A distance operation, e.g., cosine_distance_op. * @tparam EpilogueLambda applies an elementwise function to compute final values. Its signature is: template void epilogue_lambda @@ -53,19 +49,17 @@ namespace detail { * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine * @param[output] pD output matrix * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. - * @param core_op the core accumulation operation lambda + * @param distance_op the distance operation, e.g. cosine_distance_op * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ -template > struct PairwiseDistances : public BaseClass { + // Get accumulation type from distance_op + using AccT = typename OpT::AccT; + private: typedef Policy P; const DataT* xn; @@ -80,7 +77,7 @@ struct PairwiseDistances : public BaseClass { const DataT* const yBase; OutT* dOutput; char* smem; - CoreLambda core_op; + OpT distance_op; EpilogueLambda epilog_op; FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; @@ -106,7 +103,7 @@ struct PairwiseDistances : public BaseClass { const DataT* _yn, OutT* _dOutput, char* _smem, - CoreLambda _core_op, + OpT _distance_op, EpilogueLambda _epilog_op, FinalLambda _fin_op, rowEpilogueLambda _rowEpilog_op) @@ -116,7 +113,7 @@ struct PairwiseDistances : public BaseClass { yBase(_y), dOutput(_dOutput), smem(_smem), - core_op(_core_op), + distance_op(_distance_op), epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), @@ -156,15 +153,25 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); // Epilog: - if (useNorms) { + if (distance_op.use_norms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { store_output(tile_idx_m, tile_idx_n); } @@ -209,7 +216,7 @@ struct PairwiseDistances : public BaseClass { for (int j = 0; j < P::AccColsPerTh; ++j) { #pragma unroll for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); + distance_op.core(acc[i][j], this->regx[i][v], this->regy[j][v]); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 410dfa1080..b298391ef2 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -43,36 +43,20 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( extern __shared__ char smem[]; - using AccT = typename OpT::AccT; - - // Wrap operator back into lambdas. This is temporary and should be removed. - // See: https://github.com/rapidsai/raft/issues/1323 - auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - distance_op.core(acc, x, y); - }; - auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); - }; - + // The epilog is already provided by distance_op. Do not provide additional + // epilogs. + auto epilog_op = raft::void_op(); // No support for row_epilog_op. auto row_epilog_op = raft::void_op(); // Always write output constexpr bool write_out = true; constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances #include +#include +#include #include #include @@ -183,13 +185,11 @@ DI void updateSortedWarpQ( } } -template Pair; @@ -223,295 +223,275 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; - auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( - IdxT gridStrideY) { - if (gridDim.x == 1) { return; } - - Pair* shDumpKV = nullptr; - if (useNorms) { - shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } - - const int lid = threadIdx.x % warpSize; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - - // 0 -> consumer done consuming the buffer. - // -1 -> consumer started consuming the buffer - // -2 -> producer done filling the buffer - // 1 -> prod acquired to fill the buffer - if (blockIdx.x == 0) { - auto cta_processed = 0; - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - __syncwarp(); - - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - - while (cta_processed < gridDim.x - 1) { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) - ; - } - __threadfence(); - __syncthreads(); + auto rowEpilog_lambda = + [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { + if (gridDim.x == 1) { return; } + + int smem_offset = distance_op.template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + const int lid = threadIdx.x % warpSize; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + + // 0 -> consumer done consuming the buffer. + // -1 -> consumer started consuming the buffer + // -2 -> producer done filling the buffer + // 1 -> prod acquired to fill the buffer + if (blockIdx.x == 0) { + auto cta_processed = 0; + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + __syncwarp(); + + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + + while (cta_processed < gridDim.x - 1) { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - otherKV.value = out_dists[rowId * numOfNN + idx]; - otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } } } } - } - __threadfence(); - __syncthreads(); + __threadfence(); + __syncthreads(); - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } - __threadfence(); + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); // Perform merging of otherKV with topk's across warp. #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); } } + cta_processed++; } - cta_processed++; - } #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(0xffffffff, needSort); - if (needSort) { heapArr[i]->reduce(); } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(0xffffffff, needSort); + if (needSort) { heapArr[i]->reduce(); } + } } - } - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } else { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) - ; - } - __threadfence(); - __syncthreads(); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } else { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - for (int idx = lid; idx < numOfNN; idx += warpSize) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; - out_dists[rowId * numOfNN + idx] = KVPair.value; - out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + for (int idx = lid; idx < numOfNN; idx += warpSize) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + out_dists[rowId * numOfNN + idx] = KVPair.value; + out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + } } } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } - __threadfence(); - } - }; + __threadfence(); + __syncthreads(); - // epilogue operation lambda for final value calculation - auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - if (useNorms) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); } - } + }; - Pair* shDumpKV = nullptr; - if (useNorms) { - constexpr size_t shmemSize = - Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - shDumpKV = (Pair*)(&smem[shmemSize]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } + // epilogue operation lambda for final value calculation + auto epilog_lambda = + [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + int smem_offset = distance_op.template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); - - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + } } - } - if (gridStrideX > blockIdx.x * Policy::Nblk) { + if (gridStrideX > blockIdx.x * Policy::Nblk) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + } } + anyWarpTopKs += numValsWarpTopK[i]; } - anyWarpTopKs += numValsWarpTopK[i]; } - } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair* allWarpTopKs = (Pair*)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair* allWarpTopKs = (Pair*)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { #pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { numValsWarpTopK[i] += n; } + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { numValsWarpTopK[i] += n; } + } } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; - } - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; - numValsWarpTopK[i]++; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; + numValsWarpTopK[i]++; + } } } } + __syncwarp(); + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); + updateSortedWarpQ( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } - __syncwarp(); - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); - updateSortedWarpQ( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + } } } } - } - } else { + } else { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); - } - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { heapArr[i]->reduce(); } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + } } } - } - if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - }; + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + }; - raft::distance::detail::PairwiseDistances + write_out> obj(x, y, m, @@ -522,9 +502,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x ldd, _xn, _yn, - nullptr, + nullptr, // output ptr, can be null as write_out == false. smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -563,38 +543,32 @@ void fusedL2UnexpKnnImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = x - y; - acc += diff * diff; - }; - typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; if (numOfNN <= 32) { @@ -605,8 +579,10 @@ void fusedL2UnexpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - dim3 grid = raft::distance::detail::launchConfigGenerator( + const auto sharedMemSize = + distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); + + dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); if (grid.x > 1) { @@ -629,9 +605,8 @@ void fusedL2UnexpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, (int*)workspace, out_dists, @@ -754,36 +729,33 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(workspace != nullptr, "workspace is null"); dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; if (numOfNN <= 32) { @@ -794,9 +766,8 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + - ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + - (KPolicy::Mblk * numOfNN * sizeof(Pair)); + const auto sharedMemSize = + distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2ExpKnnRowMajor); int32_t* mutexes = nullptr; @@ -836,9 +807,8 @@ void fusedL2ExpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, mutexes, out_dists, diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index af67214193..adb73cb9b2 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -182,22 +182,20 @@ class FusedL2NNTest : public ::testing::TestWithParam> { int m = params.m; int n = params.n; int k = params.k; - MinAndDistanceReduceOp redOp; - fusedL2NN, int>( - out, - x.data(), - y.data(), - xn.data(), - yn.data(), - m, - n, - k, - (void*)workspace.data(), - redOp, - raft::distance::KVPMinReduce(), - Sqrt, - true, - stream); + + const bool init_out_buffer = true; + fusedL2NNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + Sqrt, + init_out_buffer, + stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } };