From 75c9f272b66d94aed88549e5b63a0eaff6355fac Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 24 Sep 2021 22:18:44 +0530 Subject: [PATCH 01/22] add fused L2 expanded kNN kernel, this is faster by at least 20-25% on higher dimensions than L2 unexpanded version --- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 179 +++++++++++++++++- .../knn/detail/knn_brute_force_faiss.cuh | 32 +++- cpp/include/raft/spatial/knn/knn.hpp | 2 +- 3 files changed, 202 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 9d00d9b9f4..90791fbbb2 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -18,6 +18,7 @@ #include #include #include +#include #include "processing.hpp" namespace raft { @@ -191,7 +192,16 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( } volatile int *mutex = mutexes; - Pair *shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + + Pair *shDumpKV = nullptr; + if (useNorms) { + shDumpKV = + (Pair *)(&smem[Policy::SmemSize + + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); + } else { + shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + } + const int lid = threadIdx.x % warpSize; const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); @@ -300,6 +310,16 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( [numOfNN, sqrt, m, n, ldd, out_dists, out_inds] __device__( AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT * regxn, DataT * regyn, IdxT gridStrideX, IdxT gridStrideY) { + if (useNorms) { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } + } + } + if (sqrt) { #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { @@ -309,7 +329,14 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( } } } - Pair *shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + Pair *shDumpKV = nullptr; + if (useNorms) { + shDumpKV = + (Pair *)(&smem[Policy::SmemSize + + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); + } else { + shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + } constexpr uint32_t mask = 0xffffffffu; const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); @@ -606,6 +633,154 @@ void l2_unexpanded_knn(size_t D, value_idx *out_inds, value_t *out_dists, } } +template +void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, + IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, OutT *out_dists, + IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, + void *workspace, size_t &worksize) { + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + ASSERT(isRowMajor, "Only Row major inputs are allowed"); + + ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || + (worksize < m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + dim3 blk(KPolicy::Nthreads); + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + acc += x * y; + }; + + auto fin_op = [] __device__(AccT d_val, int g_d_idx) { return d_val; }; + + typedef cub::KeyValuePair Pair; + + if (isRowMajor) { + constexpr auto fusedL2ExpkNN32RowMajor = + fusedL2kNN; + constexpr auto fusedL2ExpkNN64RowMajor = + fusedL2kNN; + + auto fusedL2ExpKNNRowMajor = fusedL2ExpkNN32RowMajor; + if (numOfNN <= 32) { + fusedL2ExpKNNRowMajor = fusedL2ExpkNN32RowMajor; + } else if (numOfNN <= 64) { + fusedL2ExpKNNRowMajor = fusedL2ExpkNN64RowMajor; + } else { + ASSERT(numOfNN <= 64, + "fusedL2kNN: num of nearest neighbors must be <= 64"); + } + + const auto sharedMemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + + (KPolicy::Mblk * numOfNN * sizeof(Pair)); + dim3 grid = raft::distance::launchConfigGenerator( + m, n, sharedMemSize, fusedL2ExpKNNRowMajor); + int32_t *mutexes = nullptr; + if (grid.x > 1) { + const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); + const auto normsSize = + (x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT); + const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize; + if (worksize < requiredSize) { + worksize = requiredSize; + return; + } else { + mutexes = (int32_t *)((char *)workspace + normsSize); + CUDA_CHECK( + cudaMemsetAsync(mutexes, 0, sizeof(int32_t) * numMutexes, stream)); + } + } + + DataT *xn = (DataT *)workspace; + DataT *yn = (DataT *)workspace; + + auto norm_op = [] __device__(DataT in) { return in; }; + + if (x != y) { + yn += m; + raft::linalg::rowNorm(xn, x, k, m, raft::linalg::L2Norm, isRowMajor, + stream, norm_op); + raft::linalg::rowNorm(yn, y, k, n, raft::linalg::L2Norm, isRowMajor, + stream, norm_op); + } else { + raft::linalg::rowNorm(xn, x, k, n, raft::linalg::L2Norm, isRowMajor, + stream, norm_op); + } + fusedL2ExpKNNRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt, + (uint32_t)numOfNN, mutexes, out_dists, out_inds); + } else { + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void fusedL2ExpkNN(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, bool sqrt, OutT *out_dists, + IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, + void *workspace, size_t &worksize) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + fusedL2ExpKNNImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, + out_inds, numOfNN, stream, workspace, + worksize); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + fusedL2ExpKNNImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, + out_inds, numOfNN, stream, workspace, + worksize); + } else { + fusedL2ExpKNNImpl( + x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, + workspace, worksize); + } +} + +template +void l2_expanded_knn(size_t D, value_idx *out_inds, value_t *out_dists, + const value_t *index, const value_t *query, + size_t n_index_rows, size_t n_query_rows, int k, + bool rowMajorIndex, bool rowMajorQuery, + cudaStream_t stream, void *workspace, size_t &worksize) { + // Validate the input data + ASSERT(k > 0, "l2Knn: k must be > 0"); + ASSERT(D > 0, "l2Knn: D must be > 0"); + ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0"); + ASSERT(index, "l2Knn: index must be provided (passed null)"); + ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0"); + ASSERT(query, "l2Knn: query must be provided (passed null)"); + ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)"); + ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)"); + // Currently we only support same layout for x & y inputs. + ASSERT(rowMajorIndex == rowMajorQuery, + "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); + + bool sqrt = (distanceType == raft::distance::DistanceType::L2SqrtExpanded); + + if (rowMajorIndex) { + value_idx lda = D, ldb = D, ldd = n_index_rows; + fusedL2ExpkNN( + n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, + out_dists, out_inds, k, stream, workspace, worksize); + } else { + // TODO: Add support for column major layout + } +} + } // namespace detail } // namespace knn } // namespace spatial diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 94ace19580..9f51f17eb0 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -34,6 +34,7 @@ #include #include +#include #include "fused_l2_knn.cuh" #include "haversine_distance.cuh" #include "processing.hpp" @@ -275,13 +276,29 @@ void brute_force_knn_impl(std::vector &input, std::vector &sizes, metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded)) { - size_t worksize = 0; - void *workspace = nullptr; - + size_t worksize = 0, tempWorksize = 0; + rmm::device_uvector workspace(worksize, stream); switch (metric) { case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2Unexpanded: case raft::distance::DistanceType::L2SqrtExpanded: + tempWorksize = raft::distance::getWorkspaceSize< + raft::distance::DistanceType::L2Expanded, float, float, float, + IntType>(search_items, input[i], n, sizes[i], D); + worksize = tempWorksize; + workspace.resize(worksize, stream); + l2_expanded_knn( + D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, + rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); + if (worksize > tempWorksize) { + workspace.resize(worksize, stream); + l2_expanded_knn( + D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, + rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); + } + break; + case raft::distance::DistanceType::L2Unexpanded: // Even for L2 Sqrt distance case we use non-sqrt version // as FAISS bfKNN only support non-sqrt metric & some tests // in RAFT/cuML (like Linkage) fails if we use L2 sqrt. @@ -292,14 +309,13 @@ void brute_force_knn_impl(std::vector &input, std::vector &sizes, l2_unexpanded_knn( D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, - rowMajorIndex, rowMajorQuery, stream, workspace, worksize); + rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); if (worksize) { - rmm::device_uvector d_mutexes(worksize, stream); - workspace = d_mutexes.data(); + workspace.resize(worksize, stream); l2_unexpanded_knn( D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, - rowMajorIndex, rowMajorQuery, stream, workspace, worksize); + rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); } break; default: diff --git a/cpp/include/raft/spatial/knn/knn.hpp b/cpp/include/raft/spatial/knn/knn.hpp index 71c547c281..4b7efecf5f 100644 --- a/cpp/include/raft/spatial/knn/knn.hpp +++ b/cpp/include/raft/spatial/knn/knn.hpp @@ -61,7 +61,7 @@ inline void brute_force_knn( std::vector &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, bool rowMajorIndex = true, bool rowMajorQuery = true, std::vector *translations = nullptr, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + distance::DistanceType metric = distance::DistanceType::L2Expanded, float metric_arg = 2.0f) { ASSERT(input.size() == sizes.size(), "input and sizes vectors must be the same size"); From 7a1e1e6c12ddc8cb848ec57262188828359fbc6b Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 27 Sep 2021 13:29:29 +0530 Subject: [PATCH 02/22] use lid > firsActiveLane instead of bitwise left shift and & for updateWarpQ --- cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 90791fbbb2..9e52a7c911 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -144,11 +144,11 @@ DI void updateSortedWarpQ(myWarpSelect &heapArr, Pair *allWarpTopKs, int rowId, Pair tempKV; tempKV.value = raft::shfl(heapArr->warpK[i], srcLane); tempKV.key = raft::shfl(heapArr->warpV[i], srcLane); - const auto firstActiveLane = __ffs(activeLanes); - if (firstActiveLane == (lid + 1)) { + const auto firstActiveLane = __ffs(activeLanes) - 1; + if (firstActiveLane == lid) { heapArr->warpK[i] = KVPair.value; heapArr->warpV[i] = KVPair.key; - } else if (activeLanes & ((uint32_t)1 << lid)) { + } else if (lid > firstActiveLane) { heapArr->warpK[i] = tempKV.value; heapArr->warpV[i] = tempKV.key; } From 290d28d17fdc7f58787462a47c8218384a95401e Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 28 Sep 2021 21:48:12 +0530 Subject: [PATCH 03/22] fix incorrect output for NN >32 case when taking prod-cons knn merge path --- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 9e52a7c911..d877e37e20 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -214,12 +214,11 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( myWarpSelect heapArr1(identity, keyMax, numOfNN); myWarpSelect heapArr2(identity, keyMax, numOfNN); myWarpSelect *heapArr[] = {&heapArr1, &heapArr2}; - __syncthreads(); + __syncwarp(); loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); while (cta_processed < gridDim.x - 1) { - Pair otherKV[Policy::AccRowsPerTh]; if (threadIdx.x == 0) { int32_t old = -3; @@ -233,12 +232,18 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { const auto rowId = starty + i * Policy::AccThRows; - otherKV[i].value = identity; - otherKV[i].key = keyMax; - - if (lid < numOfNN && rowId < m) { - otherKV[i].value = out_dists[rowId * numOfNN + lid]; - otherKV[i].key = (uint32_t)out_inds[rowId * numOfNN + lid]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; +#pragma unroll + for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN && rowId < m) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } } } __threadfence(); @@ -249,14 +254,26 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( } // Perform merging of otherKV with topk's across warp. + __syncwarp(); + #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { const auto rowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; if (rowId < m) { - heapArr[i]->add(otherKV[i].value, otherKV[i].key); + #pragma unroll + for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); + } } } - cta_processed++; } #pragma unroll From 5f3cea165129faf424635e63e5da7049245b749b Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 28 Sep 2021 21:55:23 +0530 Subject: [PATCH 04/22] fix clang format issues --- cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index d877e37e20..d285330a7a 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -219,7 +219,6 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); while (cta_processed < gridDim.x - 1) { - if (threadIdx.x == 0) { int32_t old = -3; while (old != -1) { @@ -232,7 +231,8 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { const auto rowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto shMemRowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; #pragma unroll for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { Pair otherKV; @@ -259,9 +259,10 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { const auto rowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto shMemRowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; if (rowId < m) { - #pragma unroll +#pragma unroll for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) { Pair otherKV; otherKV.value = identity; From 5b5f7a0a9a5e6964d56d7e1a5ca247068f4e2d8b Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 28 Sep 2021 22:00:04 +0530 Subject: [PATCH 05/22] enable testing of cuml using this raft fork --- ci/prtest.config | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ci/prtest.config b/ci/prtest.config index 08bdcaa3ab..ab46722186 100644 --- a/ci/prtest.config +++ b/ci/prtest.config @@ -1,6 +1,6 @@ RUN_CUGRAPH_LIBCUGRAPH_TESTS=OFF RUN_CUGRAPH_PYTHON_TESTS=OFF -RUN_CUML_LIBCUML_TESTS=OFF -RUN_CUML_PRIMS_TESTS=OFF -RUN_CUML_PYTHON_TESTS=OFF +RUN_CUML_LIBCUML_TESTS=ON +RUN_CUML_PRIMS_TESTS=ON +RUN_CUML_PYTHON_TESTS=ON From 738c60432247fc564ffcaee13bd0a2ae362e8bd8 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 29 Sep 2021 21:02:40 +0530 Subject: [PATCH 06/22] add custom atomicMax function which works fine if negative zeros are encountered, then the current atomicCAS based implementation --- cpp/include/raft/device_atomics.cuh | 13 ++++++++++++- cpp/include/raft/sparse/op/reduce.cuh | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/device_atomics.cuh b/cpp/include/raft/device_atomics.cuh index dc8093ca1d..a1b9594b67 100644 --- a/cpp/include/raft/device_atomics.cuh +++ b/cpp/include/raft/device_atomics.cuh @@ -423,7 +423,6 @@ struct typesAtomicCASImpl { T_int ret = atomicCAS(reinterpret_cast(addr), type_reinterpret(compare), type_reinterpret(update_value)); - return type_reinterpret(ret); } }; @@ -549,6 +548,18 @@ __forceinline__ __device__ T atomicMax(T* address, T val) { address, val, raft::device_atomics::detail::DeviceMax{}); } + +template +__forceinline__ __device__ T customAtomicMax(T* address, T val) { + float old; + val += T(0.0); + old = (val >= 0) ? __int_as_float(atomicMax((int *)address, __float_as_int(val))) : + __uint_as_float(atomicMin((unsigned int *)address, __float_as_uint(val))); + + return old; +} + + /** * @brief Overloads for `atomicCAS` * diff --git a/cpp/include/raft/sparse/op/reduce.cuh b/cpp/include/raft/sparse/op/reduce.cuh index 09a35720fb..57f647feaa 100644 --- a/cpp/include/raft/sparse/op/reduce.cuh +++ b/cpp/include/raft/sparse/op/reduce.cuh @@ -67,7 +67,7 @@ __global__ void max_duplicates_kernel(const value_idx *src_rows, if (tid < nnz) { value_idx idx = index[tid]; - atomicMax(&out_vals[idx], src_vals[tid]); + customAtomicMax(&out_vals[idx], src_vals[tid]); out_rows[idx] = src_rows[tid]; out_cols[idx] = src_cols[tid]; } From 352cc2dc560201ee36d87b49b1a92baac8cdcf29 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 30 Sep 2021 17:59:17 +0530 Subject: [PATCH 07/22] fix hang in raft atomicMax of fp32 when the inputs are NaNs --- cpp/include/raft/device_atomics.cuh | 21 ++++++++++++--------- cpp/include/raft/sparse/op/reduce.cuh | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/device_atomics.cuh b/cpp/include/raft/device_atomics.cuh index d8297cf622..83b54533c3 100644 --- a/cpp/include/raft/device_atomics.cuh +++ b/cpp/include/raft/device_atomics.cuh @@ -179,10 +179,15 @@ struct genericAtomicOperationImpl { __forceinline__ __device__ T operator()(T* addr, T const& update_value, Op op) { using T_int = unsigned int; - T old_value = *addr; T assumed{old_value}; + if (std::is_same{}) { + if (isnan(update_value)) { + return update_value; + } + } + do { assumed = old_value; const T new_value = op(old_value, update_value); @@ -191,7 +196,6 @@ struct genericAtomicOperationImpl { type_reinterpret(assumed), type_reinterpret(new_value)); old_value = type_reinterpret(ret); - } while (assumed != old_value); return old_value; @@ -548,18 +552,17 @@ __forceinline__ __device__ T atomicMax(T* address, T val) { address, val, raft::device_atomics::detail::DeviceMax{}); } - template __forceinline__ __device__ T customAtomicMax(T* address, T val) { - float old; - //val += T(0.0); - old = (val >= 0) ? __int_as_float(atomicMax((int *)address, __float_as_int(val))) : - __uint_as_float(atomicMin((unsigned int *)address, __float_as_uint(val))); + float old; + old = (val >= 0) + ? __int_as_float(atomicMax((int*)address, __float_as_int(val))) + : __uint_as_float( + atomicMin((unsigned int*)address, __float_as_uint(val))); - return old; + return old; } - /** * @brief Overloads for `atomicCAS` * diff --git a/cpp/include/raft/sparse/op/reduce.cuh b/cpp/include/raft/sparse/op/reduce.cuh index 57f647feaa..09a35720fb 100644 --- a/cpp/include/raft/sparse/op/reduce.cuh +++ b/cpp/include/raft/sparse/op/reduce.cuh @@ -67,7 +67,7 @@ __global__ void max_duplicates_kernel(const value_idx *src_rows, if (tid < nnz) { value_idx idx = index[tid]; - customAtomicMax(&out_vals[idx], src_vals[tid]); + atomicMax(&out_vals[idx], src_vals[tid]); out_rows[idx] = src_rows[tid]; out_cols[idx] = src_cols[tid]; } From aa8ef096aa6dd9915ae8a8534dbd0142c35562c5 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 5 Oct 2021 20:43:33 +0530 Subject: [PATCH 08/22] remove redundant processing.hpp included in fused_l2_knn --- cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index d285330a7a..422e248ea3 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -19,7 +19,6 @@ #include #include #include -#include "processing.hpp" namespace raft { namespace spatial { From 6072281734b7a0882b97d2c698c7ec5a01a2ba6a Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 6 Oct 2021 12:54:41 +0530 Subject: [PATCH 09/22] refactor fused L2 KNN main function to call both L2 expanded/unexpanded. make function namings consistent --- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 184 ++++++++---------- .../knn/detail/knn_brute_force_faiss.cuh | 31 ++- cpp/test/spatial/ball_cover.cu | 4 +- 3 files changed, 102 insertions(+), 117 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 422e248ea3..9412898bca 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -514,10 +514,11 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( template -void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, - IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, OutT *out_dists, - IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, - void *workspace, size_t &worksize) { +void fusedL2UnexpKnnImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, + IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, + OutT *out_dists, IdxT *out_inds, IdxT numOfNN, + cudaStream_t stream, void *workspace, + size_t &worksize) { typedef typename raft::linalg::Policy2x8::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -537,25 +538,25 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef cub::KeyValuePair Pair; if (isRowMajor) { - constexpr auto fusedL2kNN32RowMajor = + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2kNN64RowMajor = + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; - auto fusedL2kNNRowMajor = fusedL2kNN32RowMajor; + auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; if (numOfNN <= 32) { - fusedL2kNNRowMajor = fusedL2kNN32RowMajor; + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; } else if (numOfNN <= 64) { - fusedL2kNNRowMajor = fusedL2kNN64RowMajor; + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn64RowMajor; } else { ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } dim3 grid = raft::distance::launchConfigGenerator( - m, n, KPolicy::SmemSize, fusedL2kNNRowMajor); + m, n, KPolicy::SmemSize, fusedL2UnexpKnnRowMajor); if (grid.x > 1) { const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) { @@ -570,7 +571,7 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, const auto sharedMemSize = KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - fusedL2kNNRowMajor<<>>( + fusedL2UnexpKnnRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt, (uint32_t)numOfNN, (int *)workspace, out_dists, out_inds); } else { @@ -581,78 +582,32 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, template -void fusedL2kNN(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, - const DataT *x, const DataT *y, bool sqrt, OutT *out_dists, - IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, - void *workspace, size_t &worksize) { +void fusedL2UnexpKnn(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, bool sqrt, OutT *out_dists, + IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, + void *workspace, size_t &worksize) { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - fusedL2kNNImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, - out_inds, numOfNN, stream, workspace, worksize); + fusedL2UnexpKnnImpl( + x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, + workspace, worksize); } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - fusedL2kNNImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, - out_inds, numOfNN, stream, workspace, worksize); - } else { - fusedL2kNNImpl( + fusedL2UnexpKnnImpl( x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, workspace, worksize); - } -} - -/** - * Compute the k-nearest neighbors using L2 unexpanded distance. - - * @tparam value_idx - * @tparam value_t - * @param[out] out_inds output indices array on device (size n_query_rows * k) - * @param[out] out_dists output dists array on device (size n_query_rows * k) - * @param[in] index input index array on device (size n_index_rows * D) - * @param[in] query input query array on device (size n_query_rows * D) - * @param[in] n_index_rows number of rows in index array - * @param[in] n_query_rows number of rows in query array - * @param[in] k number of closest neighbors to return - * @param[in] rowMajorIndex are the index arrays in row-major layout? - * @param[in] rowMajorQuery are the query array in row-major layout? - * @param[in] stream stream to order kernel launch - */ -template -void l2_unexpanded_knn(size_t D, value_idx *out_inds, value_t *out_dists, - const value_t *index, const value_t *query, - size_t n_index_rows, size_t n_query_rows, int k, - bool rowMajorIndex, bool rowMajorQuery, - cudaStream_t stream, void *workspace, size_t &worksize) { - // Validate the input data - ASSERT(k > 0, "l2Knn: k must be > 0"); - ASSERT(D > 0, "l2Knn: D must be > 0"); - ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0"); - ASSERT(index, "l2Knn: index must be provided (passed null)"); - ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0"); - ASSERT(query, "l2Knn: query must be provided (passed null)"); - ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)"); - ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)"); - // Currently we only support same layout for x & y inputs. - ASSERT(rowMajorIndex == rowMajorQuery, - "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); - - bool sqrt = (distanceType == raft::distance::DistanceType::L2SqrtUnexpanded); - - if (rowMajorIndex) { - value_idx lda = D, ldb = D, ldd = n_index_rows; - fusedL2kNN( - n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, - out_dists, out_inds, k, stream, workspace, worksize); } else { - // TODO: Add support for column major layout + fusedL2UnexpKnnImpl( + x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, + workspace, worksize); } } template -void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, +void fusedL2ExpKnnImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, OutT *out_dists, IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, void *workspace, size_t &worksize) { @@ -679,18 +634,18 @@ void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef cub::KeyValuePair Pair; if (isRowMajor) { - constexpr auto fusedL2ExpkNN32RowMajor = + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpkNN64RowMajor = + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; - auto fusedL2ExpKNNRowMajor = fusedL2ExpkNN32RowMajor; + auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; if (numOfNN <= 32) { - fusedL2ExpKNNRowMajor = fusedL2ExpkNN32RowMajor; + fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; } else if (numOfNN <= 64) { - fusedL2ExpKNNRowMajor = fusedL2ExpkNN64RowMajor; + fusedL2ExpKnnRowMajor = fusedL2ExpKnn64RowMajor; } else { ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); @@ -700,7 +655,7 @@ void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::launchConfigGenerator( - m, n, sharedMemSize, fusedL2ExpKNNRowMajor); + m, n, sharedMemSize, fusedL2ExpKnnRowMajor); int32_t *mutexes = nullptr; if (grid.x > 1) { const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); @@ -732,7 +687,7 @@ void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, raft::linalg::rowNorm(xn, x, k, n, raft::linalg::L2Norm, isRowMajor, stream, norm_op); } - fusedL2ExpKNNRowMajor<<>>( + fusedL2ExpKnnRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt, (uint32_t)numOfNN, mutexes, out_dists, out_inds); } else { @@ -743,36 +698,52 @@ void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, template -void fusedL2ExpkNN(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, +void fusedL2ExpKnn(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, const DataT *x, const DataT *y, bool sqrt, OutT *out_dists, IdxT *out_inds, IdxT numOfNN, cudaStream_t stream, void *workspace, size_t &worksize) { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - fusedL2ExpKNNImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, workspace, worksize); } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - fusedL2ExpKNNImpl(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, workspace, worksize); } else { - fusedL2ExpKNNImpl( + fusedL2ExpKnnImpl( x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream, workspace, worksize); } } +/** + * Compute the k-nearest neighbors using L2 expanded/unexpanded distance. + + * @tparam value_idx + * @tparam value_t + * @param[out] out_inds output indices array on device (size n_query_rows * k) + * @param[out] out_dists output dists array on device (size n_query_rows * k) + * @param[in] index input index array on device (size n_index_rows * D) + * @param[in] query input query array on device (size n_query_rows * D) + * @param[in] n_index_rows number of rows in index array + * @param[in] n_query_rows number of rows in query array + * @param[in] k number of closest neighbors to return + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] stream stream to order kernel launch + */ template -void l2_expanded_knn(size_t D, value_idx *out_inds, value_t *out_dists, - const value_t *index, const value_t *query, - size_t n_index_rows, size_t n_query_rows, int k, - bool rowMajorIndex, bool rowMajorQuery, - cudaStream_t stream, void *workspace, size_t &worksize) { +void fusedL2Knn(size_t D, value_idx *out_inds, value_t *out_dists, + const value_t *index, const value_t *query, size_t n_index_rows, + size_t n_query_rows, int k, bool rowMajorIndex, + bool rowMajorQuery, cudaStream_t stream, void *workspace, + size_t &worksize) { // Validate the input data ASSERT(k > 0, "l2Knn: k must be > 0"); ASSERT(D > 0, "l2Knn: D must be > 0"); @@ -785,17 +756,34 @@ void l2_expanded_knn(size_t D, value_idx *out_inds, value_t *out_dists, // Currently we only support same layout for x & y inputs. ASSERT(rowMajorIndex == rowMajorQuery, "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); - - bool sqrt = (distanceType == raft::distance::DistanceType::L2SqrtExpanded); - - if (rowMajorIndex) { - value_idx lda = D, ldb = D, ldd = n_index_rows; - fusedL2ExpkNN( - n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, - out_dists, out_inds, k, stream, workspace, worksize); - } else { - // TODO: Add support for column major layout - } + // TODO: Add support for column major layout + ASSERT(rowMajorIndex == true, + "l2Knn: only rowMajor inputs are supported for now."); + + // Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support + // non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt. + bool sqrt = + (distanceType == raft::distance::DistanceType::L2SqrtUnexpanded) || + (distanceType == raft::distance::DistanceType::L2SqrtExpanded); + + value_idx lda = D, ldb = D, ldd = n_index_rows; + switch (distanceType) { + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2Expanded: + fusedL2ExpKnn( + n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, + out_dists, out_inds, k, stream, workspace, worksize); + break; + case raft::distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + fusedL2UnexpKnn( + n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, + out_dists, out_inds, k, stream, workspace, worksize); + break; + default: + printf("only L2 distance metric is supported\n"); + break; + }; } } // namespace detail diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index c08dd76640..882c5e12ed 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -287,34 +287,31 @@ void brute_force_knn_impl(std::vector &input, IntType>(search_items, input[i], n, sizes[i], D); worksize = tempWorksize; workspace.resize(worksize, stream); - l2_expanded_knn( - D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, - rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); + fusedL2Knn(D, out_i_ptr, out_d_ptr, input[i], search_items, + sizes[i], n, k, rowMajorIndex, rowMajorQuery, + stream, workspace.data(), worksize); if (worksize > tempWorksize) { workspace.resize(worksize, stream); - l2_expanded_knn( - D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, - rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); + fusedL2Knn(D, out_i_ptr, out_d_ptr, input[i], search_items, + sizes[i], n, k, rowMajorIndex, rowMajorQuery, + stream, workspace.data(), worksize); } break; case raft::distance::DistanceType::L2Unexpanded: // Even for L2 Sqrt distance case we use non-sqrt version // as FAISS bfKNN only support non-sqrt metric & some tests // in RAFT/cuML (like Linkage) fails if we use L2 sqrt. - // Even for L2 Sqrt distance case we use non-sqrt version - // as FAISS bfKNN only support non-sqrt metric & some tests - // in RAFT/cuML (like Linkage) fails if we use L2 sqrt. case raft::distance::DistanceType::L2SqrtUnexpanded: - l2_unexpanded_knn( - D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, - rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); + fusedL2Knn(D, out_i_ptr, out_d_ptr, input[i], search_items, + sizes[i], n, k, rowMajorIndex, rowMajorQuery, + stream, workspace.data(), worksize); if (worksize) { workspace.resize(worksize, stream); - l2_unexpanded_knn( + fusedL2Knn( D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); } diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index c43ce78cbf..53c292a887 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -98,14 +98,14 @@ void compute_bfknn(const raft::handle_t &handle, const value_t *X1, } else { size_t worksize = 0; void *workspace = nullptr; - raft::spatial::knn::detail::l2_unexpanded_knn< + raft::spatial::knn::detail::fusedL2Knn< raft::distance::DistanceType::L2SqrtUnexpanded, int64_t, value_t, false>( (size_t)d, inds, dists, input_vec[0], X2, (size_t)sizes_vec[0], (size_t)n, (int)k, true, true, handle.get_stream(), workspace, worksize); if (worksize) { rmm::device_uvector d_mutexes(worksize, handle.get_stream()); workspace = d_mutexes.data(); - raft::spatial::knn::detail::l2_unexpanded_knn< + raft::spatial::knn::detail::fusedL2Knn< raft::distance::DistanceType::L2SqrtUnexpanded, int64_t, value_t, false>((size_t)d, inds, dists, input_vec[0], X2, (size_t)sizes_vec[0], (size_t)n, (int)k, true, true, handle.get_stream(), workspace, From ae14f75006938695f65af3808874488fdafe567c Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 6 Oct 2021 14:10:30 +0530 Subject: [PATCH 10/22] revert ball cover test to use brute_force_knn function instead of explicit fusedL2KNN function call --- cpp/test/spatial/ball_cover.cu | 57 ++++++++++------------------------ 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 53c292a887..66f8c95464 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -17,8 +17,6 @@ #include #include #include -#include -#include #include #include #include "../test_utils.h" @@ -80,40 +78,6 @@ uint32_t count_discrepancies(value_idx *actual_idx, value_idx *expected_idx, return result; } -template -void compute_bfknn(const raft::handle_t &handle, const value_t *X1, - const value_t *X2, uint32_t n, uint32_t d, uint32_t k, - const raft::distance::DistanceType metric, value_t *dists, - int64_t *inds) { - std::vector input_vec = {const_cast(X1)}; - std::vector sizes_vec = {n}; - - if (metric == raft::distance::DistanceType::Haversine) { - cudaStream_t *int_streams = nullptr; - std::vector *translations = nullptr; - - raft::spatial::knn::detail::brute_force_knn_impl( - input_vec, sizes_vec, d, const_cast(X2), n, inds, dists, k, - handle.get_stream(), int_streams, 0, true, true, translations, metric); - } else { - size_t worksize = 0; - void *workspace = nullptr; - raft::spatial::knn::detail::fusedL2Knn< - raft::distance::DistanceType::L2SqrtUnexpanded, int64_t, value_t, false>( - (size_t)d, inds, dists, input_vec[0], X2, (size_t)sizes_vec[0], (size_t)n, - (int)k, true, true, handle.get_stream(), workspace, worksize); - if (worksize) { - rmm::device_uvector d_mutexes(worksize, handle.get_stream()); - workspace = d_mutexes.data(); - raft::spatial::knn::detail::fusedL2Knn< - raft::distance::DistanceType::L2SqrtUnexpanded, int64_t, value_t, - false>((size_t)d, inds, dists, input_vec[0], X2, (size_t)sizes_vec[0], - (size_t)n, (int)k, true, true, handle.get_stream(), workspace, - worksize); - } - } -} - struct ToRadians { __device__ __host__ float operator()(float a) { return a * (CUDART_PI_F / 180.0); @@ -155,8 +119,16 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { d_train_inputs.data(), ToRadians()); } - compute_bfknn(handle, d_train_inputs.data(), d_train_inputs.data(), n, d, k, - metric, d_ref_D.data(), d_ref_I.data()); + cudaStream_t *int_streams = nullptr; + std::vector *translations = nullptr; + + std::vector input_vec = {d_train_inputs.data()}; + std::vector sizes_vec = {n}; + + raft::spatial::knn::detail::brute_force_knn_impl( + input_vec, sizes_vec, d, d_train_inputs.data(), n, d_ref_I.data(), + d_ref_D.data(), k, handle.get_stream(), int_streams, 0, true, true, + translations, metric); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); @@ -226,11 +198,16 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { d_train_inputs.data(), ToRadians()); } + cudaStream_t *int_streams = nullptr; + std::vector *translations = nullptr; + std::vector input_vec = {d_train_inputs.data()}; std::vector sizes_vec = {n}; - compute_bfknn(handle, d_train_inputs.data(), d_train_inputs.data(), n, d, k, - metric, d_ref_D.data(), d_ref_I.data()); + raft::spatial::knn::detail::brute_force_knn_impl( + input_vec, sizes_vec, d, d_train_inputs.data(), n, d_ref_I.data(), + d_ref_D.data(), k, handle.get_stream(), int_streams, 0, true, true, + translations, metric); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); From 53b6415e99ef96cc02d983596572dffa149fad35 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 7 Oct 2021 16:26:47 +0530 Subject: [PATCH 11/22] use isnan only if DeviceMax/Min operations in atomicCAS based function, make customAtomicMax float only by removing the template as it is float specific function --- cpp/include/raft/device_atomics.cuh | 7 ++++--- cpp/include/raft/sparse/op/reduce.cuh | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/device_atomics.cuh b/cpp/include/raft/device_atomics.cuh index 83b54533c3..5feb38bb03 100644 --- a/cpp/include/raft/device_atomics.cuh +++ b/cpp/include/raft/device_atomics.cuh @@ -182,7 +182,8 @@ struct genericAtomicOperationImpl { T old_value = *addr; T assumed{old_value}; - if (std::is_same{}) { + if (std::is_same{} && (std::is_same{} || + std::is_same{})) { if (isnan(update_value)) { return update_value; } @@ -552,8 +553,8 @@ __forceinline__ __device__ T atomicMax(T* address, T val) { address, val, raft::device_atomics::detail::DeviceMax{}); } -template -__forceinline__ __device__ T customAtomicMax(T* address, T val) { +// fp32 only atomicMax. +__forceinline__ __device__ float customAtomicMax(float* address, float val) { float old; old = (val >= 0) ? __int_as_float(atomicMax((int*)address, __float_as_int(val))) diff --git a/cpp/include/raft/sparse/op/reduce.cuh b/cpp/include/raft/sparse/op/reduce.cuh index 09a35720fb..57f647feaa 100644 --- a/cpp/include/raft/sparse/op/reduce.cuh +++ b/cpp/include/raft/sparse/op/reduce.cuh @@ -67,7 +67,7 @@ __global__ void max_duplicates_kernel(const value_idx *src_rows, if (tid < nnz) { value_idx idx = index[tid]; - atomicMax(&out_vals[idx], src_vals[tid]); + customAtomicMax(&out_vals[idx], src_vals[tid]); out_rows[idx] = src_rows[tid]; out_cols[idx] = src_cols[tid]; } From 1d9ade3937406f277e00c45fba435894a2e78c28 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 7 Oct 2021 16:55:19 +0530 Subject: [PATCH 12/22] fix clang format issues --- cpp/include/raft/device_atomics.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/device_atomics.cuh b/cpp/include/raft/device_atomics.cuh index 5feb38bb03..d706595486 100644 --- a/cpp/include/raft/device_atomics.cuh +++ b/cpp/include/raft/device_atomics.cuh @@ -182,8 +182,8 @@ struct genericAtomicOperationImpl { T old_value = *addr; T assumed{old_value}; - if (std::is_same{} && (std::is_same{} || - std::is_same{})) { + if (std::is_same{} && + (std::is_same{} || std::is_same{})) { if (isnan(update_value)) { return update_value; } From 62cff7bc81483114427cb4d5e2b233fc805b0d22 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 11 Oct 2021 21:52:20 +0530 Subject: [PATCH 13/22] revert prtest.config changes, move fusedL2kNN launch/selection code to separate function which is now part of fused_l2_knn.cuh --- ci/prtest.config | 6 +-- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 37 +++++++++++----- .../knn/detail/knn_brute_force_faiss.cuh | 44 +------------------ 3 files changed, 32 insertions(+), 55 deletions(-) diff --git a/ci/prtest.config b/ci/prtest.config index ab46722186..08bdcaa3ab 100644 --- a/ci/prtest.config +++ b/ci/prtest.config @@ -1,6 +1,6 @@ RUN_CUGRAPH_LIBCUGRAPH_TESTS=OFF RUN_CUGRAPH_PYTHON_TESTS=OFF -RUN_CUML_LIBCUML_TESTS=ON -RUN_CUML_PRIMS_TESTS=ON -RUN_CUML_PYTHON_TESTS=ON +RUN_CUML_LIBCUML_TESTS=OFF +RUN_CUML_PRIMS_TESTS=OFF +RUN_CUML_PYTHON_TESTS=OFF diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 9412898bca..49b30dfe2e 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -737,13 +737,12 @@ void fusedL2ExpKnn(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, * @param[in] rowMajorQuery are the query array in row-major layout? * @param[in] stream stream to order kernel launch */ -template +template void fusedL2Knn(size_t D, value_idx *out_inds, value_t *out_dists, const value_t *index, const value_t *query, size_t n_index_rows, size_t n_query_rows, int k, bool rowMajorIndex, - bool rowMajorQuery, cudaStream_t stream, void *workspace, - size_t &worksize) { + bool rowMajorQuery, cudaStream_t stream, + raft::distance::DistanceType metric) { // Validate the input data ASSERT(k > 0, "l2Knn: k must be > 0"); ASSERT(D > 0, "l2Knn: D must be > 0"); @@ -762,23 +761,41 @@ void fusedL2Knn(size_t D, value_idx *out_inds, value_t *out_dists, // Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support // non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt. - bool sqrt = - (distanceType == raft::distance::DistanceType::L2SqrtUnexpanded) || - (distanceType == raft::distance::DistanceType::L2SqrtExpanded); + bool sqrt = (metric == raft::distance::DistanceType::L2SqrtUnexpanded) || + (metric == raft::distance::DistanceType::L2SqrtExpanded); + size_t worksize = 0, tempWorksize = 0; + rmm::device_uvector workspace(worksize, stream); value_idx lda = D, ldb = D, ldd = n_index_rows; - switch (distanceType) { + + switch (metric) { case raft::distance::DistanceType::L2SqrtExpanded: case raft::distance::DistanceType::L2Expanded: + tempWorksize = raft::distance::getWorkspaceSize< + raft::distance::DistanceType::L2Expanded, float, float, float, + value_idx>(query, index, n_query_rows, n_index_rows, D); + worksize = tempWorksize; + workspace.resize(worksize, stream); fusedL2ExpKnn( n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, - out_dists, out_inds, k, stream, workspace, worksize); + out_dists, out_inds, k, stream, workspace.data(), worksize); + if (worksize > tempWorksize) { + fusedL2ExpKnn( + n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, + out_dists, out_inds, k, stream, workspace.data(), worksize); + } break; case raft::distance::DistanceType::L2Unexpanded: case raft::distance::DistanceType::L2SqrtUnexpanded: fusedL2UnexpKnn( n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, - out_dists, out_inds, k, stream, workspace, worksize); + out_dists, out_inds, k, stream, workspace.data(), worksize); + if (worksize) { + fusedL2UnexpKnn(n_query_rows, n_index_rows, D, lda, ldb, ldd, + query, index, sqrt, out_dists, out_inds, k, + stream, workspace.data(), worksize); + } break; default: printf("only L2 distance metric is supported\n"); diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 882c5e12ed..04c16970d6 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -277,48 +277,8 @@ void brute_force_knn_impl(std::vector &input, metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded)) { - size_t worksize = 0, tempWorksize = 0; - rmm::device_uvector workspace(worksize, stream); - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: - tempWorksize = raft::distance::getWorkspaceSize< - raft::distance::DistanceType::L2Expanded, float, float, float, - IntType>(search_items, input[i], n, sizes[i], D); - worksize = tempWorksize; - workspace.resize(worksize, stream); - fusedL2Knn(D, out_i_ptr, out_d_ptr, input[i], search_items, - sizes[i], n, k, rowMajorIndex, rowMajorQuery, - stream, workspace.data(), worksize); - if (worksize > tempWorksize) { - workspace.resize(worksize, stream); - fusedL2Knn(D, out_i_ptr, out_d_ptr, input[i], search_items, - sizes[i], n, k, rowMajorIndex, rowMajorQuery, - stream, workspace.data(), worksize); - } - break; - case raft::distance::DistanceType::L2Unexpanded: - // Even for L2 Sqrt distance case we use non-sqrt version - // as FAISS bfKNN only support non-sqrt metric & some tests - // in RAFT/cuML (like Linkage) fails if we use L2 sqrt. - case raft::distance::DistanceType::L2SqrtUnexpanded: - fusedL2Knn(D, out_i_ptr, out_d_ptr, input[i], search_items, - sizes[i], n, k, rowMajorIndex, rowMajorQuery, - stream, workspace.data(), worksize); - if (worksize) { - workspace.resize(worksize, stream); - fusedL2Knn( - D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, - rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize); - } - break; - default: - break; - } + fusedL2Knn(D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, + k, rowMajorIndex, rowMajorQuery, stream, metric); } else { switch (metric) { case raft::distance::DistanceType::Haversine: From 9164a6403114ffc29e19019a11db7f6241528c43 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 13 Oct 2021 19:45:39 +0530 Subject: [PATCH 14/22] fix bug in updateSortedWarpQ for NN > 32, disable use of sqrt as it is not used to mimic faiss, fix issues in deviceMax atomic to filter NaNs --- cpp/include/raft/device_atomics.cuh | 8 +- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 272 +++++++++--------- 2 files changed, 138 insertions(+), 142 deletions(-) diff --git a/cpp/include/raft/device_atomics.cuh b/cpp/include/raft/device_atomics.cuh index d706595486..7fa17f602d 100644 --- a/cpp/include/raft/device_atomics.cuh +++ b/cpp/include/raft/device_atomics.cuh @@ -185,7 +185,7 @@ struct genericAtomicOperationImpl { if (std::is_same{} && (std::is_same{} || std::is_same{})) { if (isnan(update_value)) { - return update_value; + return old_value; } } @@ -556,6 +556,12 @@ __forceinline__ __device__ T atomicMax(T* address, T val) { // fp32 only atomicMax. __forceinline__ __device__ float customAtomicMax(float* address, float val) { float old; + + if (isnan(val)) { + // if NaN input, simply return value at address. + return *address; + } + old = (val >= 0) ? __int_as_float(atomicMax((int*)address, __float_as_int(val))) : __uint_as_float( diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 49b30dfe2e..acb63d4b4d 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -152,12 +152,12 @@ DI void updateSortedWarpQ(myWarpSelect &heapArr, Pair *allWarpTopKs, int rowId, heapArr->warpV[i] = tempKV.key; } if (i == 0 && NumWarpQRegs > 1) { + heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); + heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); if (lid == 0) { heapArr->warpK[1] = tempKV.value; heapArr->warpV[1] = tempKV.key; } - heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); - heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); break; } } @@ -323,185 +323,176 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( }; // epilogue operation lambda for final value calculation - auto epilog_lambda = - [numOfNN, sqrt, m, n, ldd, out_dists, out_inds] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT * regxn, - DataT * regyn, IdxT gridStrideX, IdxT gridStrideY) { - if (useNorms) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - } - - if (sqrt) { + auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { + if (useNorms) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); - } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; } } - Pair *shDumpKV = nullptr; - if (useNorms) { - shDumpKV = - (Pair *)(&smem[Policy::SmemSize + - ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); - } else { - shDumpKV = (Pair *)(&smem[Policy::SmemSize]); - } + } - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); + Pair *shDumpKV = nullptr; + if (useNorms) { + constexpr size_t shmemSize = + Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + shDumpKV = (Pair *)(&smem[shmemSize]); + } else { + shDumpKV = (Pair *)(&smem[Policy::SmemSize]); + } - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect *heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, - numOfNN, starty); - } + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); + + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect *heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, + numOfNN, starty); } + } - if (gridStrideX > blockIdx.x * Policy::Nblk) { + if (gridStrideX > blockIdx.x * Policy::Nblk) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = - (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - numValsWarpTopK[i]++; - } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + numValsWarpTopK[i]++; } } - anyWarpTopKs += numValsWarpTopK[i]; } + anyWarpTopKs += numValsWarpTopK[i]; } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair *allWarpTopKs = (Pair *)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; + } + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair *allWarpTopKs = (Pair *)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { #pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = - __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { - numValsWarpTopK[i] += n; - } + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = + __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { + numValsWarpTopK[i] += n; } } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; + } - if (needScanSort[i]) { - const auto rowId = - (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { + if (needScanSort[i]) { + const auto rowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = - otherKV; - numValsWarpTopK[i]++; - } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = + otherKV; + numValsWarpTopK[i]++; } } } - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, - numOfNN); - updateSortedWarpQkNumWarpQRegisters>( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, + numOfNN); + updateSortedWarpQkNumWarpQRegisters>( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } } - __syncthreads(); + } + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = - (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, - numOfNN); - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, + numOfNN); } } } - } else { + } + } else { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = - (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = + (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; - } - heapArr[i]->add(otherKV.value, otherKV.key); + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; } + heapArr[i]->add(otherKV.value, otherKV.key); + } - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { - heapArr[i]->reduce(); - } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, - numOfNN); + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { + heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, + numOfNN); } } + } - if (((gridStrideX + Policy::Nblk * gridDim.x) > n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, - starty); - } - }; + if (((gridStrideX + Policy::Nblk * gridDim.x) > n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, + starty); + } + }; raft::distance::PairwiseDistances workspace(worksize, stream); From abc2b1196175ccd89392e357494bb40f31bf17a8 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 13 Oct 2021 22:09:33 +0530 Subject: [PATCH 15/22] allocate workspace when resize is required for using prod-cons mutexes --- cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index acb63d4b4d..600087e383 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -770,6 +770,7 @@ void fusedL2Knn(size_t D, value_idx *out_inds, value_t *out_dists, n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, out_dists, out_inds, k, stream, workspace.data(), worksize); if (worksize > tempWorksize) { + workspace.resize(worksize, stream); fusedL2ExpKnn( n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, out_dists, out_inds, k, stream, workspace.data(), worksize); @@ -781,6 +782,7 @@ void fusedL2Knn(size_t D, value_idx *out_inds, value_t *out_dists, n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, out_dists, out_inds, k, stream, workspace.data(), worksize); if (worksize) { + workspace.resize(worksize, stream); fusedL2UnexpKnn(n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt, out_dists, out_inds, k, From ec0cc32af4259f18002ae5668180ca733bef4fa1 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 2 Nov 2021 16:35:47 +0530 Subject: [PATCH 16/22] add unit test for fused L2 KNN exp/unexp cases using faiss bfknn as gold output --- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 8 +- cpp/test/CMakeLists.txt | 1 + cpp/test/spatial/fused_l2_knn.cu | 164 ++++++++++++++++++ cpp/test/test_utils.h | 31 ++++ 4 files changed, 200 insertions(+), 4 deletions(-) create mode 100644 cpp/test/spatial/fused_l2_knn.cu diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 600087e383..3b0e69a50e 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -546,8 +546,11 @@ void fusedL2UnexpKnnImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, "fusedL2kNN: num of nearest neighbors must be <= 64"); } + const auto sharedMemSize = + KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::launchConfigGenerator( - m, n, KPolicy::SmemSize, fusedL2UnexpKnnRowMajor); + m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); + if (grid.x > 1) { const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) { @@ -559,9 +562,6 @@ void fusedL2UnexpKnnImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, } } - const auto sharedMemSize = - KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - fusedL2UnexpKnnRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt, (uint32_t)numOfNN, (int *)workspace, out_dists, out_inds); diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 43e1c65695..d77b11a7a9 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -87,6 +87,7 @@ add_executable(test_raft test/sparse/sort.cu test/sparse/symmetrize.cu test/spatial/knn.cu + test/spatial/fused_l2_knn.cu test/spatial/haversine.cu test/spatial/ball_cover.cu test/spatial/selection.cu diff --git a/cpp/test/spatial/fused_l2_knn.cu b/cpp/test/spatial/fused_l2_knn.cu new file mode 100644 index 0000000000..1ffb193bb6 --- /dev/null +++ b/cpp/test/spatial/fused_l2_knn.cu @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +struct FusedL2KNNInputs { + int num_queries; + int num_db_vecs; + int dim; + int k; + raft::distance::DistanceType metric_; +}; + +template +class FusedL2KNNTest : public ::testing::TestWithParam { + protected: + void testBruteForce() { + cudaStream_t stream = handle_.get_stream(); + + detail::fusedL2Knn(dim, raft_indices_, raft_distances_, database, + search_queries, num_db_vecs, num_queries, k_, true, true, + stream, metric); + + launchFaissBfknn(); + // Only verifying indices. + ASSERT_TRUE(devArrMatchInRange(faiss_indices_, raft_indices_, num_queries, + k_, raft::Compare(), stream)); + } + + void SetUp() override { + params_ = ::testing::TestWithParam::GetParam(); + num_queries = params_.num_queries; + num_db_vecs = params_.num_db_vecs; + dim = params_.dim; + k_ = params_.k; + metric = params_.metric_; + + cudaStream_t stream = handle_.get_stream(); + + raft::allocate(database, num_db_vecs * dim, stream, true); + raft::allocate(search_queries, num_queries * dim, stream, true); + + unsigned long long int seed = 1234ULL; + raft::random::Rng r(seed); + r.uniform(database, num_db_vecs * dim, T(-1.0), T(1.0), stream); + r.uniform(search_queries, num_queries * dim, T(-1.0), T(1.0), stream); + + raft::allocate(raft_indices_, num_queries * k_, stream, true); + raft::allocate(raft_distances_, num_queries * k_, stream, true); + raft::allocate(faiss_indices_, num_queries * k_, stream, true); + raft::allocate(faiss_distances_, num_queries * k_, stream, true); + } + + void TearDown() override { + cudaStream_t stream = handle_.get_stream(); + raft::deallocate_all(stream); + } + + void launchFaissBfknn() { + faiss::MetricType m = detail::build_faiss_metric(metric); + + faiss::gpu::StandardGpuResources gpu_res; + + gpu_res.noTempMemory(); + int device; + CUDA_CHECK(cudaGetDevice(&device)); + gpu_res.setDefaultStream(device, handle_.get_stream()); + + faiss::gpu::GpuDistanceParams args; + args.metric = m; + args.metricArg = 0; + args.k = k_; + args.dims = dim; + args.vectors = database; + args.vectorsRowMajor = true; + args.numVectors = num_db_vecs; + args.queries = search_queries; + args.queriesRowMajor = true; + args.numQueries = num_queries; + args.outDistances = faiss_distances_; + args.outIndices = faiss_indices_; + + bfKnn(&gpu_res, args); + } + + private: + raft::handle_t handle_; + FusedL2KNNInputs params_; + int num_queries; + int num_db_vecs; + int dim; + T *database; + T *search_queries; + int64_t *raft_indices_; + T *raft_distances_; + int64_t *faiss_indices_; + T *faiss_distances_; + int k_; + raft::distance::DistanceType metric; +}; + +const std::vector inputs = { + {100, 1000, 16, 10, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Expanded}, + {100, 1000, 16, 50, raft::distance::DistanceType::L2Expanded}, + {20, 10000, 16, 10, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 16, 50, raft::distance::DistanceType::L2Expanded}, + {1000, 10000, 32, 50, raft::distance::DistanceType::L2Expanded}, + {10000, 40000, 32, 30, raft::distance::DistanceType::L2Expanded}, + // L2 unexpanded + {100, 1000, 16, 10, raft::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 16, 10, raft::distance::DistanceType::L2Unexpanded}, + {100, 1000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, + {20, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 16, 50, raft::distance::DistanceType::L2Unexpanded}, + {1000, 10000, 32, 50, raft::distance::DistanceType::L2Unexpanded}, + {10000, 40000, 32, 30, raft::distance::DistanceType::L2Unexpanded}}; + +typedef FusedL2KNNTest FusedL2KNNTestF; +TEST_P(FusedL2KNNTestF, FusedBruteForce) { this->testBruteForce(); } + +INSTANTIATE_TEST_CASE_P(FusedL2KNNTest, FusedL2KNNTestF, + ::testing::ValuesIn(inputs)); + +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/test/test_utils.h b/cpp/test/test_utils.h index 0f135c0121..c0545e3bb1 100644 --- a/cpp/test/test_utils.h +++ b/cpp/test/test_utils.h @@ -141,6 +141,37 @@ testing::AssertionResult devArrMatch(const T *expected, const T *actual, return testing::AssertionSuccess(); } +// Match unsorted outputs within a range/col +template +testing::AssertionResult devArrMatchInRange(const T *expected, const T *actual, + size_t rows, size_t cols, + L eq_compare, + cudaStream_t stream = 0) { + size_t size = rows * cols; + std::unique_ptr exp_h(new T[size]); + std::unique_ptr act_h(new T[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + std::set setOfNumbers; + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto exp = exp_h.get()[idx]; + setOfNumbers.insert(exp); + } + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto act = act_h.get()[idx]; + if (!setOfNumbers.count(act)) { + return testing::AssertionFailure() << "actual=" << act << " @" << i + << "," << j << "not valid output"; + } + } + } + return testing::AssertionSuccess(); +} + template testing::AssertionResult devArrMatch(T expected, const T *actual, size_t rows, size_t cols, L eq_compare, From 2b64775e8aacddebad4896c624134cfd72be6849 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 2 Nov 2021 19:16:07 +0530 Subject: [PATCH 17/22] move customAtomicMax to generic atomicMax specialization, and remove redunant header --- cpp/include/raft/device_atomics.cuh | 40 ++++++++++--------- cpp/include/raft/sparse/op/reduce.cuh | 2 +- .../knn/detail/knn_brute_force_faiss.cuh | 1 - 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/cpp/include/raft/device_atomics.cuh b/cpp/include/raft/device_atomics.cuh index 7fa17f602d..6da594497a 100644 --- a/cpp/include/raft/device_atomics.cuh +++ b/cpp/include/raft/device_atomics.cuh @@ -182,8 +182,8 @@ struct genericAtomicOperationImpl { T old_value = *addr; T assumed{old_value}; - if (std::is_same{} && - (std::is_same{} || std::is_same{})) { + if constexpr(std::is_same{} && + (std::is_same{})) { if (isnan(update_value)) { return old_value; } @@ -203,6 +203,25 @@ struct genericAtomicOperationImpl { } }; +// 4 bytes fp32 atomic Max operation +template <> +struct genericAtomicOperationImpl { + using T = float; + __forceinline__ __device__ T operator()(T* addr, T const& update_value, + DeviceMax op) { + if (isnan(update_value)) { + return *addr; + } + + T old = (update_value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(update_value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(update_value))); + + return old; + } +}; + // 8 bytes atomic operation template struct genericAtomicOperationImpl { @@ -553,23 +572,6 @@ __forceinline__ __device__ T atomicMax(T* address, T val) { address, val, raft::device_atomics::detail::DeviceMax{}); } -// fp32 only atomicMax. -__forceinline__ __device__ float customAtomicMax(float* address, float val) { - float old; - - if (isnan(val)) { - // if NaN input, simply return value at address. - return *address; - } - - old = (val >= 0) - ? __int_as_float(atomicMax((int*)address, __float_as_int(val))) - : __uint_as_float( - atomicMin((unsigned int*)address, __float_as_uint(val))); - - return old; -} - /** * @brief Overloads for `atomicCAS` * diff --git a/cpp/include/raft/sparse/op/reduce.cuh b/cpp/include/raft/sparse/op/reduce.cuh index 57f647feaa..09a35720fb 100644 --- a/cpp/include/raft/sparse/op/reduce.cuh +++ b/cpp/include/raft/sparse/op/reduce.cuh @@ -67,7 +67,7 @@ __global__ void max_duplicates_kernel(const value_idx *src_rows, if (tid < nnz) { value_idx idx = index[tid]; - customAtomicMax(&out_vals[idx], src_vals[tid]); + atomicMax(&out_vals[idx], src_vals[tid]); out_rows[idx] = src_rows[tid]; out_cols[idx] = src_cols[tid]; } diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 04c16970d6..da1217e3cf 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -35,7 +35,6 @@ #include #include -#include #include "fused_l2_knn.cuh" #include "haversine_distance.cuh" #include "processing.hpp" From ef9a89887872b252084fb9fd134ad12216a7a238 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 2 Nov 2021 19:18:03 +0530 Subject: [PATCH 18/22] fix clang format errors --- cpp/include/raft/device_atomics.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/device_atomics.cuh b/cpp/include/raft/device_atomics.cuh index 6da594497a..a4ebcc9900 100644 --- a/cpp/include/raft/device_atomics.cuh +++ b/cpp/include/raft/device_atomics.cuh @@ -182,8 +182,7 @@ struct genericAtomicOperationImpl { T old_value = *addr; T assumed{old_value}; - if constexpr(std::is_same{} && - (std::is_same{})) { + if constexpr (std::is_same{} && (std::is_same{})) { if (isnan(update_value)) { return old_value; } @@ -213,10 +212,11 @@ struct genericAtomicOperationImpl { return *addr; } - T old = (update_value >= 0) - ? __int_as_float(atomicMax((int*)addr, __float_as_int(update_value))) - : __uint_as_float( - atomicMin((unsigned int*)addr, __float_as_uint(update_value))); + T old = + (update_value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(update_value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(update_value))); return old; } From b317a1205a9bfd10b404fe70d71994281cdb9d33 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 3 Nov 2021 14:12:16 +0530 Subject: [PATCH 19/22] call faiss before fusedL2knn kernel in the test --- cpp/test/spatial/fused_l2_knn.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/spatial/fused_l2_knn.cu b/cpp/test/spatial/fused_l2_knn.cu index 1ffb193bb6..cbd42e0333 100644 --- a/cpp/test/spatial/fused_l2_knn.cu +++ b/cpp/test/spatial/fused_l2_knn.cu @@ -54,11 +54,11 @@ class FusedL2KNNTest : public ::testing::TestWithParam { void testBruteForce() { cudaStream_t stream = handle_.get_stream(); + launchFaissBfknn(); detail::fusedL2Knn(dim, raft_indices_, raft_distances_, database, search_queries, num_db_vecs, num_queries, k_, true, true, stream, metric); - launchFaissBfknn(); // Only verifying indices. ASSERT_TRUE(devArrMatchInRange(faiss_indices_, raft_indices_, num_queries, k_, raft::Compare(), stream)); From 9e2e19e9b09562a86da6a5ba28cb473c261f8695 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 3 Nov 2021 20:46:54 +0530 Subject: [PATCH 20/22] fix issues in verification function as it can happen that 2 vectors with same distance value exists and faiss picks one vs fusedL2KNN another, so we verify both vec index as well as distance val --- cpp/test/spatial/fused_l2_knn.cu | 55 ++++++++++++++++++++++++++++++-- cpp/test/test_utils.h | 31 ------------------ 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/cpp/test/spatial/fused_l2_knn.cu b/cpp/test/spatial/fused_l2_knn.cu index cbd42e0333..70ae4d0bc0 100644 --- a/cpp/test/spatial/fused_l2_knn.cu +++ b/cpp/test/spatial/fused_l2_knn.cu @@ -48,6 +48,55 @@ struct FusedL2KNNInputs { raft::distance::DistanceType metric_; }; +template +struct idx_dist_pair { + IdxT idx; + DistT dist; + compareDist eq_compare; + bool operator==(const idx_dist_pair &a) const { + if (idx == a.idx) return true; + if (eq_compare(dist, a.dist)) return true; + return false; + } + idx_dist_pair(IdxT x, DistT y, compareDist op) + : idx(x), dist(y), eq_compare(op) {} +}; + +template +testing::AssertionResult devArrMatchKnnPair( + const T *expected_idx, const T *actual_idx, const DistT *expected_dist, + const DistT *actual_dist, size_t rows, size_t cols, const DistT eps, + cudaStream_t stream = 0) { + size_t size = rows * cols; + std::unique_ptr exp_idx_h(new T[size]); + std::unique_ptr act_idx_h(new T[size]); + std::unique_ptr exp_dist_h(new DistT[size]); + std::unique_ptr act_dist_h(new DistT[size]); + raft::update_host(exp_idx_h.get(), expected_idx, size, stream); + raft::update_host(act_idx_h.get(), actual_idx, size, stream); + raft::update_host(exp_dist_h.get(), expected_dist, size, stream); + raft::update_host(act_dist_h.get(), actual_dist, size, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + for (size_t i(0); i < rows; ++i) { + for (size_t j(0); j < cols; ++j) { + auto idx = i * cols + j; // row major assumption! + auto exp_idx = exp_idx_h.get()[idx]; + auto act_idx = act_idx_h.get()[idx]; + auto exp_dist = exp_dist_h.get()[idx]; + auto act_dist = act_dist_h.get()[idx]; + idx_dist_pair exp_kvp(exp_idx, exp_dist, raft::CompareApprox(eps)); + idx_dist_pair act_kvp(act_idx, act_dist, raft::CompareApprox(eps)); + if (!(exp_kvp == act_kvp)) { + return testing::AssertionFailure() + << "actual=" << act_kvp.idx << "," << act_kvp.dist << "!=" + << "expected" << exp_kvp.idx << "," << exp_kvp.dist << " @" << i + << "," << j; + } + } + } + return testing::AssertionSuccess(); +} + template class FusedL2KNNTest : public ::testing::TestWithParam { protected: @@ -59,9 +108,9 @@ class FusedL2KNNTest : public ::testing::TestWithParam { search_queries, num_db_vecs, num_queries, k_, true, true, stream, metric); - // Only verifying indices. - ASSERT_TRUE(devArrMatchInRange(faiss_indices_, raft_indices_, num_queries, - k_, raft::Compare(), stream)); + // verify. + devArrMatchKnnPair(faiss_indices_, raft_indices_, faiss_distances_, + raft_distances_, num_queries, k_, float(0.001), stream); } void SetUp() override { diff --git a/cpp/test/test_utils.h b/cpp/test/test_utils.h index c0545e3bb1..0f135c0121 100644 --- a/cpp/test/test_utils.h +++ b/cpp/test/test_utils.h @@ -141,37 +141,6 @@ testing::AssertionResult devArrMatch(const T *expected, const T *actual, return testing::AssertionSuccess(); } -// Match unsorted outputs within a range/col -template -testing::AssertionResult devArrMatchInRange(const T *expected, const T *actual, - size_t rows, size_t cols, - L eq_compare, - cudaStream_t stream = 0) { - size_t size = rows * cols; - std::unique_ptr exp_h(new T[size]); - std::unique_ptr act_h(new T[size]); - raft::update_host(exp_h.get(), expected, size, stream); - raft::update_host(act_h.get(), actual, size, stream); - CUDA_CHECK(cudaStreamSynchronize(stream)); - for (size_t i(0); i < rows; ++i) { - std::set setOfNumbers; - for (size_t j(0); j < cols; ++j) { - auto idx = i * cols + j; // row major assumption! - auto exp = exp_h.get()[idx]; - setOfNumbers.insert(exp); - } - for (size_t j(0); j < cols; ++j) { - auto idx = i * cols + j; // row major assumption! - auto act = act_h.get()[idx]; - if (!setOfNumbers.count(act)) { - return testing::AssertionFailure() << "actual=" << act << " @" << i - << "," << j << "not valid output"; - } - } - } - return testing::AssertionSuccess(); -} - template testing::AssertionResult devArrMatch(T expected, const T *actual, size_t rows, size_t cols, L eq_compare, From f0fd7b48fc28f2e9048cc4013b0bc06549751f8d Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 17 Nov 2021 20:57:22 +0530 Subject: [PATCH 21/22] revert ball_cover test to use compute_bfknn which is wrapper for brute_force_knn --- cpp/test/spatial/ball_cover.cu | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 66f8c95464..ca30506df0 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -78,6 +78,22 @@ uint32_t count_discrepancies(value_idx *actual_idx, value_idx *expected_idx, return result; } +template +void compute_bfknn(const raft::handle_t &handle, const value_t *X1, + const value_t *X2, uint32_t n, uint32_t d, uint32_t k, + const raft::distance::DistanceType metric, value_t *dists, + int64_t *inds) { + std::vector input_vec = {const_cast(X1)}; + std::vector sizes_vec = {n}; + + cudaStream_t *int_streams = nullptr; + std::vector *translations = nullptr; + + raft::spatial::knn::detail::brute_force_knn_impl( + input_vec, sizes_vec, d, const_cast(X2), n, inds, dists, k, + handle.get_stream(), int_streams, 0, true, true, translations, metric); +} + struct ToRadians { __device__ __host__ float operator()(float a) { return a * (CUDART_PI_F / 180.0); @@ -119,16 +135,8 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { d_train_inputs.data(), ToRadians()); } - cudaStream_t *int_streams = nullptr; - std::vector *translations = nullptr; - - std::vector input_vec = {d_train_inputs.data()}; - std::vector sizes_vec = {n}; - - raft::spatial::knn::detail::brute_force_knn_impl( - input_vec, sizes_vec, d, d_train_inputs.data(), n, d_ref_I.data(), - d_ref_D.data(), k, handle.get_stream(), int_streams, 0, true, true, - translations, metric); + compute_bfknn(handle, d_train_inputs.data(), d_train_inputs.data(), n, d, k, + metric, d_ref_D.data(), d_ref_I.data()); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); From bdce263195d8ae948b54c14397175c91b073201c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 23 Nov 2021 14:08:12 -0500 Subject: [PATCH 22/22] Adjusting rng.cuh --- cpp/test/spatial/fused_l2_knn.cu | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/cpp/test/spatial/fused_l2_knn.cu b/cpp/test/spatial/fused_l2_knn.cu index 70ae4d0bc0..4930b47e0c 100644 --- a/cpp/test/spatial/fused_l2_knn.cu +++ b/cpp/test/spatial/fused_l2_knn.cu @@ -17,15 +17,11 @@ #include "../test_utils.h" #include -#include #include -#include -#include -#include #include #include -#include +#include #include #include