diff --git a/cpp/include/raft/device_atomics.cuh b/cpp/include/raft/device_atomics.cuh index dc8093ca1d..a4ebcc9900 100644 --- a/cpp/include/raft/device_atomics.cuh +++ b/cpp/include/raft/device_atomics.cuh @@ -179,10 +179,15 @@ struct genericAtomicOperationImpl { __forceinline__ __device__ T operator()(T* addr, T const& update_value, Op op) { using T_int = unsigned int; - T old_value = *addr; T assumed{old_value}; + if constexpr (std::is_same{} && (std::is_same{})) { + if (isnan(update_value)) { + return old_value; + } + } + do { assumed = old_value; const T new_value = op(old_value, update_value); @@ -191,13 +196,32 @@ struct genericAtomicOperationImpl { type_reinterpret(assumed), type_reinterpret(new_value)); old_value = type_reinterpret(ret); - } while (assumed != old_value); return old_value; } }; +// 4 bytes fp32 atomic Max operation +template <> +struct genericAtomicOperationImpl { + using T = float; + __forceinline__ __device__ T operator()(T* addr, T const& update_value, + DeviceMax op) { + if (isnan(update_value)) { + return *addr; + } + + T old = + (update_value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(update_value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(update_value))); + + return old; + } +}; + // 8 bytes atomic operation template struct genericAtomicOperationImpl { @@ -423,7 +447,6 @@ struct typesAtomicCASImpl { T_int ret = atomicCAS(reinterpret_cast(addr), type_reinterpret(compare), type_reinterpret(update_value)); - return type_reinterpret(ret); } }; diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 7f8523a587..f774d9d1ea 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -17,8 +17,9 @@ #include #include #include - +#include // TODO: Need to hide the PairwiseDistance class impl and expose to public API +#include #include #include "processing.hpp" @@ -145,21 +146,21 @@ DI void updateSortedWarpQ(myWarpSelect &heapArr, Pair *allWarpTopKs, int rowId, Pair tempKV; tempKV.value = raft::shfl(heapArr->warpK[i], srcLane); tempKV.key = raft::shfl(heapArr->warpV[i], srcLane); - const auto firstActiveLane = __ffs(activeLanes); - if (firstActiveLane == (lid + 1)) { + const auto firstActiveLane = __ffs(activeLanes) - 1; + if (firstActiveLane == lid) { heapArr->warpK[i] = KVPair.value; heapArr->warpV[i] = KVPair.key; - } else if (activeLanes & ((uint32_t)1 << lid)) { + } else if (lid > firstActiveLane) { heapArr->warpK[i] = tempKV.value; heapArr->warpV[i] = tempKV.key; } if (i == 0 && NumWarpQRegs > 1) { + heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); + heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); if (lid == 0) { heapArr->warpK[1] = tempKV.value; heapArr->warpV[1] = tempKV.key; } - heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); - heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); break; } } @@ -193,7 +194,16 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( } volatile int *mutex = mutexes; - Pair *shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + + Pair *shDumpKV = nullptr; + if (useNorms) { + shDumpKV = + (Pair *)(&smem[Policy::SmemSize + + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); + } else { + shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + } + const int lid = threadIdx.x % warpSize; const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); @@ -206,13 +216,11 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( myWarpSelect heapArr1(identity, keyMax, numOfNN); myWarpSelect heapArr2(identity, keyMax, numOfNN); myWarpSelect *heapArr[] = {&heapArr1, &heapArr2}; - __syncthreads(); + __syncwarp(); loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); while (cta_processed < gridDim.x - 1) { - Pair otherKV[Policy::AccRowsPerTh]; - if (threadIdx.x == 0) { int32_t old = -3; while (old != -1) { @@ -225,12 +233,19 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { const auto rowId = starty + i * Policy::AccThRows; - otherKV[i].value = identity; - otherKV[i].key = keyMax; - - if (lid < numOfNN && rowId < m) { - otherKV[i].value = out_dists[rowId * numOfNN + lid]; - otherKV[i].key = (uint32_t)out_inds[rowId * numOfNN + lid]; + const auto shMemRowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; +#pragma unroll + for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN && rowId < m) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } } } __threadfence(); @@ -241,14 +256,27 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( } // Perform merging of otherKV with topk's across warp. + __syncwarp(); + #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { const auto rowId = starty + i * Policy::AccThRows; + const auto shMemRowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; if (rowId < m) { - heapArr[i]->add(otherKV[i].value, otherKV[i].key); +#pragma unroll + for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); + } } } - cta_processed++; } #pragma unroll @@ -298,168 +326,176 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( }; // epilogue operation lambda for final value calculation - auto epilog_lambda = - [numOfNN, sqrt, m, n, ldd, out_dists, out_inds] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT * regxn, - DataT * regyn, IdxT gridStrideX, IdxT gridStrideY) { - if (sqrt) { + auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { + if (useNorms) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); - } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; } } - Pair *shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + } - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); + Pair *shDumpKV = nullptr; + if (useNorms) { + constexpr size_t shmemSize = + Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + shDumpKV = (Pair *)(&smem[shmemSize]); + } else { + shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + } - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect *heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, - numOfNN, starty); - } + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); + + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect *heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, + numOfNN, starty); } + } - if (gridStrideX > blockIdx.x * Policy::Nblk) { + if (gridStrideX > blockIdx.x * Policy::Nblk) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = - (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - numValsWarpTopK[i]++; - } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + numValsWarpTopK[i]++; } } - anyWarpTopKs += numValsWarpTopK[i]; } + anyWarpTopKs += numValsWarpTopK[i]; } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair *allWarpTopKs = (Pair *)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; + } + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair *allWarpTopKs = (Pair *)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { #pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = - __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { - numValsWarpTopK[i] += n; - } + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = + __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { + numValsWarpTopK[i] += n; } } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; + } - if (needScanSort[i]) { - const auto rowId = - (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { + if (needScanSort[i]) { + const auto rowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = - otherKV; - numValsWarpTopK[i]++; - } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = + otherKV; + numValsWarpTopK[i]++; } } } - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, - numOfNN); - updateSortedWarpQkNumWarpQRegisters>( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, + numOfNN); + updateSortedWarpQkNumWarpQRegisters>( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } } - __syncthreads(); + } + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = - (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, - numOfNN); - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, + numOfNN); } } } - } else { + } + } else { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = - (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; - } - heapArr[i]->add(otherKV.value, otherKV.key); + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; } + heapArr[i]->add(otherKV.value, otherKV.key); + } - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { - heapArr[i]->reduce(); - } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, - numOfNN); + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { + heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, + numOfNN); } } + } - if (((gridStrideX + Policy::Nblk * gridDim.x) > n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, - starty); - } - }; + if (((gridStrideX + Policy::Nblk * gridDim.x) > n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, + starty); + } + }; raft::distance::detail::PairwiseDistances< useNorms, DataT, AccT, OutT, IdxT, Policy, CoreLambda, @@ -472,10 +508,11 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( template -void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, - IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, OutT *out_dists, - IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, - void *workspace, size_t &worksize) { +void fusedL2UnexpKnnImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, + IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, + OutT *out_dists, IdxT *out_inds, IdxT numOfNN, + cudaStream_t stream, void *workspace, + size_t &worksize) { typedef typename raft::linalg::Policy2x8::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -495,25 +532,28 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef cub::KeyValuePair Pair; if (isRowMajor) { - constexpr auto fusedL2kNN32RowMajor = + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2kNN64RowMajor = + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; - auto fusedL2kNNRowMajor = fusedL2kNN32RowMajor; + auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; if (numOfNN <= 32) { - fusedL2kNNRowMajor = fusedL2kNN32RowMajor; + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; } else if (numOfNN <= 64) { - fusedL2kNNRowMajor = fusedL2kNN64RowMajor; + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn64RowMajor; } else { ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } + const auto sharedMemSize = + KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::detail::launchConfigGenerator( - m, n, KPolicy::SmemSize, fusedL2kNNRowMajor); + m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); + if (grid.x > 1) { const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) { @@ -525,10 +565,7 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, } } - const auto sharedMemSize = - KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - - fusedL2kNNRowMajor<<>>( + fusedL2UnexpKnnRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt, (uint32_t)numOfNN, (int *)workspace, out_dists, out_inds); } else { @@ -539,29 +576,147 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, template -void fusedL2kNN(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, - const DataT *x, const DataT *y, bool sqrt, OutT *out_dists, - IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, - void *workspace, size_t &worksize) { +void fusedL2UnexpKnn(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, bool sqrt, OutT *out_dists, + IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, + void *workspace, size_t &worksize) { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - fusedL2kNNImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, - out_inds, numOfNN, stream, workspace, worksize); + fusedL2UnexpKnnImpl( + x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, + workspace, worksize); } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - fusedL2kNNImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, - out_inds, numOfNN, stream, workspace, worksize); + fusedL2UnexpKnnImpl( + x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, + workspace, worksize); } else { - fusedL2kNNImpl( + fusedL2UnexpKnnImpl( + x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, + workspace, worksize); + } +} + +template +void fusedL2ExpKnnImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, + IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, OutT *out_dists, + IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, + void *workspace, size_t &worksize) { + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + ASSERT(isRowMajor, "Only Row major inputs are allowed"); + + ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || + (worksize < m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + dim3 blk(KPolicy::Nthreads); + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + acc += x * y; + }; + + auto fin_op = [] __device__(AccT d_val, int g_d_idx) { return d_val; }; + + typedef cub::KeyValuePair Pair; + + if (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = + fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = + fusedL2kNN; + + auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; + if (numOfNN <= 32) { + fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; + } else if (numOfNN <= 64) { + fusedL2ExpKnnRowMajor = fusedL2ExpKnn64RowMajor; + } else { + ASSERT(numOfNN <= 64, + "fusedL2kNN: num of nearest neighbors must be <= 64"); + } + + const auto sharedMemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + + (KPolicy::Mblk * numOfNN * sizeof(Pair)); + dim3 grid = raft::distance::detail::launchConfigGenerator( + m, n, sharedMemSize, fusedL2ExpKnnRowMajor); + int32_t *mutexes = nullptr; + if (grid.x > 1) { + const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); + const auto normsSize = + (x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT); + const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize; + if (worksize < requiredSize) { + worksize = requiredSize; + return; + } else { + mutexes = (int32_t *)((char *)workspace + normsSize); + CUDA_CHECK( + cudaMemsetAsync(mutexes, 0, sizeof(int32_t) * numMutexes, stream)); + } + } + + DataT *xn = (DataT *)workspace; + DataT *yn = (DataT *)workspace; + + auto norm_op = [] __device__(DataT in) { return in; }; + + if (x != y) { + yn += m; + raft::linalg::rowNorm(xn, x, k, m, raft::linalg::L2Norm, isRowMajor, + stream, norm_op); + raft::linalg::rowNorm(yn, y, k, n, raft::linalg::L2Norm, isRowMajor, + stream, norm_op); + } else { + raft::linalg::rowNorm(xn, x, k, n, raft::linalg::L2Norm, isRowMajor, + stream, norm_op); + } + fusedL2ExpKnnRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt, + (uint32_t)numOfNN, mutexes, out_dists, out_inds); + } else { + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void fusedL2ExpKnn(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, bool sqrt, OutT *out_dists, + IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, + void *workspace, size_t &worksize) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + fusedL2ExpKnnImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, + out_inds, numOfNN, stream, workspace, + worksize); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + fusedL2ExpKnnImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, + out_inds, numOfNN, stream, workspace, + worksize); + } else { + fusedL2ExpKnnImpl( x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, workspace, worksize); } } /** - * Compute the k-nearest neighbors using L2 unexpanded distance. + * Compute the k-nearest neighbors using L2 expanded/unexpanded distance. * @tparam value_idx * @tparam value_t @@ -576,13 +731,12 @@ void fusedL2kNN(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, * @param[in] rowMajorQuery are the query array in row-major layout? * @param[in] stream stream to order kernel launch */ -template -void l2_unexpanded_knn(size_t D, value_idx *out_inds, value_t *out_dists, - const value_t *index, const value_t *query, - size_t n_index_rows, size_t n_query_rows, int k, - bool rowMajorIndex, bool rowMajorQuery, - cudaStream_t stream, void *workspace, size_t &worksize) { +template +void fusedL2Knn(size_t D, value_idx *out_inds, value_t *out_dists, + const value_t *index, const value_t *query, size_t n_index_rows, + size_t n_query_rows, int k, bool rowMajorIndex, + bool rowMajorQuery, cudaStream_t stream, + raft::distance::DistanceType metric) { // Validate the input data ASSERT(k > 0, "l2Knn: k must be > 0"); ASSERT(D > 0, "l2Knn: D must be > 0"); @@ -595,17 +749,53 @@ void l2_unexpanded_knn(size_t D, value_idx *out_inds, value_t *out_dists, // Currently we only support same layout for x & y inputs. ASSERT(rowMajorIndex == rowMajorQuery, "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); - - bool sqrt = (distanceType == raft::distance::DistanceType::L2SqrtUnexpanded); - - if (rowMajorIndex) { - value_idx lda = D, ldb = D, ldd = n_index_rows; - fusedL2kNN( - n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, - out_dists, out_inds, k, stream, workspace, worksize); - } else { - // TODO: Add support for column major layout - } + // TODO: Add support for column major layout + ASSERT(rowMajorIndex == true, + "l2Knn: only rowMajor inputs are supported for now."); + + // Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support + // non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt. + constexpr bool sqrt = false; + + size_t worksize = 0, tempWorksize = 0; + rmm::device_uvector workspace(worksize, stream); + value_idx lda = D, ldb = D, ldd = n_index_rows; + + switch (metric) { + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Expanded: + tempWorksize = raft::distance::detail::getWorkspaceSize< + raft::distance::DistanceType::L2Expanded, float, float, float, + value_idx>(query, index, n_query_rows, n_index_rows, D); + worksize = tempWorksize; + workspace.resize(worksize, stream); + fusedL2ExpKnn( + n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, + out_dists, out_inds, k, stream, workspace.data(), worksize); + if (worksize > tempWorksize) { + workspace.resize(worksize, stream); + fusedL2ExpKnn( + n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, + out_dists, out_inds, k, stream, workspace.data(), worksize); + } + break; + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + fusedL2UnexpKnn( + n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, + out_dists, out_inds, k, stream, workspace.data(), worksize); + if (worksize) { + workspace.resize(worksize, stream); + fusedL2UnexpKnn(n_query_rows, n_index_rows, D, lda, ldb, ldd, + query, index, sqrt, out_dists, out_inds, k, + stream, workspace.data(), worksize); + } + break; + default: + printf("only L2 distance metric is supported\n"); + break; + }; } } // namespace detail diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 3a3f0a6513..da1217e3cf 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -271,44 +271,53 @@ void brute_force_knn_impl(std::vector &input, cudaStream_t stream = raft::select_stream(userStream, internalStreams, n_int_streams, i); - switch (metric) { - case raft::distance::DistanceType::Haversine: - - ASSERT(D == 2, - "Haversine distance requires 2 dimensions " - "(latitude / longitude)."); - - haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, - k, stream); - break; - default: - faiss::MetricType m = build_faiss_metric(metric); - - faiss::gpu::StandardGpuResources gpu_res; - - gpu_res.noTempMemory(); - gpu_res.setDefaultStream(device, stream); - - faiss::gpu::GpuDistanceParams args; - args.metric = m; - args.metricArg = metricArg; - args.k = k; - args.dims = D; - args.vectors = input[i]; - args.vectorsRowMajor = rowMajorIndex; - args.numVectors = sizes[i]; - args.queries = search_items; - args.queriesRowMajor = rowMajorQuery; - args.numQueries = n; - args.outDistances = out_d_ptr; - args.outIndices = out_i_ptr; - - /** + if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + (metric == raft::distance::DistanceType::L2Unexpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded)) { + fusedL2Knn(D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, + k, rowMajorIndex, rowMajorQuery, stream, metric); + } else { + switch (metric) { + case raft::distance::DistanceType::Haversine: + + ASSERT(D == 2, + "Haversine distance requires 2 dimensions " + "(latitude / longitude)."); + + haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], + n, k, stream); + break; + default: + faiss::MetricType m = build_faiss_metric(metric); + + faiss::gpu::StandardGpuResources gpu_res; + + gpu_res.noTempMemory(); + gpu_res.setDefaultStream(device, stream); + + faiss::gpu::GpuDistanceParams args; + args.metric = m; + args.metricArg = metricArg; + args.k = k; + args.dims = D; + args.vectors = input[i]; + args.vectorsRowMajor = rowMajorIndex; + args.numVectors = sizes[i]; + args.queries = search_items; + args.queriesRowMajor = rowMajorQuery; + args.numQueries = n; + args.outDistances = out_d_ptr; + args.outIndices = out_i_ptr; + + /** * @todo: Until FAISS supports pluggable allocation strategies, * we will not reap the benefits of the pool allocator for * avoiding device-wide synchronizations from cudaMalloc/cudaFree */ - bfKnn(&gpu_res, args); + bfKnn(&gpu_res, args); + } } CUDA_CHECK(cudaPeekAtLastError()); diff --git a/cpp/include/raft/spatial/knn/knn.hpp b/cpp/include/raft/spatial/knn/knn.hpp index 6472eaa80b..a2e9151dbc 100644 --- a/cpp/include/raft/spatial/knn/knn.hpp +++ b/cpp/include/raft/spatial/knn/knn.hpp @@ -116,7 +116,7 @@ inline void brute_force_knn( std::vector &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, bool rowMajorIndex = true, bool rowMajorQuery = true, std::vector *translations = nullptr, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + distance::DistanceType metric = distance::DistanceType::L2Expanded, float metric_arg = 2.0f) { ASSERT(input.size() == sizes.size(), "input and sizes vectors must be the same size"); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 4a89fd3273..14052293cf 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -88,6 +88,7 @@ add_executable(test_raft test/sparse/sort.cu test/sparse/symmetrize.cu test/spatial/knn.cu + test/spatial/fused_l2_knn.cu test/spatial/haversine.cu test/spatial/ball_cover.cu test/spatial/selection.cu diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index c43ce78cbf..ca30506df0 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -17,8 +17,6 @@ #include #include #include -#include -#include #include #include #include "../test_utils.h" @@ -88,30 +86,12 @@ void compute_bfknn(const raft::handle_t &handle, const value_t *X1, std::vector input_vec = {const_cast(X1)}; std::vector sizes_vec = {n}; - if (metric == raft::distance::DistanceType::Haversine) { - cudaStream_t *int_streams = nullptr; - std::vector *translations = nullptr; + cudaStream_t *int_streams = nullptr; + std::vector *translations = nullptr; - raft::spatial::knn::detail::brute_force_knn_impl( - input_vec, sizes_vec, d, const_cast(X2), n, inds, dists, k, - handle.get_stream(), int_streams, 0, true, true, translations, metric); - } else { - size_t worksize = 0; - void *workspace = nullptr; - raft::spatial::knn::detail::l2_unexpanded_knn< - raft::distance::DistanceType::L2SqrtUnexpanded, int64_t, value_t, false>( - (size_t)d, inds, dists, input_vec[0], X2, (size_t)sizes_vec[0], (size_t)n, - (int)k, true, true, handle.get_stream(), workspace, worksize); - if (worksize) { - rmm::device_uvector d_mutexes(worksize, handle.get_stream()); - workspace = d_mutexes.data(); - raft::spatial::knn::detail::l2_unexpanded_knn< - raft::distance::DistanceType::L2SqrtUnexpanded, int64_t, value_t, - false>((size_t)d, inds, dists, input_vec[0], X2, (size_t)sizes_vec[0], - (size_t)n, (int)k, true, true, handle.get_stream(), workspace, - worksize); - } - } + raft::spatial::knn::detail::brute_force_knn_impl( + input_vec, sizes_vec, d, const_cast(X2), n, inds, dists, k, + handle.get_stream(), int_streams, 0, true, true, translations, metric); } struct ToRadians { @@ -226,11 +206,16 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { d_train_inputs.data(), ToRadians()); } + cudaStream_t *int_streams = nullptr; + std::vector *translations = nullptr; + std::vector input_vec = {d_train_inputs.data()}; std::vector sizes_vec = {n}; - compute_bfknn(handle, d_train_inputs.data(), d_train_inputs.data(), n, d, k, - metric, d_ref_D.data(), d_ref_I.data()); + raft::spatial::knn::detail::brute_force_knn_impl( + input_vec, sizes_vec, d, d_train_inputs.data(), n, d_ref_I.data(), + d_ref_D.data(), k, handle.get_stream(), int_streams, 0, true, true, + translations, metric); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); diff --git a/cpp/test/spatial/fused_l2_knn.cu b/cpp/test/spatial/fused_l2_knn.cu new file mode 100644 index 0000000000..4930b47e0c --- /dev/null +++ b/cpp/test/spatial/fused_l2_knn.cu @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +struct FusedL2KNNInputs { + int num_queries; + int num_db_vecs; + int dim; + int k; + raft::distance::DistanceType metric_; +}; + +template +struct idx_dist_pair { + IdxT idx; + DistT dist; + compareDist eq_compare; + bool operator==(const idx_dist_pair &a) const { + if (idx == a.idx) return true; + if (eq_compare(dist, a.dist)) return true; + return false; + } + idx_dist_pair(IdxT x, DistT y, compareDist op) + : idx(x), dist(y), eq_compare(op) {} +}; + +template +testing::AssertionResult devArrMatchKnnPair( + const T *expected_idx, const T *actual_idx, const DistT *expected_dist, + const DistT *actual_dist, size_t rows, size_t cols, const DistT eps, + cudaStream_t stream = 0) { + size_t size = rows * cols; + std::unique_ptr exp_idx_h(new T[size]); + std::unique_ptr act_idx_h(new T[size]); + std::unique_ptr exp_dist_h(new DistT[size]); + std::unique_ptr act_dist_h(new DistT[size]); + raft::update_host(exp_idx_h.get(), expected_idx, size, stream); + raft::update_host(act_idx_h.get(), actual_idx, size, stream); + raft::update_host(exp_dist_h.get(), expected_dist, size, stream); + raft::update_host(act_dist_h.get(), actual_dist, size, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto exp_idx = exp_idx_h.get()[idx]; + auto act_idx = act_idx_h.get()[idx]; + auto exp_dist = exp_dist_h.get()[idx]; + auto act_dist = act_dist_h.get()[idx]; + idx_dist_pair exp_kvp(exp_idx, exp_dist, raft::CompareApprox(eps)); + idx_dist_pair act_kvp(act_idx, act_dist, raft::CompareApprox(eps)); + if (!(exp_kvp == act_kvp)) { + return testing::AssertionFailure() + << "actual=" << act_kvp.idx << "," << act_kvp.dist << "!=" + << "expected" << exp_kvp.idx << "," << exp_kvp.dist << " @" << i + << "," << j; + } + } + } + return testing::AssertionSuccess(); +} + +template +class FusedL2KNNTest : public ::testing::TestWithParam { + protected: + void testBruteForce() { + cudaStream_t stream = handle_.get_stream(); + + launchFaissBfknn(); + detail::fusedL2Knn(dim, raft_indices_, raft_distances_, database, + search_queries, num_db_vecs, num_queries, k_, true, true, + stream, metric); + + // verify. + devArrMatchKnnPair(faiss_indices_, raft_indices_, faiss_distances_, + raft_distances_, num_queries, k_, float(0.001), stream); + } + + void SetUp() override { + params_ = ::testing::TestWithParam::GetParam(); + num_queries = params_.num_queries; + num_db_vecs = params_.num_db_vecs; + dim = params_.dim; + k_ = params_.k; + metric = params_.metric_; + + cudaStream_t stream = handle_.get_stream(); + + raft::allocate(database, num_db_vecs * dim, stream, true); + raft::allocate(search_queries, num_queries * dim, stream, true); + + unsigned long long int seed = 1234ULL; + raft::random::Rng r(seed); + r.uniform(database, num_db_vecs * dim, T(-1.0), T(1.0), stream); + r.uniform(search_queries, num_queries * dim, T(-1.0), T(1.0), stream); + + raft::allocate(raft_indices_, num_queries * k_, stream, true); + raft::allocate(raft_distances_, num_queries * k_, stream, true); + raft::allocate(faiss_indices_, num_queries * k_, stream, true); + raft::allocate(faiss_distances_, num_queries * k_, stream, true); + } + + void TearDown() override { + cudaStream_t stream = handle_.get_stream(); + raft::deallocate_all(stream); + } + + void launchFaissBfknn() { + faiss::MetricType m = detail::build_faiss_metric(metric); + + faiss::gpu::StandardGpuResources gpu_res; + + gpu_res.noTempMemory(); + int device; + CUDA_CHECK(cudaGetDevice(&device)); + gpu_res.setDefaultStream(device, handle_.get_stream()); + + faiss::gpu::GpuDistanceParams args; + args.metric = m; + args.metricArg = 0; + args.k = k_; + args.dims = dim; + args.vectors = database; + args.vectorsRowMajor = true; + args.numVectors = num_db_vecs; + args.queries = search_queries; + args.queriesRowMajor = true; + args.numQueries = num_queries; + args.outDistances = faiss_distances_; + args.outIndices = faiss_indices_; + + bfKnn(&gpu_res, args); + } + + private: + raft::handle_t handle_; + FusedL2KNNInputs params_; + int num_queries; + int num_db_vecs; + int dim; + T *database; + T *search_queries; + int64_t *raft_indices_; + T *raft_distances_; + int64_t *faiss_indices_; + T *faiss_distances_; + int k_; + raft::distance::DistanceType metric; +}; + +const std::vector inputs = { + {100, 1000, 16, 10, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded}, + {100, 1000, 16, 50, raft::distance::DistanceType::L2Expanded}, + {20, 10000, 16, 10, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 50, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 32, 50, raft::distance::DistanceType::L2Expanded}, + {10000, 40000, 32, 30, raft::distance::DistanceType::L2Expanded}, + // L2 unexpanded + {100, 1000, 16, 10, raft::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Unexpanded}, + {100, 1000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, + {20, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 32, 50, raft::distance::DistanceType::L2Unexpanded}, + {10000, 40000, 32, 30, raft::distance::DistanceType::L2Unexpanded}}; + +typedef FusedL2KNNTest FusedL2KNNTestF; +TEST_P(FusedL2KNNTestF, FusedBruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(FusedL2KNNTest, FusedL2KNNTestF, + ::testing::ValuesIn(inputs)); + +} // namespace knn +} // namespace spatial +} // namespace raft