From a3b8587978db1556c27d41d4d2a30ba2ac75e91e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 22:21:22 +0200 Subject: [PATCH 01/49] contractions: Concentrate tile index calculations The calculation of the tile indices are now performed in ldgXY(). This will make it possible to remove all state related to the tile index out of the class in the next commit. Note that the calculation of the tile index can depend on which overloaded constructor is called(!) --- .../detail/pairwise_distance_base.cuh | 27 ++---- .../raft/linalg/detail/contractions.cuh | 84 +++++++++++++------ .../knn/detail/epsilon_neighborhood.cuh | 10 +-- 3 files changed, 72 insertions(+), 49 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 69bb83d29a..f6e66d068e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -138,27 +138,14 @@ struct PairwiseDistances : public BaseClass { DI void updateIndicesY() { const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; + this->increment_grid_idx_n(stride); } DI void updateIndicesXY() { const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; + this->increment_grid_idx_m(stride); + this->reset_grid_idx_n(); } DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) @@ -187,7 +174,7 @@ struct PairwiseDistances : public BaseClass { this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() @@ -197,15 +184,15 @@ struct PairwiseDistances : public BaseClass { accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration // This is needed for making sure next grid stride of // non-norm based metrics uses previously accumulated buffer so // it doesn't make shmem dirty until previous iteration // is complete. - this->pageRd ^= 1; + this->switch_read_buffer(); } DI void accumulate() diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 5d83f88e71..a6efdec49e 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -40,14 +40,15 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; - /** current thread's global mem row id for X data */ - IdxT xrowid; - /** current thread's global mem row id for Y data */ - IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x; + const DataT* x_base; /** global memory pointer to Y matrix */ - const DataT* y; + const DataT* y_base; + + /** Support variables to provide backward compatibility **/ + IdxT grid_idx_m = 0; + IdxT grid_idx_n = 0; + bool first_constructor_called; /** current thread's smem row id */ int srowid; @@ -94,10 +95,8 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), - yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), - x(_x + xrowid * lda), - y(_y + yrowid * ldb), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -105,7 +104,8 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0) + pageRd(0), + first_constructor_called(true) { } @@ -133,6 +133,8 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -140,19 +142,9 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0) + pageRd(0), + first_constructor_called(false) { - if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; - x = _x + xrowid * lda; - y = _y + yrowid * ldb; - } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; - x = _x + xrowid + srowid * lda; - y = _y + yrowid + srowid * ldb; - } } protected: @@ -166,6 +158,12 @@ struct Contractions_NT { ldgY(kidx); } + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) + { + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx); + } + /** * @brief Store current block of X/Y from registers to smem * @param[in] kidx current start index of k to be loaded @@ -186,9 +184,35 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } + DI void increment_grid_idx_m(IdxT by) { grid_idx_m += by; } + + DI void increment_grid_idx_n(IdxT by) { grid_idx_n += by; } + + DI void reset_grid_idx_n() { grid_idx_n = 0; } + + DI void switch_read_buffer() { this->pageRd ^= 1; } + + DI void switch_write_buffer() { this->pageWr ^= 1; } + private: DI void ldgX(IdxT kidx) { + // Backward compatible way to determine the tile index. This depends on + // whether the first or the second constructor was called. The first + // constructor is called in epsilon_neighborhood.cuh and the second + // constructor is called in pairwise_distance_base.cuh. + if (first_constructor_called) { + ldgX(IdxT(blockIdx.x) * P::Mblk, kidx); + } else { + ldgX(grid_idx_m + IdxT(blockIdx.y) * P::Mblk, kidx); + } + } + + DI void ldgX(IdxT tile_idx_m, IdxT kidx) + { + IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; + auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; + if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -222,6 +246,18 @@ struct Contractions_NT { DI void ldgY(IdxT kidx) { + if (first_constructor_called) { + ldgY(IdxT(blockIdx.y) * P::Nblk, kidx); + } else { + ldgY(grid_idx_n + IdxT(blockIdx.x) * P::Nblk, kidx); + } + } + + DI void ldgY(IdxT tile_idx_n, IdxT kidx) + { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + if (isRowMajor) { auto numRows = n; auto koffset = kidx + scolid; @@ -315,4 +351,4 @@ struct Contractions_NT { } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index 19862d743d..cd0e005921 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { private: DI void prolog() { - this->ldgXY(0); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() { for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx); accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration } From 99e65a5d93c8fbca1afea76fbcda019346d5d1df Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 22:51:49 +0200 Subject: [PATCH 02/49] pairwise_distance_base: Remove all ldgXY(0) calls This commit moves all grid and tile indexing logic into the caller. Contractions_NT is now only responsible for *intra*-tile indexing. Due to the complexity of the epilog function, the ldgNextGridStride function is not yet called from within the main loop. That is the next goal so that we have all the grid and tile indexing localized in the loop. --- .../detail/pairwise_distance_base.cuh | 121 ++++++++++-------- .../raft/linalg/detail/contractions.cuh | 45 +------ 2 files changed, 67 insertions(+), 99 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index f6e66d068e..fefb964f3d 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -87,6 +87,12 @@ struct PairwiseDistances : public BaseClass { FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; + + const IdxT grid_stride_m; + const IdxT grid_stride_n; + const IdxT grid_offset_m; + const IdxT grid_offset_n; + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; public: @@ -116,53 +122,63 @@ struct PairwiseDistances : public BaseClass { core_op(_core_op), epilog_op(_epilog_op), fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op) + rowEpilog_op(_rowEpilog_op), + grid_stride_m(P::Nblk * gridDim.y), + grid_stride_n(P::Mblk * gridDim.x), + grid_offset_m(P::Mblk * blockIdx.y), + grid_offset_n(P::Nblk * blockIdx.x) { } DI void run() { - for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; - gridStrideY += P::Mblk * gridDim.y) { - for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { - prolog(gridStrideX, gridStrideY); - loop(); - epilog(gridStrideX, gridStrideY); + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + this->ldgXY(tile_idx_m, grid_offset_n, 0); + for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + accumulate(); // last iteration + // This is needed for making sure next grid stride of + // non-norm based metrics uses previously accumulated buffer so + // it doesn't make shmem dirty until previous iteration + // is complete. + this->switch_read_buffer(); + + epilog(tile_idx_n, tile_idx_m); } - rowEpilog_op(gridStrideY); + rowEpilog_op(tile_idx_m); } } private: - DI void updateIndicesY() - { - const auto stride = P::Nblk * gridDim.x; - this->increment_grid_idx_n(stride); - } - - DI void updateIndicesXY() - { - const auto stride = P::Mblk * gridDim.y; - this->increment_grid_idx_m(stride); - this->reset_grid_idx_n(); - } - - DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) + DI void ldgNextGridStride(IdxT tile_idx_n, IdxT tile_idx_m) { // Fetch next grid stride ldg if within range - if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { - updateIndicesY(); - this->ldgXY(0); - } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { - updateIndicesXY(); - this->ldgXY(0); + const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; + const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; + if ((next_tile_tile_idx_n) < this->n) { + this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); + } else if ((next_tile_tile_idx_m) < this->m) { + this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); } } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) + DI void prolog(IdxT tile_idx_n, IdxT tile_idx_m) { - if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); } + if (tile_idx_n == blockIdx.x * P::Nblk) { this->ldgXY(0); } #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -177,22 +193,15 @@ struct PairwiseDistances : public BaseClass { this->switch_write_buffer(); } - DI void loop() - { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); + DI void reset_accumulator() { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } } - accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. - this->switch_read_buffer(); } DI void accumulate() @@ -213,22 +222,22 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT gridStrideX, IdxT gridStrideY) + DI void epilog(IdxT tile_idx_n, IdxT tile_idx_m) { if (useNorms) { DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); DataT* syNorm = (&sxNorm[P::Mblk]); // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * P::Nblk) { + if (tile_idx_n == blockIdx.x * P::Nblk) { for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; + auto idx = tile_idx_m + i; sxNorm[i] = idx < this->m ? xn[idx] : 0; } } for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = gridStrideX + i; + auto idx = tile_idx_n + i; syNorm[i] = idx < this->n ? yn[idx] : 0; } @@ -245,17 +254,17 @@ struct PairwiseDistances : public BaseClass { } // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); + ldgNextGridStride(tile_idx_n, tile_idx_m); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + ldgNextGridStride(tile_idx_n, tile_idx_m); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { - IdxT starty = gridStrideY + this->accrowid; - IdxT startx = gridStrideX + this->acccolid; + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index a6efdec49e..6d7a8e2292 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -45,11 +45,6 @@ struct Contractions_NT { /** global memory pointer to Y matrix */ const DataT* y_base; - /** Support variables to provide backward compatibility **/ - IdxT grid_idx_m = 0; - IdxT grid_idx_n = 0; - bool first_constructor_called; - /** current thread's smem row id */ int srowid; /** current thread's smem column id */ @@ -104,8 +99,7 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0), - first_constructor_called(true) + pageRd(0) { } @@ -142,8 +136,7 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0), - first_constructor_called(false) + pageRd(0) { } @@ -152,12 +145,6 @@ struct Contractions_NT { * @brief Load current block of X/Y from global memory to registers * @param[in] kidx current start index of k to be loaded */ - DI void ldgXY(IdxT kidx) - { - ldgX(kidx); - ldgY(kidx); - } - DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) { ldgX(tile_idx_m, kidx); @@ -184,30 +171,11 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } - DI void increment_grid_idx_m(IdxT by) { grid_idx_m += by; } - - DI void increment_grid_idx_n(IdxT by) { grid_idx_n += by; } - - DI void reset_grid_idx_n() { grid_idx_n = 0; } - DI void switch_read_buffer() { this->pageRd ^= 1; } DI void switch_write_buffer() { this->pageWr ^= 1; } private: - DI void ldgX(IdxT kidx) - { - // Backward compatible way to determine the tile index. This depends on - // whether the first or the second constructor was called. The first - // constructor is called in epsilon_neighborhood.cuh and the second - // constructor is called in pairwise_distance_base.cuh. - if (first_constructor_called) { - ldgX(IdxT(blockIdx.x) * P::Mblk, kidx); - } else { - ldgX(grid_idx_m + IdxT(blockIdx.y) * P::Mblk, kidx); - } - } - DI void ldgX(IdxT tile_idx_m, IdxT kidx) { IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; @@ -244,15 +212,6 @@ struct Contractions_NT { } } - DI void ldgY(IdxT kidx) - { - if (first_constructor_called) { - ldgY(IdxT(blockIdx.y) * P::Nblk, kidx); - } else { - ldgY(grid_idx_n + IdxT(blockIdx.x) * P::Nblk, kidx); - } - } - DI void ldgY(IdxT tile_idx_n, IdxT kidx) { IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; From e6d5078aa126f5612525a548a063610136f98293 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 23:40:32 +0200 Subject: [PATCH 03/49] pairwise_distance_base: Move all logic into run loop This commit removes the epilog function and moves its functionality into the run loop. The next step might be to see if the ldgNextGridStride() method has to be called the current location, or if performance is the same if its called at the start of the loop. --- .../detail/pairwise_distance_base.cuh | 128 ++++++++---------- 1 file changed, 57 insertions(+), 71 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index fefb964f3d..a2dffad808 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -87,7 +87,6 @@ struct PairwiseDistances : public BaseClass { FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; - const IdxT grid_stride_m; const IdxT grid_stride_n; const IdxT grid_offset_m; @@ -141,14 +140,14 @@ struct PairwiseDistances : public BaseClass { this->switch_write_buffer(); for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - accumulate(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration // This is needed for making sure next grid stride of @@ -157,14 +156,25 @@ struct PairwiseDistances : public BaseClass { // is complete. this->switch_read_buffer(); - epilog(tile_idx_n, tile_idx_m); + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, regxn, regyn); + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } + if (writeOut) { store_output(tile_idx_m, tile_idx_n); } } rowEpilog_op(tile_idx_m); } } private: - DI void ldgNextGridStride(IdxT tile_idx_n, IdxT tile_idx_m) + DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) { // Fetch next grid stride ldg if within range const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; @@ -176,24 +186,8 @@ struct PairwiseDistances : public BaseClass { } } - DI void prolog(IdxT tile_idx_n, IdxT tile_idx_m) + DI void reset_accumulator() { - if (tile_idx_n == blockIdx.x * P::Nblk) { this->ldgXY(0); } - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; - } - } - - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - } - - DI void reset_accumulator() { // Reset accumulator registers to zero. #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -222,60 +216,52 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT tile_idx_n, IdxT tile_idx_m) + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) { - if (useNorms) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (tile_idx_n == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (tile_idx_n == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; } + } - __syncthreads(); + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_n, tile_idx_m); - epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_n, tile_idx_m); - epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + } - if (writeOut) { - IdxT starty = tile_idx_m + this->accrowid; - IdxT startx = tile_idx_n + this->acccolid; + DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) + { + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); } } } From 995d2ae5a550060c27e3426570a7f9e8e7addc01 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 5 Oct 2022 16:17:56 +0200 Subject: [PATCH 04/49] pairwise_distance_base: Fix typo This results in subtle issues with non-square KernelPolicy, as found in fusedL2KNN. --- cpp/include/raft/distance/detail/pairwise_distance_base.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index a2dffad808..b28c3a3de4 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -122,8 +122,8 @@ struct PairwiseDistances : public BaseClass { epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), - grid_stride_m(P::Nblk * gridDim.y), - grid_stride_n(P::Mblk * gridDim.x), + grid_stride_m(P::Mblk * gridDim.y), + grid_stride_n(P::Nblk * gridDim.x), grid_offset_m(P::Mblk * blockIdx.y), grid_offset_n(P::Nblk * blockIdx.x) { From e6976c53ab559befef9019123ae33379bb54733e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 15:40:04 +0100 Subject: [PATCH 05/49] Implement reviewer feedback --- .../raft/distance/detail/pairwise_distance_base.cuh | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index b28c3a3de4..d849b23999 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ namespace detail { * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ template m; tile_idx_m += grid_stride_m) { this->ldgXY(tile_idx_m, grid_offset_n, 0); for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + // Prolog: reset_accumulator(); this->stsXY(); __syncthreads(); this->switch_write_buffer(); + // Main loop: for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { this->ldgXY(tile_idx_m, tile_idx_n, kidx); // Process all data in shared memory (previous k-block) and @@ -150,12 +153,12 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); } accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer back + // so that it satisfies the pre-condition of this loop. this->switch_read_buffer(); + // Epilog: if (useNorms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); From 8385f2f784f85b559277de5d2ea5c85aa109ee10 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 22 Sep 2022 11:38:16 +0200 Subject: [PATCH 06/49] Add sparseL2NN initial implementation --- .../raft/distance/detail/compress_to_bits.cuh | 49 ++ .../distance/detail/sparse_distance_base.cuh | 362 +++++++++++++ .../raft/distance/detail/sparse_l2_nn.cuh | 303 +++++++++++ cpp/include/raft/distance/sparse_l2_nn.cuh | 114 ++++ .../raft/linalg/detail/contractions.cuh | 42 ++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/sparse_l2_nn.cu | 494 ++++++++++++++++++ 7 files changed, 1365 insertions(+) create mode 100644 cpp/include/raft/distance/detail/compress_to_bits.cuh create mode 100644 cpp/include/raft/distance/detail/sparse_distance_base.cuh create mode 100644 cpp/include/raft/distance/detail/sparse_l2_nn.cuh create mode 100644 cpp/include/raft/distance/sparse_l2_nn.cuh create mode 100644 cpp/test/distance/sparse_l2_nn.cu diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh new file mode 100644 index 0000000000..e9a60154a3 --- /dev/null +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace raft { +namespace distance { +namespace detail { + +template ::value>> +__global__ void compress_to_bits_naive(const bool* in, int in_rows, int in_cols, T* out) +{ + constexpr int bits_per_element = 8 * sizeof(T); + + const size_t i = threadIdx.y + blockIdx.y * blockDim.y; + const size_t j = threadIdx.x + blockIdx.x * blockDim.x; + + if (in_rows <= i || in_cols <= j) { return; } + + bool bit = in[i * in_cols + j]; + int bitpos = j % bits_per_element; + + T bitfield = bit ? T(1) << bitpos : 0; + + const size_t out_rows = raft::ceildiv(in_cols, bits_per_element); + const size_t out_cols = in_rows; + const size_t out_j = i; + const size_t out_i = j / bits_per_element; + if (out_i < out_rows && out_j < out_cols) { atomicOr(&out[out_i * out_cols + out_j], bitfield); } +} + +}; // namespace detail +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/detail/sparse_distance_base.cuh b/cpp/include/raft/distance/detail/sparse_distance_base.cuh new file mode 100644 index 0000000000..6e51ccbab3 --- /dev/null +++ b/cpp/include/raft/distance/detail/sparse_distance_base.cuh @@ -0,0 +1,362 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include + +#include + +namespace raft { +namespace distance { +namespace detail { + +/** + * @brief Device class for L1, L2 and cosine distance metrics. + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda tells how to accumulate an x and y into + acc. its signature: + template void core_lambda(AccT& acc, + const DataT& x, const DataT& y) + * @tparam EpilogueLambda applies an elementwise function to compute final + values. Its signature is: + template void epilogue_lambda + (AccT acc[][], DataT* regxn, DataT* regyn); + * @tparam FinalLambda the final lambda called on final distance value + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine + * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine + * @param[output] pD output matrix + * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. + * @param core_op the core accumulation operation lambda + * @param epilog_op the epilog operation lambda + * @param fin_op the final gemm epilogue lambda + */ + +template > +struct SparseDistances : public BaseClass { + private: + typedef Policy P; + const DataT* xn; + const DataT* yn; + const DataT* const yBase; + const uint64_t* adj; + const IdxT* group_idxs; + IdxT num_groups; + char* smem; + CoreLambda core_op; + EpilogueLambda epilog_op; + FinalLambda fin_op; + rowEpilogueLambda rowEpilog_op; + + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; + + public: + // Constructor + DI SparseDistances(const DataT* _x, + const DataT* _y, + IdxT _m, + IdxT _n, + IdxT _k, + IdxT _lda, + IdxT _ldb, + IdxT _ldd, + const DataT* _xn, + const DataT* _yn, + const uint64_t* _adj, + const IdxT* _group_idxs, + IdxT _num_groups, + char* _smem, + CoreLambda _core_op, + EpilogueLambda _epilog_op, + FinalLambda _fin_op, + rowEpilogueLambda _rowEpilog_op) + : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), + xn(_xn), + yn(_yn), + yBase(_y), + adj(_adj), + group_idxs(_group_idxs), + num_groups(_num_groups), + smem(_smem), + core_op(_core_op), + epilog_op(_epilog_op), + fin_op(_fin_op), + rowEpilog_op(_rowEpilog_op) + { + } + + DI void run() + { + const auto grid_stride_m = (P::Mblk * gridDim.y); + const auto grid_offset_m = (P::Mblk * blockIdx.y); + + const auto grid_stride_g = gridDim.x; + const auto grid_offset_g = blockIdx.x; + + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + // Start loop over groups + for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { + // The __syncthreads() ensures that loading the block flag occurs at + // the same time in all threads of the block. Since all threads load + // the same address, this speeds up the code. + __syncthreads(); + const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); + // block_adj is a bitfield that contains a 1 if a row is adjacent to the + // current group. All zero means we can skip this group. + if (block_adj == 0) { continue; } + + // Determine which results, that are computed by this thread, have to + // be taken into account. This information is stored in a bitfield, + // thread_adj. If all results computed by this thread can be ignored, + // then we can also skip some computations (thread_adj == 0). + + // We precompute this information because it is used in various + // locations to skip thread-local computations. + int thread_adj = compute_thread_adjacency(block_adj); + + auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; + const auto tile_end_n = group_idxs[idx_g]; + for (; tile_idx_n < tile_end_n; tile_idx_n += P::Nblk) { + // We provide tile_end_n to limit the number of unnecessary data + // points that are loaded from y. + // TODO: determine if this actually improves performance. + this->ldgXY(tile_idx_m, tile_idx_n, 0, tile_end_n); + + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx, tile_end_n); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + if (thread_adj != 0) { accumulate(); } + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + if (thread_adj != 0) { + accumulate(); // last iteration + } + // This is needed for making sure next grid stride of + // non-norm based metrics uses previously accumulated buffer so + // it doesn't make shmem dirty until previous iteration + // is complete. + this->switch_read_buffer(); + + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, tile_end_n, regxn, regyn); + if (thread_adj != 0) { + epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, tile_end_n); + } + } else { + if (thread_adj != 0) { + epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, tile_end_n); + } + } + } // tile_idx_n + } // idx_g + rowEpilog_op(tile_idx_m); + } // tile_idx_n + } + + private: + DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) + { + IdxT block_flag_idx = tile_idx_m / P::Mblk; + return adj[block_flag_idx * this->num_groups + idx_group]; + } + + DI uint32_t compute_thread_adjacency(const uint64_t block_adj) + { + uint32_t thread_adj = 0; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + const uint64_t read_mask = 1ull << (this->accrowid + i * P::AccThRows); + const uint32_t write_mask = 1 << i; + if ((block_adj & read_mask) != 0) { thread_adj |= write_mask; } + } + return thread_adj; + } + + DI void reset_accumulator() + { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } + } + } + + DI void accumulate() + { +#pragma unroll + for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { + this->ldsXY(ki); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { +#pragma unroll + for (int v = 0; v < P::Veclen; ++v) { + core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); + } + } + } + } + } + + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + IdxT tile_end_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) + { + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } + + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < tile_end_n ? yn[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; + } + } +}; // struct SparseDistances + +/** + * @brief the distance matrix calculation kernel for L1, L2 and cosine + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda lambda which implements accumulation operation + * @tparam EpilogueLambda lambda which implements operation for calculating + final value. + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major(default), + false for column major + * + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] xn row norms of input matrix A. + * @param[in] yn row norms of input matrix B. + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] pD output matrix + * @param core_op the core lambda + * @param epilog_op the epilogue lambda + * @param fin_op the final gemm epilogue lambda + */ + +template +__global__ __launch_bounds__(Policy::Nthreads, 2) + + void sparseDistanceMatKernel(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + const bool* adj, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + CoreLambda core_op, + EpilogueLambda epilog_op, + FinalLambda fin_op) +{ + extern __shared__ char smem[]; + auto rowEpilog = [] __device__(IdxT starty) { return; }; + + SparseDistances + obj(x, y, m, n, k, lda, ldb, ldd, _xn, _yn, smem, core_op, epilog_op, fin_op, rowEpilog); + obj.run(); +} + +}; // namespace detail +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/detail/sparse_l2_nn.cuh b/cpp/include/raft/distance/detail/sparse_l2_nn.cuh new file mode 100644 index 0000000000..acc66b3837 --- /dev/null +++ b/cpp/include/raft/distance/detail/sparse_l2_nn.cuh @@ -0,0 +1,303 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +namespace detail { + +#if (ENABLE_MEMCPY_ASYNC == 1) +#include +using namespace nvcuda::experimental; +#endif + +template +__global__ __launch_bounds__(P::Nthreads, 2) void sparseL2NNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const uint64_t* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + CoreLambda core_op, + FinalLambda fin_op) +{ + extern __shared__ char smem[]; + + typedef cub::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + int acc_adj, + DataT* regxn, + DataT* regyn, + IdxT tile_idx_n, + IdxT tile_idx_m, + IdxT tile_end_n) { + KVPReduceOpT pairRed_op(pairRedOp); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } + } + if (Sqrt) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = raft::mySqrt(acc[i][j]); + } + } + } + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + const bool ignore = (acc_adj & (1 << i)) == 0; + if (ignore) { continue; } +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; + if (tile_end_n <= tmpkey) { + // Do not process beyond end of tile. + continue; + } + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < tile_end_n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + auto tmpkey = raft::shfl(val[i].key, lid + j); + auto tmpvalue = raft::shfl(val[i].value, lid + j); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + SparseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + adj, + group_idxs, + num_groups, + smem, + core_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +} + +template +void sparseL2NNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const bool* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + typedef typename linalg::Policy4x4::Policy P; + + static_assert(P::Mblk == 64, "sparseL2NNImpl only supports a policy with 64 rows per block."); + + // First, compress boolean to bitfield. + + // TODO 1: Remove allocation; use workspace instead(?) + // TODO 2: Use a faster compress_to_bits implementation that does not require a pre-zeroed output. + rmm::device_uvector adj64(raft::ceildiv(m, IdxT(64)) * num_groups, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(adj64.data(), 0, adj64.size() * sizeof(uint64_t), stream)); + dim3 compress_grid(raft::ceildiv(m, 32), raft::ceildiv(num_groups, 32)); + compress_to_bits_naive<<>>( + adj, num_groups, m, adj64.data()); + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef cub::KeyValuePair KVPair; + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + // TODO 3: remove fin_op + auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; + + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + if (sqrt) { + auto sparseL2NNSqrt = sparseL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, sparseL2NNSqrt); + + sparseL2NNSqrt<<>>(min, + x, + y, + xn, + yn, + adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + fin_op); + } else { + auto sparseL2NN = sparseL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, sparseL2NN); + sparseL2NN<<>>(min, + x, + y, + xn, + yn, + adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + maxVal, + workspace, + redOp, + pairRedOp, + core_lambda, + fin_op); + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/sparse_l2_nn.cuh b/cpp/include/raft/distance/sparse_l2_nn.cuh new file mode 100644 index 0000000000..c690702cb4 --- /dev/null +++ b/cpp/include/raft/distance/sparse_l2_nn.cuh @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __SPARSE_L2_NN_H +#define __SPARSE_L2_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +/** + * @brief Sparse L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void sparseL2NN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const bool* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // TODO: decide on kernel policy based on skinniness of the matrices. If k is + // low, it may make sense to use another kernel policy, like in + // fused_l2_nn.cuh. + detail::sparseL2NNImpl(min, + x, + y, + xn, + yn, + adj, + group_idxs, + num_groups, + m, + n, + k, + (int*)workspace, + redOp, + pairRedOp, + sqrt, + initOutBuffer, + stream); + // } +} + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 6d7a8e2292..4c5a43cd57 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -151,6 +151,12 @@ struct Contractions_NT { ldgY(tile_idx_n, kidx); } + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) + { + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx, tile_end_n); + } + /** * @brief Store current block of X/Y from registers to smem * @param[in] kidx current start index of k to be loaded @@ -248,6 +254,42 @@ struct Contractions_NT { } } + DI void ldgY(IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) + { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + + if (isRowMajor) { + auto numRows = tile_end_n; + auto koffset = kidx + scolid; +#pragma unroll + for (int i = 0; i < P::LdgPerThY; ++i) { + if (koffset < ldb && (yrowid + i * P::LdgRowsY) < numRows) { + ldg(ldgDataY[i], y + i * P::LdgRowsY * ldb + koffset); + } else { +#pragma unroll + for (int j = 0; j < P::Veclen; ++j) { + ldgDataY[i][j] = Zero; + } + } + } + } else { + auto numRows = k; + auto koffset = scolid; +#pragma unroll + for (int i = 0; i < P::LdgPerThY; ++i) { + if ((koffset + yrowid) < tile_end_n && (srowid + kidx + i * P::LdgRowsY) < numRows) { + ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset); + } else { +#pragma unroll + for (int j = 0; j < P::Veclen; ++j) { + ldgDataY[i][j] = Zero; + } + } + } + } + } + DI void stsX(DataT* smem) { auto* saddr = smem + srowid * P::SmemStride + scolid; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6c7ca11d86..8039e0277e 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -121,6 +121,7 @@ if(BUILD_TESTS) test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu + test/distance/sparse_l2_nn.cu test/distance/gram.cu OPTIONAL DIST diff --git a/cpp/test/distance/sparse_l2_nn.cu b/cpp/test/distance/sparse_l2_nn.cu new file mode 100644 index 0000000000..293c78ddee --- /dev/null +++ b/cpp/test/distance/sparse_l2_nn.cu @@ -0,0 +1,494 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace distance { +namespace sparse_l2_nn { + +template +struct CubKVPMinReduce { + typedef cub::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +__global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(cub::KeyValuePair* min, + DataT* x, + DataT* y, + bool* adj, + int* group_idxs, + int m, + int n, + int k, + int num_groups, + int* workspace, + DataT maxVal) +{ + const int m_stride = blockDim.y * gridDim.y; + const int m_offset = threadIdx.y + blockIdx.y * blockDim.y; + const int n_stride = blockDim.x * gridDim.x; + const int n_offset = threadIdx.x + blockIdx.x * blockDim.x; + + for (int m_grid = 0; m_grid < m; m_grid += m_stride) { + for (int n_grid = 0; n_grid < n; n_grid += n_stride) { + int midx = m_grid + m_offset; + int nidx = n_grid + n_offset; + + // Do a reverse linear search to determine the group index. + int group_idx = 0; + for (int i = num_groups; 0 <= i; --i) { + if (nidx < group_idxs[i]) { group_idx = i; } + } + const bool include_dist = adj[group_idx * m + midx] && midx < m && nidx < n; + + // Compute L2 metric. + DataT acc = DataT(0); + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto diff = x[xidx] - y[yidx]; + acc += diff * diff; + } + if (Sqrt) { acc = raft::mySqrt(acc); } + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + cub::KeyValuePair tmp; + tmp.key = include_dist ? nidx : -1; + tmp.value = include_dist ? acc : maxVal; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, CubKVPMinReduce()); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } + __syncthreads(); + } + } +} + +template +void naive(cub::KeyValuePair* min, + DataT* x, + DataT* y, + bool* adj, + int* group_idxs, + int m, + int n, + int k, + int num_groups, + int* workspace, + cudaStream_t stream) +{ + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + raft::distance::detail::initKernel, int> + <<>>(min, m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); + + const int nwarps = 16; + static const dim3 TPB(32, nwarps, 1); + dim3 nblks(1, 200, 1); + naiveKernel, nwarps><<>>( + min, x, y, adj, group_idxs, m, n, k, num_groups, workspace, std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +enum AdjacencyPattern { + checkerboard = 0, + checkerboard_4 = 1, + checkerboard_64 = 2, + all_true = 3, + all_false = 4 +}; + +template +struct Inputs { + DataT tolerance; + int m, n, k, num_groups; + unsigned long long int seed; + + AdjacencyPattern pattern; + + friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + { + return os << "m: " << p.m + << ", " + "n: " + << p.n + << ", " + "k: " + << p.k + << ", " + "num_groups: " + << p.num_groups + << ", " + "seed: " + << p.seed + << ", " + "tol: " + << p.tolerance; + } +}; + +__global__ void init_adj( + int m, int n, int num_groups, AdjacencyPattern pattern, bool* adj, int* group_idxs) +{ + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_groups; i += blockDim.y * gridDim.y) { + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < m; j += blockDim.x * gridDim.x) { + switch (pattern) { + case checkerboard: adj[i * m + j] = (i + j) % 2; break; + case checkerboard_4: adj[i * m + j] = (i + (j / 4)) % 2; break; + case checkerboard_64: adj[i * m + j] = (i + (j / 64)) % 2; break; + case all_true: adj[i * m + j] = true; break; + case all_false: adj[i * m + j] = false; break; + default: assert(false && "unknown pattern"); + } + } + } + // Each group is of size n / num_groups. + // + // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive + // scan of the group lengths) + // + // - The first group always starts at index zero, so we do not store it. + // + // - The group_idxs[num_groups - 1] should always equal n. + + if (blockIdx.y == 0 && threadIdx.y == 0) { + const int j_stride = blockDim.x * gridDim.x; + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_groups; j += j_stride) { + group_idxs[j] = (j + 1) * (n / num_groups); + } + group_idxs[num_groups - 1] = n; + } +} + +template +class SparseL2NNTest : public ::testing::TestWithParam> { + public: + SparseL2NNTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + x(params.m * params.k, stream), + y(params.n * params.k, stream), + adj(params.m * params.num_groups, stream), + group_idxs(params.num_groups, stream), + xn(params.m, stream), + yn(params.n, stream), + min(params.m, stream), + min_ref(params.m, stream), + workspace(params.m * sizeof(int), stream) + { + } + + protected: + void SetUp() override + { + raft::random::RngState r(params.seed); + int m = params.m; + int n = params.n; + int k = params.k; + int num_groups = params.num_groups; + uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); + + dim3 block(32, 32); + dim3 grid(10, 10); + init_adj<<>>( + m, n, num_groups, params.pattern, adj.data(), group_idxs.data()); + RAFT_CUDA_TRY(cudaGetLastError()); + + generateGoldenResult(); + raft::linalg::rowNorm(xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm(yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + protected: + Inputs params; + rmm::device_uvector x; + rmm::device_uvector y; + rmm::device_uvector adj; + rmm::device_uvector group_idxs; + rmm::device_uvector xn; + rmm::device_uvector yn; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; + rmm::device_uvector workspace; + raft::handle_t handle; + cudaStream_t stream; + + virtual void generateGoldenResult() + { + int m = params.m; + int n = params.n; + int k = params.k; + int num_groups = params.num_groups; + + naive(min_ref.data(), + x.data(), + y.data(), + adj.data(), + group_idxs.data(), + m, + n, + k, + num_groups, + (int*)workspace.data(), + stream); + } + + void runTest(cub::KeyValuePair* out) + { + int m = params.m; + int n = params.n; + int k = params.k; + int num_groups = params.num_groups; + + MinAndDistanceReduceOp redOp; + sparseL2NN, int>( + out, + x.data(), + y.data(), + xn.data(), + yn.data(), + adj.data(), + group_idxs.data(), + num_groups, + m, + n, + k, + (void*)workspace.data(), + redOp, + raft::distance::KVPMinReduce(), + Sqrt, + true, + stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } +}; + +template +struct CompareApproxAbsKVP { + typedef typename cub::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP& a, const KVP& b) const + { + T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); + T m = std::max(raft::abs(a.value), raft::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +struct CompareExactKVP { + typedef typename cub::KeyValuePair KVP; + bool operator()(const KVP& a, const KVP& b) const + { + if (a.value != b.value) return false; + return true; + } +}; + +template +::testing::AssertionResult devArrMatch(const cub::KeyValuePair* expected, + const cub::KeyValuePair* actual, + size_t size, + L eq_compare, + cudaStream_t stream = 0) +{ + typedef typename cub::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value << " != expected=" << exp.key << "," + << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +const std::vector> inputsf = { + {0.001f, 32, 32, 32, 2, 1234ULL, AdjacencyPattern::all_true}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_true}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_false}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_4}, + {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_64}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_true}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_false}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_4}, + {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_64}, + {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::all_true}, + {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::all_false}, + {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, +}; + +typedef SparseL2NNTest SparseL2NNTestF_Sq; +TEST_P(SparseL2NNTestF_Sq, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestF_Sq, ::testing::ValuesIn(inputsf)); +typedef SparseL2NNTest SparseL2NNTestF_Sqrt; +TEST_P(SparseL2NNTestF_Sqrt, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestF_Sqrt, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.00001, 32, 32, 32, 2, 1234ULL, AdjacencyPattern::all_true}, + + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_true}, + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_false}, + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_4}, + {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_64}, + + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_true}, + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_false}, + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard}, + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_4}, + {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_64}, +}; +typedef SparseL2NNTest SparseL2NNTestD_Sq; +TEST_P(SparseL2NNTestD_Sq, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestD_Sq, ::testing::ValuesIn(inputsd)); +typedef SparseL2NNTest SparseL2NNTestD_Sqrt; +TEST_P(SparseL2NNTestD_Sqrt, Result) +{ + runTest(min.data()); + ASSERT_TRUE(devArrMatch( + min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestD_Sqrt, ::testing::ValuesIn(inputsd)); + +/// This is to test output determinism of the prim +template +class SparseL2NNDetTest : public SparseL2NNTest { + public: + SparseL2NNDetTest() : stream(handle.get_stream()), min1(0, stream) {} + + void SetUp() override + { + SparseL2NNTest::SetUp(); + int m = this->params.m; + min1.resize(m, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + } + + void TearDown() override { SparseL2NNTest::TearDown(); } + + protected: + raft::handle_t handle; + cudaStream_t stream; + + rmm::device_uvector> min1; + + static const int NumRepeats = 100; + + void generateGoldenResult() override {} +}; + +typedef SparseL2NNDetTest SparseL2NNDetTestF_Sq; +TEST_P(SparseL2NNDetTestF_Sq, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestF_Sq, ::testing::ValuesIn(inputsf)); +typedef SparseL2NNDetTest SparseL2NNDetTestF_Sqrt; +TEST_P(SparseL2NNDetTestF_Sqrt, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestF_Sqrt, ::testing::ValuesIn(inputsf)); + +typedef SparseL2NNDetTest SparseL2NNDetTestD_Sq; +TEST_P(SparseL2NNDetTestD_Sq, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestD_Sq, ::testing::ValuesIn(inputsd)); +typedef SparseL2NNDetTest SparseL2NNDetTestD_Sqrt; +TEST_P(SparseL2NNDetTestD_Sqrt, Result) +{ + runTest(min.data()); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1.data()); + ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); + } +} +INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestD_Sqrt, ::testing::ValuesIn(inputsd)); + +} // end namespace sparse_l2_nn +} // end namespace distance +} // end namespace raft From 69687d6378408330f894d3621df7485df13a7548 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 22 Sep 2022 14:39:53 +0200 Subject: [PATCH 07/49] Add sparseL2NN benchmarks --- cpp/bench/CMakeLists.txt | 1 + cpp/bench/distance/sparse_l2_nn.cu | 194 +++++++++++++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 cpp/bench/distance/sparse_l2_nn.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 6b985acfc3..8808393db7 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -82,6 +82,7 @@ if(BUILD_BENCH) bench/distance/distance_l1.cu bench/distance/distance_unexp_l2.cu bench/distance/fused_l2_nn.cu + bench/spatial/sparse_l2_nn.cu bench/distance/kernels.cu bench/main.cpp OPTIONAL diff --git a/cpp/bench/distance/sparse_l2_nn.cu b/cpp/bench/distance/sparse_l2_nn.cu new file mode 100644 index 0000000000..a505bcc996 --- /dev/null +++ b/cpp/bench/distance/sparse_l2_nn.cu @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if defined RAFT_NN_COMPILED +#include +#endif + +namespace raft::bench::spatial::sparse { + +// Introduce various sparsity patterns +enum SparsityPattern { + checkerboard = 0, + checkerboard_4 = 1, + checkerboard_64 = 2, + all_true = 3, + all_false = 4 +}; + +struct sparse_l2_nn_inputs { + int m, n, k, num_groups; + SparsityPattern pattern; +}; // struct sparse_l2_nn_inputs + +__global__ void init_adj( + int m, int n, int num_groups, SparsityPattern pattern, bool* adj, int* group_idxs) +{ + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_groups; i += blockDim.y * gridDim.y) { + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < m; j += blockDim.x * gridDim.x) { + switch (pattern) { + case checkerboard: adj[i * m + j] = (i + j) % 2; break; + case checkerboard_4: adj[i * m + j] = (i + (j / 4)) % 2; break; + case checkerboard_64: adj[i * m + j] = (i + (j / 64)) % 2; break; + case all_true: adj[i * m + j] = true; break; + case all_false: adj[i * m + j] = false; break; + default: assert(false && "unknown pattern"); + } + } + } + // Each group is of size n / num_groups. + // + // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive + // scan of the group lengths) + // + // - The first group always starts at index zero, so we do not store it. + // + // - The group_idxs[num_groups - 1] should always equal n. + + if (blockIdx.y == 0 && threadIdx.y == 0) { + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_groups; + j += blockDim.x * gridDim.x) { + group_idxs[j] = (j + 1) * (n / num_groups); + } + group_idxs[num_groups - 1] = n; + } +} + +template +struct sparse_l2_nn : public fixture { + sparse_l2_nn(const sparse_l2_nn_inputs& p) + : params(p), + out(p.m, stream), + x(p.m * p.k, stream), + y(p.n * p.k, stream), + xn(p.m, stream), + yn(p.n, stream), + adj(p.m * p.num_groups, stream), + group_idxs(p.num_groups, stream), + workspace(p.m, stream) + { + raft::handle_t handle{stream}; + raft::random::RngState r(123456ULL); + + uniform(handle, r, x.data(), p.m * p.k, T(-1.0), T(1.0)); + uniform(handle, r, y.data(), p.n * p.k, T(-1.0), T(1.0)); + raft::linalg::rowNorm(xn.data(), x.data(), p.k, p.m, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm(yn.data(), y.data(), p.k, p.n, raft::linalg::L2Norm, true, stream); + raft::distance::initialize, int>( + handle, out.data(), p.m, std::numeric_limits::max(), op); + + dim3 block(32, 32); + dim3 grid(10, 10); + init_adj<<>>( + p.m, p.n, p.num_groups, p.pattern, adj.data(), group_idxs.data()); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + void run_benchmark(::benchmark::State& state) override + { + loop_on_state(state, [this]() { + // It is sufficient to only benchmark the L2-squared metric + raft::distance::sparseL2NN, int>(out.data(), + x.data(), + y.data(), + xn.data(), + yn.data(), + adj.data(), + group_idxs.data(), + params.num_groups, + params.m, + params.n, + params.k, + (void*)workspace.data(), + op, + pairRedOp, + false, + false, + stream); + }); + } + + private: + sparse_l2_nn_inputs params; + rmm::device_uvector x, y, xn, yn; + rmm::device_uvector adj; + rmm::device_uvector group_idxs; + rmm::device_uvector> out; + rmm::device_uvector workspace; + raft::distance::KVPMinReduce pairRedOp; + raft::distance::MinAndDistanceReduceOp op; +}; // struct SparseL2NN + +// TODO: Consider thinning the list of benchmark cases.. +const std::vector sparse_l2_nn_input_vecs = { + // Very fat matrices... + {32, 16384, 16384, 32, SparsityPattern::checkerboard}, + {64, 16384, 16384, 32, SparsityPattern::checkerboard}, + {128, 16384, 16384, 32, SparsityPattern::checkerboard}, + {256, 16384, 16384, 32, SparsityPattern::checkerboard}, + {512, 16384, 16384, 32, SparsityPattern::checkerboard}, + {1024, 16384, 16384, 32, SparsityPattern::checkerboard}, + {16384, 32, 16384, 32, SparsityPattern::checkerboard}, + {16384, 64, 16384, 32, SparsityPattern::checkerboard}, + {16384, 128, 16384, 32, SparsityPattern::checkerboard}, + {16384, 256, 16384, 32, SparsityPattern::checkerboard}, + {16384, 512, 16384, 32, SparsityPattern::checkerboard}, + {16384, 1024, 16384, 32, SparsityPattern::checkerboard}, + + // Representative matrices... + {16384, 16384, 32, 32, SparsityPattern::checkerboard}, + {16384, 16384, 64, 32, SparsityPattern::checkerboard}, + {16384, 16384, 128, 32, SparsityPattern::checkerboard}, + {16384, 16384, 256, 32, SparsityPattern::checkerboard}, + {16384, 16384, 512, 32, SparsityPattern::checkerboard}, + {16384, 16384, 1024, 32, SparsityPattern::checkerboard}, + {16384, 16384, 16384, 32, SparsityPattern::checkerboard}, + + {16384, 16384, 32, 32, SparsityPattern::checkerboard_4}, + {16384, 16384, 64, 32, SparsityPattern::checkerboard_4}, + {16384, 16384, 128, 32, SparsityPattern::checkerboard_4}, + {16384, 16384, 256, 32, SparsityPattern::checkerboard_4}, + {16384, 16384, 512, 32, SparsityPattern::checkerboard_4}, + {16384, 16384, 1024, 32, SparsityPattern::checkerboard_4}, + {16384, 16384, 16384, 32, SparsityPattern::checkerboard_4}, + + {16384, 16384, 32, 32, SparsityPattern::checkerboard_64}, + {16384, 16384, 64, 32, SparsityPattern::checkerboard_64}, + {16384, 16384, 128, 32, SparsityPattern::checkerboard_64}, + {16384, 16384, 256, 32, SparsityPattern::checkerboard_64}, + {16384, 16384, 512, 32, SparsityPattern::checkerboard_64}, + {16384, 16384, 1024, 32, SparsityPattern::checkerboard_64}, + {16384, 16384, 16384, 32, SparsityPattern::checkerboard_64}, +}; + +RAFT_BENCH_REGISTER(sparse_l2_nn, "", sparse_l2_nn_input_vecs); +// Do not benchmark double. + +} // namespace raft::bench::spatial::sparse From e74b5f40ec61bef81bca0d3cae3d1bd154d4cbd4 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 12:13:55 +0100 Subject: [PATCH 08/49] Rename files sparse_* => masked_* --- cpp/bench/CMakeLists.txt | 2 +- cpp/bench/distance/{sparse_l2_nn.cu => masked_l2_nn.cu} | 0 .../{sparse_distance_base.cuh => masked_distance_base.cuh} | 0 .../raft/distance/detail/{sparse_l2_nn.cuh => masked_l2_nn.cuh} | 0 .../raft/distance/{sparse_l2_nn.cuh => masked_l2_nn.cuh} | 0 cpp/test/CMakeLists.txt | 2 +- cpp/test/distance/{sparse_l2_nn.cu => masked_l2_nn.cu} | 0 7 files changed, 2 insertions(+), 2 deletions(-) rename cpp/bench/distance/{sparse_l2_nn.cu => masked_l2_nn.cu} (100%) rename cpp/include/raft/distance/detail/{sparse_distance_base.cuh => masked_distance_base.cuh} (100%) rename cpp/include/raft/distance/detail/{sparse_l2_nn.cuh => masked_l2_nn.cuh} (100%) rename cpp/include/raft/distance/{sparse_l2_nn.cuh => masked_l2_nn.cuh} (100%) rename cpp/test/distance/{sparse_l2_nn.cu => masked_l2_nn.cu} (100%) diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 8808393db7..b167305dad 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -82,7 +82,7 @@ if(BUILD_BENCH) bench/distance/distance_l1.cu bench/distance/distance_unexp_l2.cu bench/distance/fused_l2_nn.cu - bench/spatial/sparse_l2_nn.cu + bench/spatial/masked_l2_nn.cu bench/distance/kernels.cu bench/main.cpp OPTIONAL diff --git a/cpp/bench/distance/sparse_l2_nn.cu b/cpp/bench/distance/masked_l2_nn.cu similarity index 100% rename from cpp/bench/distance/sparse_l2_nn.cu rename to cpp/bench/distance/masked_l2_nn.cu diff --git a/cpp/include/raft/distance/detail/sparse_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh similarity index 100% rename from cpp/include/raft/distance/detail/sparse_distance_base.cuh rename to cpp/include/raft/distance/detail/masked_distance_base.cuh diff --git a/cpp/include/raft/distance/detail/sparse_l2_nn.cuh b/cpp/include/raft/distance/detail/masked_l2_nn.cuh similarity index 100% rename from cpp/include/raft/distance/detail/sparse_l2_nn.cuh rename to cpp/include/raft/distance/detail/masked_l2_nn.cuh diff --git a/cpp/include/raft/distance/sparse_l2_nn.cuh b/cpp/include/raft/distance/masked_l2_nn.cuh similarity index 100% rename from cpp/include/raft/distance/sparse_l2_nn.cuh rename to cpp/include/raft/distance/masked_l2_nn.cuh diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 8039e0277e..ee8359d9ce 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -121,7 +121,7 @@ if(BUILD_TESTS) test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu - test/distance/sparse_l2_nn.cu + test/distance/masked_l2_nn.cu test/distance/gram.cu OPTIONAL DIST diff --git a/cpp/test/distance/sparse_l2_nn.cu b/cpp/test/distance/masked_l2_nn.cu similarity index 100% rename from cpp/test/distance/sparse_l2_nn.cu rename to cpp/test/distance/masked_l2_nn.cu From 7bf5801e1ccdbe4fbec27babadcaf0bb45886ade Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 16:23:20 +0100 Subject: [PATCH 09/49] Rename functions sparse_* => masked_* - Rename functions, classes, tests, and benchmarks. - Add raft::handle parameter to public API - Remove stream parameter --- cpp/bench/distance/masked_l2_nn.cu | 32 ++--- .../distance/detail/masked_distance_base.cuh | 11 +- .../raft/distance/detail/masked_l2_nn.cuh | 41 +++---- cpp/include/raft/distance/masked_l2_nn.cuh | 70 +++++++---- cpp/test/distance/masked_l2_nn.cu | 113 +++++++++--------- 5 files changed, 145 insertions(+), 122 deletions(-) diff --git a/cpp/bench/distance/masked_l2_nn.cu b/cpp/bench/distance/masked_l2_nn.cu index a505bcc996..c2ab9750b7 100644 --- a/cpp/bench/distance/masked_l2_nn.cu +++ b/cpp/bench/distance/masked_l2_nn.cu @@ -22,7 +22,7 @@ #include #include -#include +#include #include #include #include @@ -32,7 +32,7 @@ #include #endif -namespace raft::bench::spatial::sparse { +namespace raft::bench::spatial::masked { // Introduce various sparsity patterns enum SparsityPattern { @@ -43,10 +43,10 @@ enum SparsityPattern { all_false = 4 }; -struct sparse_l2_nn_inputs { +struct masked_l2_nn_inputs { int m, n, k, num_groups; SparsityPattern pattern; -}; // struct sparse_l2_nn_inputs +}; // struct masked_l2_nn_inputs __global__ void init_adj( int m, int n, int num_groups, SparsityPattern pattern, bool* adj, int* group_idxs) @@ -82,8 +82,8 @@ __global__ void init_adj( } template -struct sparse_l2_nn : public fixture { - sparse_l2_nn(const sparse_l2_nn_inputs& p) +struct masked_l2_nn : public fixture { + masked_l2_nn(const masked_l2_nn_inputs& p) : params(p), out(p.m, stream), x(p.m * p.k, stream), @@ -101,7 +101,7 @@ struct sparse_l2_nn : public fixture { uniform(handle, r, y.data(), p.n * p.k, T(-1.0), T(1.0)); raft::linalg::rowNorm(xn.data(), x.data(), p.k, p.m, raft::linalg::L2Norm, true, stream); raft::linalg::rowNorm(yn.data(), y.data(), p.k, p.n, raft::linalg::L2Norm, true, stream); - raft::distance::initialize, int>( + raft::distance::initialize, int>( handle, out.data(), p.m, std::numeric_limits::max(), op); dim3 block(32, 32); @@ -115,7 +115,8 @@ struct sparse_l2_nn : public fixture { { loop_on_state(state, [this]() { // It is sufficient to only benchmark the L2-squared metric - raft::distance::sparseL2NN, int>(out.data(), + raft::distance::maskedL2NN, int>(handle, + out.data(), x.data(), y.data(), xn.data(), @@ -130,24 +131,23 @@ struct sparse_l2_nn : public fixture { op, pairRedOp, false, - false, - stream); + false); }); } private: - sparse_l2_nn_inputs params; + masked_l2_nn_inputs params; rmm::device_uvector x, y, xn, yn; rmm::device_uvector adj; rmm::device_uvector group_idxs; - rmm::device_uvector> out; + rmm::device_uvector> out; rmm::device_uvector workspace; raft::distance::KVPMinReduce pairRedOp; raft::distance::MinAndDistanceReduceOp op; -}; // struct SparseL2NN +}; // struct MaskedL2NN // TODO: Consider thinning the list of benchmark cases.. -const std::vector sparse_l2_nn_input_vecs = { +const std::vector masked_l2_nn_input_vecs = { // Very fat matrices... {32, 16384, 16384, 32, SparsityPattern::checkerboard}, {64, 16384, 16384, 32, SparsityPattern::checkerboard}, @@ -188,7 +188,7 @@ const std::vector sparse_l2_nn_input_vecs = { {16384, 16384, 16384, 32, SparsityPattern::checkerboard_64}, }; -RAFT_BENCH_REGISTER(sparse_l2_nn, "", sparse_l2_nn_input_vecs); +RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); // Do not benchmark double. -} // namespace raft::bench::spatial::sparse +} // namespace raft::bench::spatial::masked diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index 6e51ccbab3..4112916c71 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -70,7 +70,7 @@ template > -struct SparseDistances : public BaseClass { +struct MaskedDistances : public BaseClass { private: typedef Policy P; const DataT* xn; @@ -89,7 +89,7 @@ struct SparseDistances : public BaseClass { public: // Constructor - DI SparseDistances(const DataT* _x, + DI MaskedDistances(const DataT* _x, const DataT* _y, IdxT _m, IdxT _n, @@ -156,7 +156,6 @@ struct SparseDistances : public BaseClass { for (; tile_idx_n < tile_end_n; tile_idx_n += P::Nblk) { // We provide tile_end_n to limit the number of unnecessary data // points that are loaded from y. - // TODO: determine if this actually improves performance. this->ldgXY(tile_idx_m, tile_idx_n, 0, tile_end_n); reset_accumulator(); @@ -279,7 +278,7 @@ struct SparseDistances : public BaseClass { regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } } -}; // struct SparseDistances +}; // struct MaskedDistances /** * @brief the distance matrix calculation kernel for L1, L2 and cosine @@ -324,7 +323,7 @@ template __global__ __launch_bounds__(Policy::Nthreads, 2) - void sparseDistanceMatKernel(const DataT* x, + void maskedDistanceMatKernel(const DataT* x, const DataT* y, const DataT* _xn, const DataT* _yn, @@ -342,7 +341,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) extern __shared__ char smem[]; auto rowEpilog = [] __device__(IdxT starty) { return; }; - SparseDistances #include #include #include #include -#include +#include #include #include #include @@ -45,7 +44,7 @@ template -__global__ __launch_bounds__(P::Nthreads, 2) void sparseL2NNkernel(OutT* min, +__global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, const DataT* x, const DataT* y, const DataT* xn, @@ -65,7 +64,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void sparseL2NNkernel(OutT* min, { extern __shared__ char smem[]; - typedef cub::KeyValuePair KVPair; + typedef raft::KeyValuePair KVPair; KVPair val[P::AccRowsPerTh]; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -95,7 +94,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void sparseL2NNkernel(OutT* min, for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } @@ -152,7 +151,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void sparseL2NNkernel(OutT* min, }; IdxT lda = k, ldb = k, ldd = n; - SparseDistances -void sparseL2NNImpl(OutT* min, +void maskedL2NNImpl(raft::handle_t const& handle, + OutT* min, const DataT* x, const DataT* y, const DataT* xn, @@ -200,17 +200,15 @@ void sparseL2NNImpl(OutT* min, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, - bool initOutBuffer, - cudaStream_t stream) + bool initOutBuffer) { typedef typename linalg::Policy4x4::Policy P; - static_assert(P::Mblk == 64, "sparseL2NNImpl only supports a policy with 64 rows per block."); + static_assert(P::Mblk == 64, "maskedL2NNImpl only supports a policy with 64 rows per block."); - // First, compress boolean to bitfield. + cudaStream_t stream = handle.get_stream(); - // TODO 1: Remove allocation; use workspace instead(?) - // TODO 2: Use a faster compress_to_bits implementation that does not require a pre-zeroed output. + // first, compress boolean to bitfield. rmm::device_uvector adj64(raft::ceildiv(m, IdxT(64)) * num_groups, stream); RAFT_CUDA_TRY(cudaMemsetAsync(adj64.data(), 0, adj64.size() * sizeof(uint64_t), stream)); dim3 compress_grid(raft::ceildiv(m, 32), raft::ceildiv(num_groups, 32)); @@ -220,7 +218,7 @@ void sparseL2NNImpl(OutT* min, dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); constexpr auto maxVal = std::numeric_limits::max(); - typedef cub::KeyValuePair KVPair; + typedef raft::KeyValuePair KVPair; // Accumulation operation lambda auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; @@ -232,12 +230,11 @@ void sparseL2NNImpl(OutT* min, RAFT_CUDA_TRY(cudaGetLastError()); } - // TODO 3: remove fin_op - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; + auto fin_op = raft::identity_op{}; constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { - auto sparseL2NNSqrt = sparseL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, sparseL2NNSqrt); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, maskedL2NNSqrt); - sparseL2NNSqrt<<>>(min, + maskedL2NNSqrt<<>>(min, x, y, xn, @@ -266,7 +263,7 @@ void sparseL2NNImpl(OutT* min, core_lambda, fin_op); } else { - auto sparseL2NN = sparseL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, sparseL2NN); - sparseL2NN<<>>(min, + dim3 grid = launchConfigGenerator

(m, n, shmemSize, maskedL2NN); + maskedL2NN<<>>(min, x, y, xn, diff --git a/cpp/include/raft/distance/masked_l2_nn.cuh b/cpp/include/raft/distance/masked_l2_nn.cuh index c690702cb4..ff8c84d5ad 100644 --- a/cpp/include/raft/distance/masked_l2_nn.cuh +++ b/cpp/include/raft/distance/masked_l2_nn.cuh @@ -14,14 +14,13 @@ * limitations under the License. */ -#ifndef __SPARSE_L2_NN_H -#define __SPARSE_L2_NN_H +#ifndef __MASKED_L2_NN_H +#define __MASKED_L2_NN_H #pragma once -#include #include -#include +#include #include #include #include @@ -31,12 +30,28 @@ namespace raft { namespace distance { /** - * @brief Sparse L2 distance and 1-nearest-neighbor computation in a single call. + * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. + * This function enables faster computation of nearest neighbors if the + * computation of distances between certain point pairs can be skipped. + * + * To avoid using a full adjacency matrix between all points in `x` and `y`, the + * points in `y` are divided into groups. An adjacency matrix describes for each + * point in `x` and each group whether to compute the distance. + * + * **Performance considerations** + * + * The points in `x` are grouped into tiles of `M` points (`M` is currently 64, + * but may change in the future). As a result, the largest compute time + * reduction occurs if all `M` points can skip a group. If only part of the `M` + * points can skip a group, then at most a minor compute time reduction and a + * modest energy use reduction can be expected. + * + * The points in `y` are also grouped into tiles of `N` points (`N` is currently + * 64, but may change in the future). As a result, group sizes should be larger + * than `N` to avoid wasting computational resources. If the group sizes are + * evenly divisible by `N`, then the computation is most efficient, although for + * larger group sizes this effect is minor. * * @tparam DataT data type * @tparam OutT output type to either store 1-NN indices and their minimum @@ -47,6 +62,7 @@ namespace distance { * and also to initialize the output array elements with the * appropriate initial value needed for reduction. * + * @param handle RAFT handle for managing expensive resources * @param[out] min will contain the reduced output (Length = `m`) * (on device) * @param[in] x first matrix. Row major. Dim = `m x k`. @@ -55,19 +71,34 @@ namespace distance { * (on device). * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] adj A boolean adjacency matrix indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `m x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups The number of groups in `y`. * @param[in] m gemm m * @param[in] n gemm n * @param[in] k gemm k * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] pairRedOp reduction operation on key value pairs * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt * @param[in] initOutBuffer whether to initialize the output buffer before the * main kernel launch - * @param[in] stream cuda stream */ -template -void sparseL2NN(OutT* min, +template +void maskedL2NN(raft::handle_t const& handle, + OutT* min, const DataT* x, const DataT* y, const DataT* xn, @@ -82,13 +113,10 @@ void sparseL2NN(OutT* min, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, - bool initOutBuffer, - cudaStream_t stream) + bool initOutBuffer) { - // TODO: decide on kernel policy based on skinniness of the matrices. If k is - // low, it may make sense to use another kernel policy, like in - // fused_l2_nn.cuh. - detail::sparseL2NNImpl(min, + detail::maskedL2NNImpl(handle, + min, x, y, xn, @@ -103,9 +131,7 @@ void sparseL2NN(OutT* min, redOp, pairRedOp, sqrt, - initOutBuffer, - stream); - // } + initOutBuffer); } } // namespace distance diff --git a/cpp/test/distance/masked_l2_nn.cu b/cpp/test/distance/masked_l2_nn.cu index 293c78ddee..e4d44e8408 100644 --- a/cpp/test/distance/masked_l2_nn.cu +++ b/cpp/test/distance/masked_l2_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,9 @@ #include "../test_utils.h" #include -#include -#include +#include +#include +#include #include #include #include @@ -28,11 +29,11 @@ namespace raft { namespace distance { -namespace sparse_l2_nn { +namespace masked_l2_nn { template -struct CubKVPMinReduce { - typedef cub::KeyValuePair KVP; +struct RaftKVPMinReduce { + typedef raft::KeyValuePair KVP; DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } @@ -41,7 +42,7 @@ struct CubKVPMinReduce { }; // KVPMinReduce template -__global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(cub::KeyValuePair* min, +__global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(raft::KeyValuePair* min, DataT* x, DataT* y, bool* adj, @@ -78,15 +79,15 @@ __global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(cub::KeyValuePair< auto diff = x[xidx] - y[yidx]; acc += diff * diff; } - if (Sqrt) { acc = raft::mySqrt(acc); } + if (Sqrt) { acc = raft::sqrt(acc); } ReduceOpT redOp; - typedef cub::WarpReduce> WarpReduce; + typedef cub::WarpReduce> WarpReduce; __shared__ typename WarpReduce::TempStorage temp[NWARPS]; int warpId = threadIdx.x / raft::WarpSize; - cub::KeyValuePair tmp; + raft::KeyValuePair tmp; tmp.key = include_dist ? nidx : -1; tmp.value = include_dist ? acc : maxVal; - tmp = WarpReduce(temp[warpId]).Reduce(tmp, CubKVPMinReduce()); + tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); if (threadIdx.x % raft::WarpSize == 0 && midx < m) { while (atomicCAS(workspace + midx, 0, 1) == 1) ; @@ -101,7 +102,7 @@ __global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(cub::KeyValuePair< } template -void naive(cub::KeyValuePair* min, +void naive(raft::KeyValuePair* min, DataT* x, DataT* y, bool* adj, @@ -116,7 +117,7 @@ void naive(cub::KeyValuePair* min, RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); auto blks = raft::ceildiv(m, 256); MinAndDistanceReduceOp op; - raft::distance::detail::initKernel, int> + raft::distance::detail::initKernel, int> <<>>(min, m, std::numeric_limits::max(), op); RAFT_CUDA_TRY(cudaGetLastError()); @@ -199,9 +200,9 @@ __global__ void init_adj( } template -class SparseL2NNTest : public ::testing::TestWithParam> { +class MaskedL2NNTest : public ::testing::TestWithParam> { public: - SparseL2NNTest() + MaskedL2NNTest() : params(::testing::TestWithParam>::GetParam()), stream(handle.get_stream()), x(params.m * params.k, stream), @@ -247,8 +248,8 @@ class SparseL2NNTest : public ::testing::TestWithParam> { rmm::device_uvector group_idxs; rmm::device_uvector xn; rmm::device_uvector yn; - rmm::device_uvector> min; - rmm::device_uvector> min_ref; + rmm::device_uvector> min; + rmm::device_uvector> min_ref; rmm::device_uvector workspace; raft::handle_t handle; cudaStream_t stream; @@ -273,7 +274,7 @@ class SparseL2NNTest : public ::testing::TestWithParam> { stream); } - void runTest(cub::KeyValuePair* out) + void runTest(raft::KeyValuePair* out) { int m = params.m; int n = params.n; @@ -281,7 +282,8 @@ class SparseL2NNTest : public ::testing::TestWithParam> { int num_groups = params.num_groups; MinAndDistanceReduceOp redOp; - sparseL2NN, int>( + maskedL2NN, int>( + handle, out, x.data(), y.data(), @@ -297,15 +299,14 @@ class SparseL2NNTest : public ::testing::TestWithParam> { redOp, raft::distance::KVPMinReduce(), Sqrt, - true, - stream); + true); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } }; template struct CompareApproxAbsKVP { - typedef typename cub::KeyValuePair KVP; + typedef typename raft::KeyValuePair KVP; CompareApproxAbsKVP(T eps_) : eps(eps_) {} bool operator()(const KVP& a, const KVP& b) const { @@ -321,7 +322,7 @@ struct CompareApproxAbsKVP { template struct CompareExactKVP { - typedef typename cub::KeyValuePair KVP; + typedef typename raft::KeyValuePair KVP; bool operator()(const KVP& a, const KVP& b) const { if (a.value != b.value) return false; @@ -330,13 +331,13 @@ struct CompareExactKVP { }; template -::testing::AssertionResult devArrMatch(const cub::KeyValuePair* expected, - const cub::KeyValuePair* actual, +::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, + const raft::KeyValuePair* actual, size_t size, L eq_compare, cudaStream_t stream = 0) { - typedef typename cub::KeyValuePair KVP; + typedef typename raft::KeyValuePair KVP; std::shared_ptr exp_h(new KVP[size]); std::shared_ptr act_h(new KVP[size]); raft::update_host(exp_h.get(), expected, size, stream); @@ -371,22 +372,22 @@ const std::vector> inputsf = { {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, }; -typedef SparseL2NNTest SparseL2NNTestF_Sq; -TEST_P(SparseL2NNTestF_Sq, Result) +typedef MaskedL2NNTest MaskedL2NNTestF_Sq; +TEST_P(MaskedL2NNTestF_Sq, Result) { runTest(min.data()); ASSERT_TRUE(devArrMatch( min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); } -INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestF_Sq, ::testing::ValuesIn(inputsf)); -typedef SparseL2NNTest SparseL2NNTestF_Sqrt; -TEST_P(SparseL2NNTestF_Sqrt, Result) +INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestF_Sq, ::testing::ValuesIn(inputsf)); +typedef MaskedL2NNTest MaskedL2NNTestF_Sqrt; +TEST_P(MaskedL2NNTestF_Sqrt, Result) { runTest(min.data()); ASSERT_TRUE(devArrMatch( min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); } -INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestF_Sqrt, ::testing::ValuesIn(inputsf)); +INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestF_Sqrt, ::testing::ValuesIn(inputsf)); const std::vector> inputsd = { {0.00001, 32, 32, 32, 2, 1234ULL, AdjacencyPattern::all_true}, @@ -403,52 +404,52 @@ const std::vector> inputsd = { {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_4}, {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_64}, }; -typedef SparseL2NNTest SparseL2NNTestD_Sq; -TEST_P(SparseL2NNTestD_Sq, Result) +typedef MaskedL2NNTest MaskedL2NNTestD_Sq; +TEST_P(MaskedL2NNTestD_Sq, Result) { runTest(min.data()); ASSERT_TRUE(devArrMatch( min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); } -INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestD_Sq, ::testing::ValuesIn(inputsd)); -typedef SparseL2NNTest SparseL2NNTestD_Sqrt; -TEST_P(SparseL2NNTestD_Sqrt, Result) +INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestD_Sq, ::testing::ValuesIn(inputsd)); +typedef MaskedL2NNTest MaskedL2NNTestD_Sqrt; +TEST_P(MaskedL2NNTestD_Sqrt, Result) { runTest(min.data()); ASSERT_TRUE(devArrMatch( min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); } -INSTANTIATE_TEST_CASE_P(SparseL2NNTests, SparseL2NNTestD_Sqrt, ::testing::ValuesIn(inputsd)); +INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestD_Sqrt, ::testing::ValuesIn(inputsd)); /// This is to test output determinism of the prim template -class SparseL2NNDetTest : public SparseL2NNTest { +class MaskedL2NNDetTest : public MaskedL2NNTest { public: - SparseL2NNDetTest() : stream(handle.get_stream()), min1(0, stream) {} + MaskedL2NNDetTest() : stream(handle.get_stream()), min1(0, stream) {} void SetUp() override { - SparseL2NNTest::SetUp(); + MaskedL2NNTest::SetUp(); int m = this->params.m; min1.resize(m, stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } - void TearDown() override { SparseL2NNTest::TearDown(); } + void TearDown() override { MaskedL2NNTest::TearDown(); } protected: raft::handle_t handle; cudaStream_t stream; - rmm::device_uvector> min1; + rmm::device_uvector> min1; static const int NumRepeats = 100; void generateGoldenResult() override {} }; -typedef SparseL2NNDetTest SparseL2NNDetTestF_Sq; -TEST_P(SparseL2NNDetTestF_Sq, Result) +typedef MaskedL2NNDetTest MaskedL2NNDetTestF_Sq; +TEST_P(MaskedL2NNDetTestF_Sq, Result) { runTest(min.data()); // assumed to be golden for (int i = 0; i < NumRepeats; ++i) { @@ -456,9 +457,9 @@ TEST_P(SparseL2NNDetTestF_Sq, Result) ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); } } -INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestF_Sq, ::testing::ValuesIn(inputsf)); -typedef SparseL2NNDetTest SparseL2NNDetTestF_Sqrt; -TEST_P(SparseL2NNDetTestF_Sqrt, Result) +INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestF_Sq, ::testing::ValuesIn(inputsf)); +typedef MaskedL2NNDetTest MaskedL2NNDetTestF_Sqrt; +TEST_P(MaskedL2NNDetTestF_Sqrt, Result) { runTest(min.data()); // assumed to be golden for (int i = 0; i < NumRepeats; ++i) { @@ -466,10 +467,10 @@ TEST_P(SparseL2NNDetTestF_Sqrt, Result) ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); } } -INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestF_Sqrt, ::testing::ValuesIn(inputsf)); +INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestF_Sqrt, ::testing::ValuesIn(inputsf)); -typedef SparseL2NNDetTest SparseL2NNDetTestD_Sq; -TEST_P(SparseL2NNDetTestD_Sq, Result) +typedef MaskedL2NNDetTest MaskedL2NNDetTestD_Sq; +TEST_P(MaskedL2NNDetTestD_Sq, Result) { runTest(min.data()); // assumed to be golden for (int i = 0; i < NumRepeats; ++i) { @@ -477,9 +478,9 @@ TEST_P(SparseL2NNDetTestD_Sq, Result) ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); } } -INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestD_Sq, ::testing::ValuesIn(inputsd)); -typedef SparseL2NNDetTest SparseL2NNDetTestD_Sqrt; -TEST_P(SparseL2NNDetTestD_Sqrt, Result) +INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestD_Sq, ::testing::ValuesIn(inputsd)); +typedef MaskedL2NNDetTest MaskedL2NNDetTestD_Sqrt; +TEST_P(MaskedL2NNDetTestD_Sqrt, Result) { runTest(min.data()); // assumed to be golden for (int i = 0; i < NumRepeats; ++i) { @@ -487,8 +488,8 @@ TEST_P(SparseL2NNDetTestD_Sqrt, Result) ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); } } -INSTANTIATE_TEST_CASE_P(SparseL2NNDetTests, SparseL2NNDetTestD_Sqrt, ::testing::ValuesIn(inputsd)); +INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestD_Sqrt, ::testing::ValuesIn(inputsd)); -} // end namespace sparse_l2_nn +} // end namespace masked_l2_nn } // end namespace distance } // end namespace raft From 96ff79d92145b77508f205f6c613e0651db12c0a Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 17:23:31 +0100 Subject: [PATCH 10/49] Remove workspace parameter --- cpp/bench/distance/masked_l2_nn.cu | 5 +-- .../raft/distance/detail/masked_l2_nn.cuh | 31 ++++++++++++------- cpp/include/raft/distance/masked_l2_nn.cuh | 5 +-- cpp/test/distance/masked_l2_nn.cu | 1 - 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/cpp/bench/distance/masked_l2_nn.cu b/cpp/bench/distance/masked_l2_nn.cu index c2ab9750b7..80de97f31f 100644 --- a/cpp/bench/distance/masked_l2_nn.cu +++ b/cpp/bench/distance/masked_l2_nn.cu @@ -91,8 +91,7 @@ struct masked_l2_nn : public fixture { xn(p.m, stream), yn(p.n, stream), adj(p.m * p.num_groups, stream), - group_idxs(p.num_groups, stream), - workspace(p.m, stream) + group_idxs(p.num_groups, stream) { raft::handle_t handle{stream}; raft::random::RngState r(123456ULL); @@ -127,7 +126,6 @@ struct masked_l2_nn : public fixture { params.m, params.n, params.k, - (void*)workspace.data(), op, pairRedOp, false, @@ -141,7 +139,6 @@ struct masked_l2_nn : public fixture { rmm::device_uvector adj; rmm::device_uvector group_idxs; rmm::device_uvector> out; - rmm::device_uvector workspace; raft::distance::KVPMinReduce pairRedOp; raft::distance::MinAndDistanceReduceOp op; }; // struct MaskedL2NN diff --git a/cpp/include/raft/distance/detail/masked_l2_nn.cuh b/cpp/include/raft/distance/detail/masked_l2_nn.cuh index c8b5821d8e..e9226e0180 100644 --- a/cpp/include/raft/distance/detail/masked_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_l2_nn.cuh @@ -184,7 +184,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, } template -void maskedL2NNImpl(raft::handle_t const& handle, +void maskedL2NNImpl(raft::handle_t& handle, OutT* min, const DataT* x, const DataT* y, @@ -196,7 +196,6 @@ void maskedL2NNImpl(raft::handle_t const& handle, IdxT m, IdxT n, IdxT k, - int* workspace, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, @@ -206,14 +205,23 @@ void maskedL2NNImpl(raft::handle_t const& handle, static_assert(P::Mblk == 64, "maskedL2NNImpl only supports a policy with 64 rows per block."); - cudaStream_t stream = handle.get_stream(); + // Get stream and workspace memory resource + rmm::mr::device_memory_resource * ws_mr = dynamic_cast(handle.get_workspace_resource()); + auto stream = handle.get_stream(); - // first, compress boolean to bitfield. - rmm::device_uvector adj64(raft::ceildiv(m, IdxT(64)) * num_groups, stream); - RAFT_CUDA_TRY(cudaMemsetAsync(adj64.data(), 0, adj64.size() * sizeof(uint64_t), stream)); + // Acquire temporary buffers and initialize to zero: + // 1) Adjacency matrix bitfield + // 2) Workspace for fused nearest neighbor operation + size_t m_div_64 = raft::ceildiv(m, IdxT(64)); + rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; + rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; + RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); + + // Compress boolean adjacency matrix to bitfield. dim3 compress_grid(raft::ceildiv(m, 32), raft::ceildiv(num_groups, 32)); compress_to_bits_naive<<>>( - adj, num_groups, m, adj64.data()); + adj, num_groups, m, ws_adj64.data()); dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); @@ -223,7 +231,6 @@ void maskedL2NNImpl(raft::handle_t const& handle, // Accumulation operation lambda auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel <<>>(min, m, maxVal, redOp); @@ -250,14 +257,14 @@ void maskedL2NNImpl(raft::handle_t const& handle, y, xn, yn, - adj64.data(), + ws_adj64.data(), group_idxs, num_groups, m, n, k, maxVal, - workspace, + ws_fused_nn.data(), redOp, pairRedOp, core_lambda, @@ -278,14 +285,14 @@ void maskedL2NNImpl(raft::handle_t const& handle, y, xn, yn, - adj64.data(), + ws_adj64.data(), group_idxs, num_groups, m, n, k, maxVal, - workspace, + ws_fused_nn.data(), redOp, pairRedOp, core_lambda, diff --git a/cpp/include/raft/distance/masked_l2_nn.cuh b/cpp/include/raft/distance/masked_l2_nn.cuh index ff8c84d5ad..9fae354c12 100644 --- a/cpp/include/raft/distance/masked_l2_nn.cuh +++ b/cpp/include/raft/distance/masked_l2_nn.cuh @@ -85,7 +85,6 @@ namespace distance { * @param[in] m gemm m * @param[in] n gemm n * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) * @param[in] redOp reduction operator in the epilogue * @param[in] pairRedOp reduction operation on key value pairs * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt @@ -97,7 +96,7 @@ template -void maskedL2NN(raft::handle_t const& handle, +void maskedL2NN(raft::handle_t& handle, OutT* min, const DataT* x, const DataT* y, @@ -109,7 +108,6 @@ void maskedL2NN(raft::handle_t const& handle, IdxT m, IdxT n, IdxT k, - void* workspace, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, @@ -127,7 +125,6 @@ void maskedL2NN(raft::handle_t const& handle, m, n, k, - (int*)workspace, redOp, pairRedOp, sqrt, diff --git a/cpp/test/distance/masked_l2_nn.cu b/cpp/test/distance/masked_l2_nn.cu index e4d44e8408..2e97e1278a 100644 --- a/cpp/test/distance/masked_l2_nn.cu +++ b/cpp/test/distance/masked_l2_nn.cu @@ -295,7 +295,6 @@ class MaskedL2NNTest : public ::testing::TestWithParam> { m, n, k, - (void*)workspace.data(), redOp, raft::distance::KVPMinReduce(), Sqrt, From be5682681118d74fd42bbfafdbfb1242d3372bee Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 18:09:39 +0100 Subject: [PATCH 11/49] Use mdspan parameters --- cpp/bench/distance/masked_l2_nn.cu | 45 ++++++++++++------- cpp/include/raft/distance/masked_l2_nn.cuh | 51 ++++++++++++---------- cpp/test/distance/masked_l2_nn.cu | 31 ++++++++----- 3 files changed, 75 insertions(+), 52 deletions(-) diff --git a/cpp/bench/distance/masked_l2_nn.cu b/cpp/bench/distance/masked_l2_nn.cu index 80de97f31f..db6bf91936 100644 --- a/cpp/bench/distance/masked_l2_nn.cu +++ b/cpp/bench/distance/masked_l2_nn.cu @@ -113,23 +113,36 @@ struct masked_l2_nn : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { + using DataT = T; + using IdxT = int; + using OutT = raft::KeyValuePair; + + IdxT m = params.m; + IdxT n = params.n; + IdxT k = params.k; + IdxT num_groups = params.num_groups; + + auto out_view = raft::make_device_vector_view(out.data(), m); + auto x_view = raft::make_device_matrix_view(x.data(), m, k); + auto y_view = raft::make_device_matrix_view(y.data(), n, k); + auto x_norm = raft::make_device_vector_view(xn.data(), m); + auto y_norm = raft::make_device_vector_view(yn.data(), n); + auto adj_view = raft::make_device_matrix_view(adj.data(), m, num_groups); + auto group_idxs_view = raft::make_device_vector_view(group_idxs.data(), num_groups); + // It is sufficient to only benchmark the L2-squared metric - raft::distance::maskedL2NN, int>(handle, - out.data(), - x.data(), - y.data(), - xn.data(), - yn.data(), - adj.data(), - group_idxs.data(), - params.num_groups, - params.m, - params.n, - params.k, - op, - pairRedOp, - false, - false); + raft::distance::maskedL2NN(handle, + out_view, + x_view, + y_view, + x_norm, + y_norm, + adj_view, + group_idxs_view, + op, + pairRedOp, + false, + false); }); } diff --git a/cpp/include/raft/distance/masked_l2_nn.cuh b/cpp/include/raft/distance/masked_l2_nn.cuh index 9fae354c12..b19ed981c0 100644 --- a/cpp/include/raft/distance/masked_l2_nn.cuh +++ b/cpp/include/raft/distance/masked_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -91,36 +91,39 @@ namespace distance { * @param[in] initOutBuffer whether to initialize the output buffer before the * main kernel launch */ -template +template void maskedL2NN(raft::handle_t& handle, - OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const bool* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, + raft::device_vector_view out, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_vector_view const x_norm, + raft::device_vector_view const y_norm, + raft::device_matrix_view const adj, + raft::device_vector_view const group_idxs, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, bool initOutBuffer) { + // TODO: add more assertions. + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + + IdxT m = x.extent(0); + IdxT n = y.extent(0); + IdxT k = x.extent(1); + IdxT num_groups = group_idxs.extent(0); + detail::maskedL2NNImpl(handle, - min, - x, - y, - xn, - yn, - adj, - group_idxs, + out.data_handle(), + x.data_handle(), + y.data_handle(), + x_norm.data_handle(), + y_norm.data_handle(), + adj.data_handle(), + group_idxs.data_handle(), num_groups, m, n, diff --git a/cpp/test/distance/masked_l2_nn.cu b/cpp/test/distance/masked_l2_nn.cu index 2e97e1278a..2072ce2731 100644 --- a/cpp/test/distance/masked_l2_nn.cu +++ b/cpp/test/distance/masked_l2_nn.cu @@ -16,6 +16,7 @@ #include "../test_utils.h" #include +#include #include #include #include @@ -281,20 +282,26 @@ class MaskedL2NNTest : public ::testing::TestWithParam> { int k = params.k; int num_groups = params.num_groups; + auto out_view = raft::make_device_vector_view(out, m); + auto x_view = raft::make_device_matrix_view(x.data(), m, k); + auto y_view = raft::make_device_matrix_view(y.data(), n, k); + auto x_norm = raft::make_device_vector_view(xn.data(), m); + auto y_norm = raft::make_device_vector_view(yn.data(), n); + auto adj_view = raft::make_device_matrix_view(adj.data(), m, num_groups); + auto group_idxs_view = raft::make_device_vector_view(group_idxs.data(), num_groups); + MinAndDistanceReduceOp redOp; - maskedL2NN, int>( + using IdxT = int; + + maskedL2NN, IdxT>( handle, - out, - x.data(), - y.data(), - xn.data(), - yn.data(), - adj.data(), - group_idxs.data(), - num_groups, - m, - n, - k, + out_view, + x_view, + y_view, + x_norm, + y_norm, + adj_view, + group_idxs_view, redOp, raft::distance::KVPMinReduce(), Sqrt, From 07c2f53ffb02210c007885d0039786aaf04ebd62 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 19:01:52 +0100 Subject: [PATCH 12/49] Remove uvectors from benchmark Use mdarray/mdspans directly. --- cpp/bench/distance/masked_l2_nn.cu | 85 +++++++++++++----------------- 1 file changed, 38 insertions(+), 47 deletions(-) diff --git a/cpp/bench/distance/masked_l2_nn.cu b/cpp/bench/distance/masked_l2_nn.cu index db6bf91936..434662fb09 100644 --- a/cpp/bench/distance/masked_l2_nn.cu +++ b/cpp/bench/distance/masked_l2_nn.cu @@ -22,6 +22,8 @@ #include #include +#include +#include #include #include #include @@ -83,77 +85,66 @@ __global__ void init_adj( template struct masked_l2_nn : public fixture { + using DataT = T; + using IdxT = int; + using OutT = raft::KeyValuePair; + + // Parameters + masked_l2_nn_inputs params; + // Data + raft::device_vector out; + raft::device_matrix x, y; + raft::device_vector xn, yn; + raft::device_matrix adj; + raft::device_vector group_idxs; + // Reduction operators + raft::distance::KVPMinReduce pairRedOp; + raft::distance::MinAndDistanceReduceOp op; + masked_l2_nn(const masked_l2_nn_inputs& p) : params(p), - out(p.m, stream), - x(p.m * p.k, stream), - y(p.n * p.k, stream), - xn(p.m, stream), - yn(p.n, stream), - adj(p.m * p.num_groups, stream), - group_idxs(p.num_groups, stream) + out{raft::make_device_vector(handle, p.m)}, + x{raft::make_device_matrix(handle, p.m, p.k)}, + y{raft::make_device_matrix(handle, p.n, p.k)}, + xn{raft::make_device_vector(handle, p.m)}, + yn{raft::make_device_vector(handle, p.n)}, + adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, + group_idxs{raft::make_device_vector(handle, p.num_groups)} { - raft::handle_t handle{stream}; raft::random::RngState r(123456ULL); - uniform(handle, r, x.data(), p.m * p.k, T(-1.0), T(1.0)); - uniform(handle, r, y.data(), p.n * p.k, T(-1.0), T(1.0)); - raft::linalg::rowNorm(xn.data(), x.data(), p.k, p.m, raft::linalg::L2Norm, true, stream); - raft::linalg::rowNorm(yn.data(), y.data(), p.k, p.n, raft::linalg::L2Norm, true, stream); + uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0)); + uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0)); + raft::linalg::rowNorm(xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm(yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); raft::distance::initialize, int>( - handle, out.data(), p.m, std::numeric_limits::max(), op); + handle, out.data_handle(), p.m, std::numeric_limits::max(), op); dim3 block(32, 32); dim3 grid(10, 10); init_adj<<>>( - p.m, p.n, p.num_groups, p.pattern, adj.data(), group_idxs.data()); + p.m, p.n, p.num_groups, p.pattern, adj.data_handle(), group_idxs.data_handle()); RAFT_CUDA_TRY(cudaGetLastError()); } void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - using DataT = T; - using IdxT = int; - using OutT = raft::KeyValuePair; - - IdxT m = params.m; - IdxT n = params.n; - IdxT k = params.k; - IdxT num_groups = params.num_groups; - - auto out_view = raft::make_device_vector_view(out.data(), m); - auto x_view = raft::make_device_matrix_view(x.data(), m, k); - auto y_view = raft::make_device_matrix_view(y.data(), n, k); - auto x_norm = raft::make_device_vector_view(xn.data(), m); - auto y_norm = raft::make_device_vector_view(yn.data(), n); - auto adj_view = raft::make_device_matrix_view(adj.data(), m, num_groups); - auto group_idxs_view = raft::make_device_vector_view(group_idxs.data(), num_groups); - // It is sufficient to only benchmark the L2-squared metric raft::distance::maskedL2NN(handle, - out_view, - x_view, - y_view, - x_norm, - y_norm, - adj_view, - group_idxs_view, + out.view(), + x.view(), + y.view(), + xn.view(), + yn.view(), + adj.view(), + group_idxs.view(), op, pairRedOp, false, false); }); } - - private: - masked_l2_nn_inputs params; - rmm::device_uvector x, y, xn, yn; - rmm::device_uvector adj; - rmm::device_uvector group_idxs; - rmm::device_uvector> out; - raft::distance::KVPMinReduce pairRedOp; - raft::distance::MinAndDistanceReduceOp op; }; // struct MaskedL2NN // TODO: Consider thinning the list of benchmark cases.. From 2f234816206d041e63a784d06dc02f676b17d19b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 19:49:16 +0100 Subject: [PATCH 13/49] Use parameter struct in public API --- cpp/bench/distance/masked_l2_nn.cu | 21 ++++----- cpp/include/raft/distance/masked_l2_nn.cuh | 52 +++++++++++++--------- cpp/test/distance/masked_l2_nn.cu | 24 +++++----- 3 files changed, 57 insertions(+), 40 deletions(-) diff --git a/cpp/bench/distance/masked_l2_nn.cu b/cpp/bench/distance/masked_l2_nn.cu index 434662fb09..69f4487500 100644 --- a/cpp/bench/distance/masked_l2_nn.cu +++ b/cpp/bench/distance/masked_l2_nn.cu @@ -88,6 +88,9 @@ struct masked_l2_nn : public fixture { using DataT = T; using IdxT = int; using OutT = raft::KeyValuePair; + using RedOpT = raft::distance::MinAndDistanceReduceOp; + using PairRedOpT = raft::distance::KVPMinReduce; + using ParamT = raft::distance::MaskedL2NNParams; // Parameters masked_l2_nn_inputs params; @@ -97,9 +100,6 @@ struct masked_l2_nn : public fixture { raft::device_vector xn, yn; raft::device_matrix adj; raft::device_vector group_idxs; - // Reduction operators - raft::distance::KVPMinReduce pairRedOp; - raft::distance::MinAndDistanceReduceOp op; masked_l2_nn(const masked_l2_nn_inputs& p) : params(p), @@ -118,7 +118,7 @@ struct masked_l2_nn : public fixture { raft::linalg::rowNorm(xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); raft::linalg::rowNorm(yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); raft::distance::initialize, int>( - handle, out.data_handle(), p.m, std::numeric_limits::max(), op); + handle, out.data_handle(), p.m, std::numeric_limits::max(), RedOpT{}); dim3 block(32, 32); dim3 grid(10, 10); @@ -129,20 +129,21 @@ struct masked_l2_nn : public fixture { void run_benchmark(::benchmark::State& state) override { - loop_on_state(state, [this]() { + bool init_out = false; + bool sqrt = false; + ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; + + loop_on_state(state, [this, masked_l2_params]() { // It is sufficient to only benchmark the L2-squared metric raft::distance::maskedL2NN(handle, - out.view(), + masked_l2_params, x.view(), y.view(), xn.view(), yn.view(), adj.view(), group_idxs.view(), - op, - pairRedOp, - false, - false); + out.view()); }); } }; // struct MaskedL2NN diff --git a/cpp/include/raft/distance/masked_l2_nn.cuh b/cpp/include/raft/distance/masked_l2_nn.cuh index b19ed981c0..60380e7493 100644 --- a/cpp/include/raft/distance/masked_l2_nn.cuh +++ b/cpp/include/raft/distance/masked_l2_nn.cuh @@ -29,6 +29,29 @@ namespace raft { namespace distance { +/** + * @brief Parameters for maskedL2NN function + * + * Prescribes how to reduce a distance to an intermediate type (`redOp`), and + * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is + * mapped to an (index, value) pair and (index, value) pair with the lowest + * value (distance) is selected. + * + * In addition, prescribes whether to compute the square root of the distance + * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). + */ +template +struct MaskedL2NNParams { + /** Reduction operator in the epilogue */ + ReduceOpT redOp; + /** Reduction operation on key value pairs */ + KVPReduceOpT pairRedOp; + /** Whether the output `minDist` should contain L2-sqrt */ + bool sqrt; + /** Whether to initialize the output buffer before the main kernel launch */ + bool initOutBuffer; +}; + /** * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. * @@ -63,8 +86,7 @@ namespace distance { * appropriate initial value needed for reduction. * * @param handle RAFT handle for managing expensive resources - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) + * @param params Parameter struct specifying the reduction operations. * @param[in] x first matrix. Row major. Dim = `m x k`. * (on device). * @param[in] y second matrix. Row major. Dim = `n x k`. @@ -81,29 +103,19 @@ namespace distance { * always assumed to start at index 0 and the last * group typically ends at index `n`. Length = * `num_groups`. - * @param[in] num_groups The number of groups in `y`. - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) */ template void maskedL2NN(raft::handle_t& handle, - raft::device_vector_view out, + MaskedL2NNParams params, raft::device_matrix_view const x, raft::device_matrix_view const y, raft::device_vector_view const x_norm, raft::device_vector_view const y_norm, raft::device_matrix_view const adj, raft::device_vector_view const group_idxs, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer) + raft::device_vector_view out) { // TODO: add more assertions. RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); @@ -128,10 +140,10 @@ void maskedL2NN(raft::handle_t& handle, m, n, k, - redOp, - pairRedOp, - sqrt, - initOutBuffer); + params.redOp, + params.pairRedOp, + params.sqrt, + params.initOutBuffer); } } // namespace distance diff --git a/cpp/test/distance/masked_l2_nn.cu b/cpp/test/distance/masked_l2_nn.cu index 2072ce2731..e92eed3e9b 100644 --- a/cpp/test/distance/masked_l2_nn.cu +++ b/cpp/test/distance/masked_l2_nn.cu @@ -277,35 +277,39 @@ class MaskedL2NNTest : public ::testing::TestWithParam> { void runTest(raft::KeyValuePair* out) { + using IdxT = int; + using OutT = raft::KeyValuePair; + using RedOpT = MinAndDistanceReduceOp; + using PairRedOpT = raft::distance::KVPMinReduce; + using ParamT = MaskedL2NNParams; + + bool init_out = true; + ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, Sqrt, init_out}; + int m = params.m; int n = params.n; int k = params.k; int num_groups = params.num_groups; - auto out_view = raft::make_device_vector_view(out, m); auto x_view = raft::make_device_matrix_view(x.data(), m, k); auto y_view = raft::make_device_matrix_view(y.data(), n, k); auto x_norm = raft::make_device_vector_view(xn.data(), m); auto y_norm = raft::make_device_vector_view(yn.data(), n); auto adj_view = raft::make_device_matrix_view(adj.data(), m, num_groups); auto group_idxs_view = raft::make_device_vector_view(group_idxs.data(), num_groups); + auto out_view = raft::make_device_vector_view(out, m); - MinAndDistanceReduceOp redOp; - using IdxT = int; - - maskedL2NN, IdxT>( + maskedL2NN( handle, - out_view, + masked_l2_params, x_view, y_view, x_norm, y_norm, adj_view, group_idxs_view, - redOp, - raft::distance::KVPMinReduce(), - Sqrt, - true); + out_view); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } }; From 7df1fd2b3988eea235fbadbfc8074335ab97b85d Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 20:12:44 +0100 Subject: [PATCH 14/49] Clean up minor nits in tests and benchmarks --- cpp/bench/distance/masked_l2_nn.cu | 2 +- cpp/test/distance/masked_l2_nn.cu | 14 +++++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/cpp/bench/distance/masked_l2_nn.cu b/cpp/bench/distance/masked_l2_nn.cu index 69f4487500..33a50ab633 100644 --- a/cpp/bench/distance/masked_l2_nn.cu +++ b/cpp/bench/distance/masked_l2_nn.cu @@ -34,7 +34,7 @@ #include #endif -namespace raft::bench::spatial::masked { +namespace raft::bench::distance::masked_nn { // Introduce various sparsity patterns enum SparsityPattern { diff --git a/cpp/test/distance/masked_l2_nn.cu b/cpp/test/distance/masked_l2_nn.cu index e92eed3e9b..baf5947770 100644 --- a/cpp/test/distance/masked_l2_nn.cu +++ b/cpp/test/distance/masked_l2_nn.cu @@ -28,9 +28,7 @@ #include #include -namespace raft { -namespace distance { -namespace masked_l2_nn { +namespace raft::distance::masked_l2_nn { template struct RaftKVPMinReduce { @@ -238,7 +236,7 @@ class MaskedL2NNTest : public ::testing::TestWithParam> { generateGoldenResult(); raft::linalg::rowNorm(xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream); raft::linalg::rowNorm(yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } protected: @@ -310,7 +308,7 @@ class MaskedL2NNTest : public ::testing::TestWithParam> { group_idxs_view, out_view); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } }; @@ -442,7 +440,7 @@ class MaskedL2NNDetTest : public MaskedL2NNTest { MaskedL2NNTest::SetUp(); int m = this->params.m; min1.resize(m, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + handle.sync_stream(stream); } void TearDown() override { MaskedL2NNTest::TearDown(); } @@ -500,6 +498,4 @@ TEST_P(MaskedL2NNDetTestD_Sqrt, Result) } INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestD_Sqrt, ::testing::ValuesIn(inputsd)); -} // end namespace masked_l2_nn -} // end namespace distance -} // end namespace raft +} // end namespace raft::distance::masked_l2_nn From b5906849f1d67174be247641bc028b4e1da67de0 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 20:25:55 +0100 Subject: [PATCH 15/49] Update copyright years --- cpp/bench/distance/masked_l2_nn.cu | 2 +- cpp/include/raft/distance/detail/compress_to_bits.cuh | 2 +- cpp/include/raft/distance/detail/masked_distance_base.cuh | 2 +- cpp/include/raft/distance/detail/masked_l2_nn.cuh | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/bench/distance/masked_l2_nn.cu b/cpp/bench/distance/masked_l2_nn.cu index 33a50ab633..eba69e3511 100644 --- a/cpp/bench/distance/masked_l2_nn.cu +++ b/cpp/bench/distance/masked_l2_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh index e9a60154a3..444c7b005e 100644 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index 4112916c71..661bf5a86b 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/masked_l2_nn.cuh b/cpp/include/raft/distance/detail/masked_l2_nn.cuh index e9226e0180..0e5c0ddb4e 100644 --- a/cpp/include/raft/distance/detail/masked_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 9fcc24d88d3fe279acc530e62fd43ffeae225791 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 20:37:50 +0100 Subject: [PATCH 16/49] Rename masked_l2_nn => masked_nn --- cpp/bench/distance/{masked_l2_nn.cu => masked_nn.cu} | 2 +- .../raft/distance/detail/{masked_l2_nn.cuh => masked_nn.cuh} | 0 cpp/include/raft/distance/{masked_l2_nn.cuh => masked_nn.cuh} | 2 +- cpp/test/distance/{masked_l2_nn.cu => masked_nn.cu} | 4 ++-- 4 files changed, 4 insertions(+), 4 deletions(-) rename cpp/bench/distance/{masked_l2_nn.cu => masked_nn.cu} (99%) rename cpp/include/raft/distance/detail/{masked_l2_nn.cuh => masked_nn.cuh} (100%) rename cpp/include/raft/distance/{masked_l2_nn.cuh => masked_nn.cuh} (99%) rename cpp/test/distance/{masked_l2_nn.cu => masked_nn.cu} (99%) diff --git a/cpp/bench/distance/masked_l2_nn.cu b/cpp/bench/distance/masked_nn.cu similarity index 99% rename from cpp/bench/distance/masked_l2_nn.cu rename to cpp/bench/distance/masked_nn.cu index eba69e3511..92bf350437 100644 --- a/cpp/bench/distance/masked_l2_nn.cu +++ b/cpp/bench/distance/masked_nn.cu @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/cpp/include/raft/distance/detail/masked_l2_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh similarity index 100% rename from cpp/include/raft/distance/detail/masked_l2_nn.cuh rename to cpp/include/raft/distance/detail/masked_nn.cuh diff --git a/cpp/include/raft/distance/masked_l2_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh similarity index 99% rename from cpp/include/raft/distance/masked_l2_nn.cuh rename to cpp/include/raft/distance/masked_nn.cuh index 60380e7493..bbabbe8d84 100644 --- a/cpp/include/raft/distance/masked_l2_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -20,7 +20,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/cpp/test/distance/masked_l2_nn.cu b/cpp/test/distance/masked_nn.cu similarity index 99% rename from cpp/test/distance/masked_l2_nn.cu rename to cpp/test/distance/masked_nn.cu index baf5947770..de45e3b5e8 100644 --- a/cpp/test/distance/masked_l2_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -18,8 +18,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include From 953b33db37d04ebf89cd725a347c930f47262a30 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 20:38:11 +0100 Subject: [PATCH 17/49] cmake: rename masked_l2_nn => masked_nn --- cpp/bench/CMakeLists.txt | 2 +- cpp/test/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index b167305dad..05c85bedee 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -82,7 +82,7 @@ if(BUILD_BENCH) bench/distance/distance_l1.cu bench/distance/distance_unexp_l2.cu bench/distance/fused_l2_nn.cu - bench/spatial/masked_l2_nn.cu + bench/distance/masked_nn.cu bench/distance/kernels.cu bench/main.cpp OPTIONAL diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ee8359d9ce..ddce9b0ee7 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -121,7 +121,7 @@ if(BUILD_TESTS) test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu - test/distance/masked_l2_nn.cu + test/distance/masked_nn.cu test/distance/gram.cu OPTIONAL DIST From 724e874428aab2146ea2d4e4899a0558f2a6a899 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 20:54:56 +0100 Subject: [PATCH 18/49] Clang-format --- cpp/bench/distance/masked_nn.cu | 20 +++++++------- .../raft/distance/detail/masked_nn.cuh | 3 ++- cpp/include/raft/distance/masked_nn.cuh | 2 +- cpp/test/distance/masked_nn.cu | 27 +++++++++---------- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/cpp/bench/distance/masked_nn.cu b/cpp/bench/distance/masked_nn.cu index 92bf350437..073b672e59 100644 --- a/cpp/bench/distance/masked_nn.cu +++ b/cpp/bench/distance/masked_nn.cu @@ -85,12 +85,12 @@ __global__ void init_adj( template struct masked_l2_nn : public fixture { - using DataT = T; - using IdxT = int; - using OutT = raft::KeyValuePair; - using RedOpT = raft::distance::MinAndDistanceReduceOp; + using DataT = T; + using IdxT = int; + using OutT = raft::KeyValuePair; + using RedOpT = raft::distance::MinAndDistanceReduceOp; using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = raft::distance::MaskedL2NNParams; + using ParamT = raft::distance::MaskedL2NNParams; // Parameters masked_l2_nn_inputs params; @@ -115,8 +115,10 @@ struct masked_l2_nn : public fixture { uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0)); uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0)); - raft::linalg::rowNorm(xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); - raft::linalg::rowNorm(yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm( + xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm( + yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); raft::distance::initialize, int>( handle, out.data_handle(), p.m, std::numeric_limits::max(), RedOpT{}); @@ -130,7 +132,7 @@ struct masked_l2_nn : public fixture { void run_benchmark(::benchmark::State& state) override { bool init_out = false; - bool sqrt = false; + bool sqrt = false; ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; loop_on_state(state, [this, masked_l2_params]() { @@ -193,4 +195,4 @@ const std::vector masked_l2_nn_input_vecs = { RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); // Do not benchmark double. -} // namespace raft::bench::spatial::masked +} // namespace raft::bench::distance::masked_nn diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 0e5c0ddb4e..7079774226 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -206,7 +206,8 @@ void maskedL2NNImpl(raft::handle_t& handle, static_assert(P::Mblk == 64, "maskedL2NNImpl only supports a policy with 64 rows per block."); // Get stream and workspace memory resource - rmm::mr::device_memory_resource * ws_mr = dynamic_cast(handle.get_workspace_resource()); + rmm::mr::device_memory_resource* ws_mr = + dynamic_cast(handle.get_workspace_resource()); auto stream = handle.get_stream(); // Acquire temporary buffers and initialize to zero: diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index bbabbe8d84..ce1c9ba52c 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -39,7 +39,7 @@ namespace distance { * * In addition, prescribes whether to compute the square root of the distance * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). - */ + */ template struct MaskedL2NNParams { /** Reduction operator in the epilogue */ diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index de45e3b5e8..4d25d2dde2 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -275,11 +275,11 @@ class MaskedL2NNTest : public ::testing::TestWithParam> { void runTest(raft::KeyValuePair* out) { - using IdxT = int; - using OutT = raft::KeyValuePair; - using RedOpT = MinAndDistanceReduceOp; + using IdxT = int; + using OutT = raft::KeyValuePair; + using RedOpT = MinAndDistanceReduceOp; using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = MaskedL2NNParams; + using ParamT = MaskedL2NNParams; bool init_out = true; ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, Sqrt, init_out}; @@ -297,16 +297,15 @@ class MaskedL2NNTest : public ::testing::TestWithParam> { auto group_idxs_view = raft::make_device_vector_view(group_idxs.data(), num_groups); auto out_view = raft::make_device_vector_view(out, m); - maskedL2NN( - handle, - masked_l2_params, - x_view, - y_view, - x_norm, - y_norm, - adj_view, - group_idxs_view, - out_view); + maskedL2NN(handle, + masked_l2_params, + x_view, + y_view, + x_norm, + y_norm, + adj_view, + group_idxs_view, + out_view); handle.sync_stream(stream); } From 4ea41c43d7fd4340a1316b3226d5245da175ade4 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 23 Jan 2023 21:58:11 +0100 Subject: [PATCH 19/49] Fix docs for masked NN --- cpp/include/raft/distance/masked_nn.cuh | 35 ++++++++++++++++++---- docs/source/cpp_api/distance.rst | 1 + docs/source/cpp_api/distance_masked_nn.rst | 16 ++++++++++ 3 files changed, 47 insertions(+), 5 deletions(-) create mode 100644 docs/source/cpp_api/distance_masked_nn.rst diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index ce1c9ba52c..9f833e54e0 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -28,9 +28,32 @@ namespace raft { namespace distance { +/** + * \defgroup masked_nn Masked 1-nearest neighbors + * @{ + */ /** - * @brief Parameters for maskedL2NN function + * @brief Parameter struct for maskedL2NN function + * + * @tparam ReduceOpT Type of reduction operator in the epilogue. + * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. + * + * Usage example: + * @code{.cpp} + * #include + * + * using IdxT = int; + * using DataT = float; + * using RedOpT = raft::distance::MinAndDistanceReduceOp; + * using PairRedOpT = raft::distance::KVPMinReduce; + * using ParamT = raft::distance::MaskedL2NNParams; + * + * bool init_out = true; + * bool sqrt = false; + * + * ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; + * @endcode * * Prescribes how to reduce a distance to an intermediate type (`redOp`), and * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is @@ -91,8 +114,8 @@ struct MaskedL2NNParams { * (on device). * @param[in] y second matrix. Row major. Dim = `n x k`. * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) * @param[in] adj A boolean adjacency matrix indicating for each * row of `x` and each group in `y` whether to compute the * distance. Dim = `m x num_groups`. @@ -103,12 +126,12 @@ struct MaskedL2NNParams { * always assumed to start at index 0 and the last * group typically ends at index `n`. Length = * `num_groups`. - * @param[out] min will contain the reduced output (Length = `m`) + * @param[out] out will contain the reduced output (Length = `m`) * (on device) */ template void maskedL2NN(raft::handle_t& handle, - MaskedL2NNParams params, + raft::distance::MaskedL2NNParams params, raft::device_matrix_view const x, raft::device_matrix_view const y, raft::device_vector_view const x_norm, @@ -146,6 +169,8 @@ void maskedL2NN(raft::handle_t& handle, params.initOutBuffer); } +/** @} */ + } // namespace distance } // namespace raft diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst index eb9bc6255d..1632f19fba 100644 --- a/docs/source/cpp_api/distance.rst +++ b/docs/source/cpp_api/distance.rst @@ -25,3 +25,4 @@ namespace *raft::distance* distance_pairwise.rst distance_1nn.rst + distance_masked_nn.rst diff --git a/docs/source/cpp_api/distance_masked_nn.rst b/docs/source/cpp_api/distance_masked_nn.rst new file mode 100644 index 0000000000..89e23ba98a --- /dev/null +++ b/docs/source/cpp_api/distance_masked_nn.rst @@ -0,0 +1,16 @@ +Masked 1-Nearest Neighbors +========================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::distance* + +.. doxygengroup:: masked_nn + :project: RAFT + :members: + :content-only: + From 529764e6f925a21e59cd055c3f96ffa481e88f1c Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 11:03:28 +0100 Subject: [PATCH 20/49] Reduce iterations of deterministic test As discussed here: https://github.com/rapidsai/raft/pull/838#discussion_r1085052571 --- cpp/test/distance/fused_l2_nn.cu | 2 +- cpp/test/distance/masked_nn.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 54de12307a..f3a2442757 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -385,7 +385,7 @@ class FusedL2NNDetTest : public FusedL2NNTest { rmm::device_uvector> min1; - static const int NumRepeats = 100; + static const int NumRepeats = 3; void generateGoldenResult() override {} }; diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index 4d25d2dde2..3fd794c07b 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -450,7 +450,7 @@ class MaskedL2NNDetTest : public MaskedL2NNTest { rmm::device_uvector> min1; - static const int NumRepeats = 100; + static const int NumRepeats = 3; void generateGoldenResult() override {} }; From f4db5e52551be76be14f0c843b4f811e2ff6c036 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 21:30:55 +0100 Subject: [PATCH 21/49] Add SDDM comparison --- cpp/include/raft/distance/masked_nn.cuh | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index 9f833e54e0..2ae454ffd1 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -99,6 +99,19 @@ struct MaskedL2NNParams { * evenly divisible by `N`, then the computation is most efficient, although for * larger group sizes this effect is minor. * + * + * **Comparison to SDDM** + * + * [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense + * matrix multiplication) is a matrix-matrix multiplication where only part of + * the output is computed. Compared to maskedL2NN, there are a few differences: + * + * - The output of maskedL2NN is a single vector (of nearest neighbors) and not + * a sparse matrix. + * + * - The sampling in maskedL2NN is expressed through intermediate "groups" + rather than a CSR format. + * * @tparam DataT data type * @tparam OutT output type to either store 1-NN indices and their minimum * distances or store only the min distances. Accordingly, one From b0954e8550d7f31001000c196247bb57f302f8fa Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 21:33:27 +0100 Subject: [PATCH 22/49] Fix const and add extents checks --- .../raft/distance/detail/masked_nn.cuh | 2 +- cpp/include/raft/distance/masked_nn.cuh | 33 +++++++++++-------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 7079774226..a5985ed40f 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -184,7 +184,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, } template -void maskedL2NNImpl(raft::handle_t& handle, +void maskedL2NNImpl(const raft::handle_t& handle, OutT* min, const DataT* x, const DataT* y, diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index 2ae454ffd1..4734977943 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -143,27 +143,34 @@ struct MaskedL2NNParams { * (on device) */ template -void maskedL2NN(raft::handle_t& handle, +void maskedL2NN(const raft::handle_t& handle, raft::distance::MaskedL2NNParams params, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_vector_view const x_norm, - raft::device_vector_view const y_norm, - raft::device_matrix_view const adj, - raft::device_vector_view const group_idxs, + raft::device_matrix_view x, + raft::device_matrix_view y, + raft::device_vector_view x_norm, + raft::device_vector_view y_norm, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs, raft::device_vector_view out) { - // TODO: add more assertions. - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - IdxT m = x.extent(0); IdxT n = y.extent(0); IdxT k = x.extent(1); IdxT num_groups = group_idxs.extent(0); + // Match k dimension of x, y + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); + // Match x, x_norm and y, y_norm + RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`."); + RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` "); + // Match adj to x and group_idxs + RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`."); + RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`."); + // NOTE: We do not check if all indices in group_idxs actually points *inside* y. + + // If there is no work to be done, return immediately. + if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; } + detail::maskedL2NNImpl(handle, out.data_handle(), x.data_handle(), From 2e3ca4439ea44f900c6c54d51b3af57fe28cafb4 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 21:35:42 +0100 Subject: [PATCH 23/49] Refactor maskedL2NN test - Reuse raft::distance::KVPMinReduce operator (and remove it from the test file). - Remove sqrt template parameter. - Replace test fixture with free-standing functions for greater reusability. - Use itertools to generate cartesian product of test cases. - Minimized the determinism test (only for float and only compares 2 runs, not 100) --- .../raft/distance/detail/fused_l2_nn.cuh | 1 + cpp/test/distance/masked_nn.cu | 614 ++++++++---------- 2 files changed, 277 insertions(+), 338 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 447359ffe6..8fbd7a9c69 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -37,6 +37,7 @@ template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } }; // KVPMinReduce diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index 3fd794c07b..461eafce3d 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -16,6 +16,8 @@ #include "../test_utils.h" #include +#include +#include #include #include #include @@ -24,34 +26,70 @@ #include #include #include +#include -#include -#include - -namespace raft::distance::masked_l2_nn { +namespace raft::distance::masked_nn { -template -struct RaftKVPMinReduce { - typedef raft::KeyValuePair KVP; +// The adjacency pattern determines what distances get computed. +enum AdjacencyPattern { + checkerboard = 0, // adjacency matrix looks like a checkerboard (half the distances are computed) + checkerboard_4 = 1, // checkerboard with tiles of size 4x4 + checkerboard_64 = 2, // checkerboard with tiles of size 64x64 + all_true = 3, // no distance computations can be skipped + all_false = 4 // all distance computations can be skipped +}; - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } +// Kernels: +// - init_adj: to initialize the adjacency kernel with a specific adjacency pattern +// - referenceKernel: to produce the ground-truth output - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } +__global__ void init_adj( + int m, int n, int num_groups, AdjacencyPattern pattern, bool* adj, int* group_idxs) +{ + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_groups; i += blockDim.y * gridDim.y) { + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < m; j += blockDim.x * gridDim.x) { + switch (pattern) { + case checkerboard: adj[i * m + j] = (i + j) % 2; break; + case checkerboard_4: adj[i * m + j] = (i + (j / 4)) % 2; break; + case checkerboard_64: adj[i * m + j] = (i + (j / 64)) % 2; break; + case all_true: adj[i * m + j] = true; break; + case all_false: adj[i * m + j] = false; break; + default: assert(false && "unknown pattern"); + } + } + } + // Each group is of size n / num_groups. + // + // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive + // scan of the group lengths) + // + // - The first group always starts at index zero, so we do not store it. + // + // - The group_idxs[num_groups - 1] should always equal n. -}; // KVPMinReduce + if (blockIdx.y == 0 && threadIdx.y == 0) { + const int j_stride = blockDim.x * gridDim.x; + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_groups; j += j_stride) { + group_idxs[j] = (j + 1) * (n / num_groups); + } + group_idxs[num_groups - 1] = n; + } +} -template -__global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(raft::KeyValuePair* min, - DataT* x, - DataT* y, - bool* adj, - int* group_idxs, - int m, - int n, - int k, - int num_groups, - int* workspace, - DataT maxVal) +template +__global__ __launch_bounds__(32 * NWARPS, + 2) void referenceKernel(raft::KeyValuePair* min, + DataT* x, + DataT* y, + bool* adj, + int* group_idxs, + int m, + int n, + int k, + int num_groups, + bool sqrt, + int* workspace, + DataT maxVal) { const int m_stride = blockDim.y * gridDim.y; const int m_offset = threadIdx.y + blockIdx.y * blockDim.y; @@ -78,7 +116,7 @@ __global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(raft::KeyValuePair auto diff = x[xidx] - y[yidx]; acc += diff * diff; } - if (Sqrt) { acc = raft::sqrt(acc); } + if (sqrt) { acc = raft::sqrt(acc); } ReduceOpT redOp; typedef cub::WarpReduce> WarpReduce; __shared__ typename WarpReduce::TempStorage temp[NWARPS]; @@ -86,7 +124,7 @@ __global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(raft::KeyValuePair raft::KeyValuePair tmp; tmp.key = include_dist ? nidx : -1; tmp.value = include_dist ? acc : maxVal; - tmp = WarpReduce(temp[warpId]).Reduce(tmp, RaftKVPMinReduce()); + tmp = WarpReduce(temp[warpId]).Reduce(tmp, raft::distance::KVPMinReduce{}); if (threadIdx.x % raft::WarpSize == 0 && midx < m) { while (atomicCAS(workspace + midx, 0, 1) == 1) ; @@ -100,216 +138,148 @@ __global__ __launch_bounds__(32 * NWARPS, 2) void naiveKernel(raft::KeyValuePair } } -template -void naive(raft::KeyValuePair* min, - DataT* x, - DataT* y, - bool* adj, - int* group_idxs, - int m, - int n, - int k, - int num_groups, - int* workspace, - cudaStream_t stream) -{ - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - auto blks = raft::ceildiv(m, 256); - MinAndDistanceReduceOp op; - raft::distance::detail::initKernel, int> - <<>>(min, m, std::numeric_limits::max(), op); - RAFT_CUDA_TRY(cudaGetLastError()); +// Structs +// - Params: holds parameters for test case +// - Inputs: holds the inputs to the functions under test (x, y, adj, group_idxs). Is generated from +// the inputs. +struct Params { + double tolerance; + int m, n, k, num_groups; + bool sqrt; + unsigned long long int seed; + AdjacencyPattern pattern; +}; - const int nwarps = 16; - static const dim3 TPB(32, nwarps, 1); - dim3 nblks(1, 200, 1); - naiveKernel, nwarps><<>>( - min, x, y, adj, group_idxs, m, n, k, num_groups, workspace, std::numeric_limits::max()); - RAFT_CUDA_TRY(cudaGetLastError()); +template +inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& +{ + os << "m: " << p.m << ", n: " << p.n << ", k: " << p.k << ", num_groups: " << p.num_groups + << ", sqrt: " << p.sqrt << ", seed: " << p.seed << ", tol: " << p.tolerance; + return os; } -enum AdjacencyPattern { - checkerboard = 0, - checkerboard_4 = 1, - checkerboard_64 = 2, - all_true = 3, - all_false = 4 -}; - template struct Inputs { - DataT tolerance; - int m, n, k, num_groups; - unsigned long long int seed; + using IdxT = int; - AdjacencyPattern pattern; + raft::device_matrix x, y; + raft::device_matrix adj; + raft::device_vector group_idxs; - friend std::ostream& operator<<(std::ostream& os, const Inputs& p) + Inputs(const raft::handle_t& handle, const Params& p) + : x{raft::make_device_matrix(handle, p.m, p.k)}, + y{raft::make_device_matrix(handle, p.n, p.k)}, + adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, + group_idxs{raft::make_device_vector(handle, p.num_groups)} { - return os << "m: " << p.m - << ", " - "n: " - << p.n - << ", " - "k: " - << p.k - << ", " - "num_groups: " - << p.num_groups - << ", " - "seed: " - << p.seed - << ", " - "tol: " - << p.tolerance; + // Initialize x, y + raft::random::RngState r(p.seed); + uniform(handle, r, x.data_handle(), p.m * p.k, DataT(-1.0), DataT(1.0)); + uniform(handle, r, y.data_handle(), p.n * p.k, DataT(-1.0), DataT(1.0)); + + // Initialize adj, group_idxs. + dim3 block(32, 32); + dim3 grid(10, 10); + init_adj<<>>( + p.m, p.n, p.num_groups, p.pattern, adj.data_handle(), group_idxs.data_handle()); + RAFT_CUDA_TRY(cudaGetLastError()); } }; -__global__ void init_adj( - int m, int n, int num_groups, AdjacencyPattern pattern, bool* adj, int* group_idxs) +template > +auto reference(const raft::handle_t& handle, Inputs inp, const Params& p) + -> raft::device_vector { - for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_groups; i += blockDim.y * gridDim.y) { - for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < m; j += blockDim.x * gridDim.x) { - switch (pattern) { - case checkerboard: adj[i * m + j] = (i + j) % 2; break; - case checkerboard_4: adj[i * m + j] = (i + (j / 4)) % 2; break; - case checkerboard_64: adj[i * m + j] = (i + (j / 64)) % 2; break; - case all_true: adj[i * m + j] = true; break; - case all_false: adj[i * m + j] = false; break; - default: assert(false && "unknown pattern"); - } - } - } - // Each group is of size n / num_groups. - // - // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive - // scan of the group lengths) - // - // - The first group always starts at index zero, so we do not store it. - // - // - The group_idxs[num_groups - 1] should always equal n. + int m = inp.x.extent(0); + int n = inp.y.extent(0); + int k = inp.x.extent(1); + int num_groups = inp.group_idxs.extent(0); - if (blockIdx.y == 0 && threadIdx.y == 0) { - const int j_stride = blockDim.x * gridDim.x; - for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_groups; j += j_stride) { - group_idxs[j] = (j + 1) * (n / num_groups); - } - group_idxs[num_groups - 1] = n; + if (m == 0 || n == 0 || k == 0 || num_groups == 0) { + return raft::make_device_vector(handle, 0); } -} -template -class MaskedL2NNTest : public ::testing::TestWithParam> { - public: - MaskedL2NNTest() - : params(::testing::TestWithParam>::GetParam()), - stream(handle.get_stream()), - x(params.m * params.k, stream), - y(params.n * params.k, stream), - adj(params.m * params.num_groups, stream), - group_idxs(params.num_groups, stream), - xn(params.m, stream), - yn(params.n, stream), - min(params.m, stream), - min_ref(params.m, stream), - workspace(params.m * sizeof(int), stream) - { - } + // Initialize workspace + auto stream = handle.get_stream(); + rmm::device_uvector workspace(p.m * sizeof(int), stream); + RAFT_CUDA_TRY(cudaMemsetAsync(workspace.data(), 0, sizeof(int) * m, stream)); - protected: - void SetUp() override - { - raft::random::RngState r(params.seed); - int m = params.m; - int n = params.n; - int k = params.k; - int num_groups = params.num_groups; - uniform(handle, r, x.data(), m * k, DataT(-1.0), DataT(1.0)); - uniform(handle, r, y.data(), n * k, DataT(-1.0), DataT(1.0)); - - dim3 block(32, 32); - dim3 grid(10, 10); - init_adj<<>>( - m, n, num_groups, params.pattern, adj.data(), group_idxs.data()); - RAFT_CUDA_TRY(cudaGetLastError()); + // Initialize output + auto out = raft::make_device_vector(handle, m); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + raft::distance::detail::initKernel, int> + <<>>(out.data_handle(), m, std::numeric_limits::max(), op); + RAFT_CUDA_TRY(cudaGetLastError()); - generateGoldenResult(); - raft::linalg::rowNorm(xn.data(), x.data(), k, m, raft::linalg::L2Norm, true, stream); - raft::linalg::rowNorm(yn.data(), y.data(), k, n, raft::linalg::L2Norm, true, stream); - handle.sync_stream(stream); - } + // Launch reference kernel + const int nwarps = 16; + static const dim3 TPB(32, nwarps, 1); + dim3 nblks(1, 200, 1); + referenceKernel + <<>>(out.data_handle(), + inp.x.data_handle(), + inp.y.data_handle(), + inp.adj.data_handle(), + inp.group_idxs.data_handle(), + m, + n, + k, + num_groups, + p.sqrt, + (int*)workspace.data(), + std::numeric_limits::max()); + RAFT_CUDA_TRY(cudaGetLastError()); - protected: - Inputs params; - rmm::device_uvector x; - rmm::device_uvector y; - rmm::device_uvector adj; - rmm::device_uvector group_idxs; - rmm::device_uvector xn; - rmm::device_uvector yn; - rmm::device_uvector> min; - rmm::device_uvector> min_ref; - rmm::device_uvector workspace; - raft::handle_t handle; - cudaStream_t stream; - - virtual void generateGoldenResult() - { - int m = params.m; - int n = params.n; - int k = params.k; - int num_groups = params.num_groups; - - naive(min_ref.data(), - x.data(), - y.data(), - adj.data(), - group_idxs.data(), - m, - n, - k, - num_groups, - (int*)workspace.data(), - stream); - } + return out; +} - void runTest(raft::KeyValuePair* out) - { - using IdxT = int; - using OutT = raft::KeyValuePair; - using RedOpT = MinAndDistanceReduceOp; - using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = MaskedL2NNParams; - - bool init_out = true; - ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, Sqrt, init_out}; - - int m = params.m; - int n = params.n; - int k = params.k; - int num_groups = params.num_groups; - - auto x_view = raft::make_device_matrix_view(x.data(), m, k); - auto y_view = raft::make_device_matrix_view(y.data(), n, k); - auto x_norm = raft::make_device_vector_view(xn.data(), m); - auto y_norm = raft::make_device_vector_view(yn.data(), n); - auto adj_view = raft::make_device_matrix_view(adj.data(), m, num_groups); - auto group_idxs_view = raft::make_device_vector_view(group_idxs.data(), num_groups); - auto out_view = raft::make_device_vector_view(out, m); - - maskedL2NN(handle, - masked_l2_params, - x_view, - y_view, - x_norm, - y_norm, - adj_view, - group_idxs_view, - out_view); - - handle.sync_stream(stream); - } -}; +template > +auto run_masked_nn(const raft::handle_t& handle, Inputs inp, const Params& p) + -> raft::device_vector +{ + // Compute norms: + auto x_norm = raft::make_device_vector(handle, p.m); + auto y_norm = raft::make_device_vector(handle, p.n); + + raft::linalg::norm(handle, + std::as_const(inp.x).view(), + x_norm.view(), + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + raft::linalg::norm(handle, + std::as_const(inp.y).view(), + y_norm.view(), + raft::linalg::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + + // Create parameters for maskedL2NN + using IdxT = int; + using RedOpT = MinAndDistanceReduceOp; + using PairRedOpT = raft::distance::KVPMinReduce; + using ParamT = raft::distance::MaskedL2NNParams; + + bool init_out = true; + ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, p.sqrt, init_out}; + + // Create output + auto out = raft::make_device_vector(handle, p.m); + + // Launch kernel + raft::distance::maskedL2NN(handle, + masked_l2_params, + inp.x.view(), + inp.y.view(), + x_norm.view(), + y_norm.view(), + inp.adj.view(), + inp.group_idxs.view(), + out.view()); + + handle.sync_stream(); + + return out; +} template struct CompareApproxAbsKVP { @@ -362,139 +332,107 @@ template return ::testing::AssertionSuccess(); } -const std::vector> inputsf = { - {0.001f, 32, 32, 32, 2, 1234ULL, AdjacencyPattern::all_true}, - {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_true}, - {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_false}, - {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, - {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_4}, - {0.001f, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_64}, - {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_true}, - {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_false}, - {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard}, - {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_4}, - {0.001f, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_64}, - {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::all_true}, - {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::all_false}, - {0.001f, (1 << 15) + 19, (1 << 9) + 17, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, -}; - -typedef MaskedL2NNTest MaskedL2NNTestF_Sq; -TEST_P(MaskedL2NNTestF_Sq, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestF_Sq, ::testing::ValuesIn(inputsf)); -typedef MaskedL2NNTest MaskedL2NNTestF_Sqrt; -TEST_P(MaskedL2NNTestF_Sqrt, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestF_Sqrt, ::testing::ValuesIn(inputsf)); - -const std::vector> inputsd = { - {0.00001, 32, 32, 32, 2, 1234ULL, AdjacencyPattern::all_true}, - - {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_true}, - {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::all_false}, - {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard}, - {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_4}, - {0.00001, 512, 512, 8, 32, 1234ULL, AdjacencyPattern::checkerboard_64}, - - {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_true}, - {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::all_false}, - {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard}, - {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_4}, - {0.00001, 1 << 9, 1 << 16, 8, 1 << 9, 1234ULL, AdjacencyPattern::checkerboard_64}, -}; -typedef MaskedL2NNTest MaskedL2NNTestD_Sq; -TEST_P(MaskedL2NNTestD_Sq, Result) +inline auto gen_params() -> std::vector { - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); + // Regular powers of two + auto regular = raft::util::itertools::product({0.001f}, // tolerance + {32, 64, 512}, // m + {32, 64, 512}, // n + {8, 32}, // k + {2, 32}, // num_groups + {true, false}, // sqrt + {1234ULL}, // seed + {AdjacencyPattern::all_true, + AdjacencyPattern::checkerboard, + AdjacencyPattern::checkerboard_64, + AdjacencyPattern::all_false}); + + // Irregular sizes to check tiling and bounds checking + auto irregular = raft::util::itertools::product({0.001f}, // tolerance + {511, 512, 513}, // m + {127, 128, 129}, // n + {5}, // k + {3, 9}, // num_groups + {true, false}, // sqrt + {1234ULL}, // seed + {AdjacencyPattern::all_true, + AdjacencyPattern::checkerboard, + AdjacencyPattern::checkerboard_64}); + + regular.insert(regular.end(), irregular.begin(), irregular.end()); + + return regular; } -INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestD_Sq, ::testing::ValuesIn(inputsd)); -typedef MaskedL2NNTest MaskedL2NNTestD_Sqrt; -TEST_P(MaskedL2NNTestD_Sqrt, Result) -{ - runTest(min.data()); - ASSERT_TRUE(devArrMatch( - min_ref.data(), min.data(), params.m, CompareApproxAbsKVP(params.tolerance), stream)); -} -INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestD_Sqrt, ::testing::ValuesIn(inputsd)); - -/// This is to test output determinism of the prim -template -class MaskedL2NNDetTest : public MaskedL2NNTest { - public: - MaskedL2NNDetTest() : stream(handle.get_stream()), min1(0, stream) {} - - void SetUp() override - { - MaskedL2NNTest::SetUp(); - int m = this->params.m; - min1.resize(m, stream); - handle.sync_stream(stream); - } - - void TearDown() override { MaskedL2NNTest::TearDown(); } - - protected: - raft::handle_t handle; - cudaStream_t stream; - - rmm::device_uvector> min1; - static const int NumRepeats = 3; - - void generateGoldenResult() override {} +class MaskedL2NNTestMin : public ::testing::TestWithParam { + // Empty. }; -typedef MaskedL2NNDetTest MaskedL2NNDetTestF_Sq; -TEST_P(MaskedL2NNDetTestF_Sq, Result) -{ - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } -} -INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestF_Sq, ::testing::ValuesIn(inputsf)); -typedef MaskedL2NNDetTest MaskedL2NNDetTestF_Sqrt; -TEST_P(MaskedL2NNDetTestF_Sqrt, Result) +// +TEST_P(MaskedL2NNTestMin, Float) { - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } + using DataT = float; + + // Get parameters; create handle and input data. + Params p = GetParam(); + raft::handle_t handle{}; + Inputs inputs{handle, p}; + + // Calculate reference and test output + auto out_reference = reference(handle, inputs, p); + auto out_fast = run_masked_nn(handle, inputs, p); + + // Check for differences. + ASSERT_TRUE(devArrMatch(out_reference.data_handle(), + out_fast.data_handle(), + p.m, + CompareApproxAbsKVP(p.tolerance), + handle.get_stream())); } -INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestF_Sqrt, ::testing::ValuesIn(inputsf)); -typedef MaskedL2NNDetTest MaskedL2NNDetTestD_Sq; -TEST_P(MaskedL2NNDetTestD_Sq, Result) +// This test checks whether running the maskedL2NN twice returns the same +// output. +TEST_P(MaskedL2NNTestMin, Determinism) { - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } + using DataT = float; + + // Get parameters; create handle and input data. + Params p = GetParam(); + raft::handle_t handle{}; + Inputs inputs{handle, p}; + + // Calculate reference and test output + auto out1 = run_masked_nn(handle, inputs, p); + auto out2 = run_masked_nn(handle, inputs, p); + + // Check for differences. + ASSERT_TRUE(devArrMatch(out1.data_handle(), + out2.data_handle(), + p.m, + CompareApproxAbsKVP(p.tolerance), + handle.get_stream())); } -INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestD_Sq, ::testing::ValuesIn(inputsd)); -typedef MaskedL2NNDetTest MaskedL2NNDetTestD_Sqrt; -TEST_P(MaskedL2NNDetTestD_Sqrt, Result) + +TEST_P(MaskedL2NNTestMin, Double) { - runTest(min.data()); // assumed to be golden - for (int i = 0; i < NumRepeats; ++i) { - runTest(min1.data()); - ASSERT_TRUE(devArrMatch(min.data(), min1.data(), params.m, CompareExactKVP(), stream)); - } + using DataT = double; + + // Get parameters; create handle and input data. + Params p = GetParam(); + raft::handle_t handle{}; + Inputs inputs{handle, p}; + + // Calculate reference and test output + auto out_reference = reference(handle, inputs, p); + auto out_fast = run_masked_nn(handle, inputs, p); + + // Check for differences. + ASSERT_TRUE(devArrMatch(out_reference.data_handle(), + out_fast.data_handle(), + p.m, + CompareApproxAbsKVP(p.tolerance), + handle.get_stream())); } -INSTANTIATE_TEST_CASE_P(MaskedL2NNDetTests, MaskedL2NNDetTestD_Sqrt, ::testing::ValuesIn(inputsd)); +INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestMin, ::testing::ValuesIn(gen_params())); -} // end namespace raft::distance::masked_l2_nn +} // end namespace raft::distance::masked_nn From db6288c86b1efd933e7f83d82432a28c199cd172 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 21:45:43 +0100 Subject: [PATCH 24/49] wording: grouped -> processed --- cpp/include/raft/distance/masked_nn.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index 4734977943..1f47c935f2 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -87,7 +87,7 @@ struct MaskedL2NNParams { * * **Performance considerations** * - * The points in `x` are grouped into tiles of `M` points (`M` is currently 64, + * The points in `x` are processed in tiles of `M` points (`M` is currently 64, * but may change in the future). As a result, the largest compute time * reduction occurs if all `M` points can skip a group. If only part of the `M` * points can skip a group, then at most a minor compute time reduction and a From 773ab2ff987ab6943c9e550332ae06c2e7f3702b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 22:39:17 +0100 Subject: [PATCH 25/49] Docstring changes and removal of extraneous function --- .../distance/detail/masked_distance_base.cuh | 127 +++++------------- .../raft/distance/detail/masked_nn.cuh | 1 - 2 files changed, 31 insertions(+), 97 deletions(-) diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index 661bf5a86b..eaaeab99aa 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,12 +25,12 @@ namespace distance { namespace detail { /** - * @brief Device class for L1, L2 and cosine distance metrics. + * @brief Device class for masked nearest neighbor computations. + * * @tparam useNorms whether norms are needed - * @tparam DataT input data-type (for A and B matrices) + * @tparam DataT input data-type (for x and y matrices) * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type + * @tparam IdxT index data-type * @tparam Policy struct which tunes the Contraction kernel * @tparam CoreLambda tells how to accumulate an x and y into acc. its signature: @@ -41,27 +41,39 @@ namespace detail { template void epilogue_lambda (AccT acc[][], DataT* regxn, DataT* regyn); * @tparam FinalLambda the final lambda called on final distance value + * @tparam rowEpilogueLambda epilog lambda that executes when a full row has + * been processed. + * * @param[in] x input matrix * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D + * @param[in] m number of rows of x + * @param[in] n number of columns of y + * @param[in] k number of cols of x and y + * @param[in] lda leading dimension of x + * @param[in] ldb leading dimension of y + * @param[in] ldd parameter to keep Contractions_NT happy.. * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine - * @param[output] pD output matrix - * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. + * @param[in] adj A boolean adjacency matrix indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `m x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups The number of groups in group_idxs. + * @param[in] smem shared mem buffer for intermediate storage of x, y, xn & yn. * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed. */ - template switch_read_buffer(); if (useNorms) { @@ -196,7 +207,7 @@ struct MaskedDistances : public BaseClass { } // tile_idx_n } // idx_g rowEpilog_op(tile_idx_m); - } // tile_idx_n + } // tile_idx_m } private: @@ -280,82 +291,6 @@ struct MaskedDistances : public BaseClass { } }; // struct MaskedDistances -/** - * @brief the distance matrix calculation kernel for L1, L2 and cosine - * @tparam useNorms whether norms are needed - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda lambda which implements accumulation operation - * @tparam EpilogueLambda lambda which implements operation for calculating - final value. - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major(default), - false for column major - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] xn row norms of input matrix A. - * @param[in] yn row norms of input matrix B. - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix - * @param core_op the core lambda - * @param epilog_op the epilogue lambda - * @param fin_op the final gemm epilogue lambda - */ - -template -__global__ __launch_bounds__(Policy::Nthreads, 2) - - void maskedDistanceMatKernel(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - const bool* adj, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - CoreLambda core_op, - EpilogueLambda epilog_op, - FinalLambda fin_op) -{ - extern __shared__ char smem[]; - auto rowEpilog = [] __device__(IdxT starty) { return; }; - - MaskedDistances - obj(x, y, m, n, k, lda, ldb, ldd, _xn, _yn, smem, core_op, epilog_op, fin_op, rowEpilog); - obj.run(); -} - }; // namespace detail }; // namespace distance }; // namespace raft diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index a5985ed40f..0f8e3039aa 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -152,7 +152,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, IdxT lda = k, ldb = k, ldd = n; MaskedDistances Date: Tue, 24 Jan 2023 22:49:49 +0100 Subject: [PATCH 26/49] Move sqrt from template to run-time parameter There is barely any impact on performance (less than 0.1%) by having sqrt be a run-time parameter. It does speed up compilation time though. --- .../raft/distance/detail/masked_nn.cuh | 91 ++++++------------- 1 file changed, 30 insertions(+), 61 deletions(-) diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 0f8e3039aa..e7ac861e76 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -38,7 +38,6 @@ using namespace nvcuda::experimental; template ; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, maskedL2NNSqrt); - - maskedL2NNSqrt<<>>(min, - x, - y, - xn, - yn, - ws_adj64.data(), - group_idxs, - num_groups, - m, - n, - k, - maxVal, - ws_fused_nn.data(), - redOp, - pairRedOp, - core_lambda, - fin_op); - } else { - auto maskedL2NN = maskedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, maskedL2NN); - maskedL2NN<<>>(min, - x, - y, - xn, - yn, - ws_adj64.data(), - group_idxs, - num_groups, - m, - n, - k, - maxVal, - ws_fused_nn.data(), - redOp, - pairRedOp, - core_lambda, - fin_op); - } + auto maskedL2NN = maskedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, maskedL2NN); + maskedL2NN<<>>(min, + x, + y, + xn, + yn, + ws_adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + sqrt, + maxVal, + ws_fused_nn.data(), + redOp, + pairRedOp, + core_lambda, + fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } From 0629446dd348570ecced01ee790d12a11a92eb4c Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 23:14:58 +0100 Subject: [PATCH 27/49] Remove extraneous memcpy::async stuff --- cpp/include/raft/distance/detail/masked_nn.cuh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index e7ac861e76..fe4c8116d1 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -30,11 +30,6 @@ namespace raft { namespace distance { namespace detail { -#if (ENABLE_MEMCPY_ASYNC == 1) -#include -using namespace nvcuda::experimental; -#endif - template Date: Tue, 24 Jan 2023 23:15:55 +0100 Subject: [PATCH 28/49] Implement half the reviewer feedback on mnn_base --- cpp/bench/distance/masked_nn.cu | 3 ++- .../raft/distance/detail/masked_distance_base.cuh | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cpp/bench/distance/masked_nn.cu b/cpp/bench/distance/masked_nn.cu index 073b672e59..3b163e95de 100644 --- a/cpp/bench/distance/masked_nn.cu +++ b/cpp/bench/distance/masked_nn.cu @@ -193,6 +193,7 @@ const std::vector masked_l2_nn_input_vecs = { }; RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); -// Do not benchmark double. +// We don't benchmark double to keep compile times in check when not using the +// distance library. } // namespace raft::bench::distance::masked_nn diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index eaaeab99aa..fc1a374765 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -54,9 +54,9 @@ namespace detail { * @param[in] ldd parameter to keep Contractions_NT happy.. * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine - * @param[in] adj A boolean adjacency matrix indicating for each + * @param[in] adj An adjacency matrix encoded as a bitfield indicating for each * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. + * distance. Dim = `(m / 64) x num_groups`. * @param[in] group_idxs An array containing the *end* indices of each group * in `y`. The value of group_idxs[j] indicates the * start of group j + 1, i.e., it is the inclusive @@ -145,10 +145,6 @@ struct MaskedDistances : public BaseClass { for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { // Start loop over groups for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { - // The __syncthreads() ensures that loading the block flag occurs at - // the same time in all threads of the block. Since all threads load - // the same address, this speeds up the code. - __syncthreads(); const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); // block_adj is a bitfield that contains a 1 if a row is adjacent to the // current group. All zero means we can skip this group. @@ -213,7 +209,11 @@ struct MaskedDistances : public BaseClass { private: DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) { + // A single element of `adj` contains exactly enough bits to indicate which + // rows in the current tile to skip and which to compute. + static_assert(P::Mblk == 8 * sizeof(adj[0]), "maskedL2NN only supports a policy with 64 rows per block."); IdxT block_flag_idx = tile_idx_m / P::Mblk; + // Index into adj at row tile_idx_m / 64 and column idx_group. return adj[block_flag_idx * this->num_groups + idx_group]; } From 6243c8afe83f28193942bc826dd9642718b9aeb7 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 23:20:23 +0100 Subject: [PATCH 29/49] Formatting --- cpp/include/raft/distance/detail/masked_distance_base.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index fc1a374765..c6844bec67 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -211,7 +211,8 @@ struct MaskedDistances : public BaseClass { { // A single element of `adj` contains exactly enough bits to indicate which // rows in the current tile to skip and which to compute. - static_assert(P::Mblk == 8 * sizeof(adj[0]), "maskedL2NN only supports a policy with 64 rows per block."); + static_assert(P::Mblk == 8 * sizeof(adj[0]), + "maskedL2NN only supports a policy with 64 rows per block."); IdxT block_flag_idx = tile_idx_m / P::Mblk; // Index into adj at row tile_idx_m / 64 and column idx_group. return adj[block_flag_idx * this->num_groups + idx_group]; From 4dc0c370f35a3dd440659d3083f778c2a281e1f8 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 24 Jan 2023 23:23:23 +0100 Subject: [PATCH 30/49] Fix copyright years These files are released for the first time in 2023. Copyright header has been adjusted accordingly. --- cpp/bench/distance/masked_nn.cu | 2 +- cpp/include/raft/distance/detail/compress_to_bits.cuh | 2 +- cpp/include/raft/distance/detail/masked_nn.cuh | 2 +- cpp/include/raft/distance/masked_nn.cuh | 2 +- cpp/test/distance/masked_nn.cu | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/bench/distance/masked_nn.cu b/cpp/bench/distance/masked_nn.cu index 3b163e95de..cef91f8daf 100644 --- a/cpp/bench/distance/masked_nn.cu +++ b/cpp/bench/distance/masked_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh index 444c7b005e..aa4bc27d39 100644 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index fe4c8116d1..c139f97e59 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index 1f47c935f2..9f43870727 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index 461eafce3d..0bcd7d777c 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From b84a5b4809e852d93c1cff2633f07616dfb13228 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 09:22:58 +0100 Subject: [PATCH 31/49] Reword comment --- cpp/include/raft/distance/masked_nn.cuh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index 9f43870727..f5b1839d1b 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -81,9 +81,11 @@ struct MaskedL2NNParams { * This function enables faster computation of nearest neighbors if the * computation of distances between certain point pairs can be skipped. * - * To avoid using a full adjacency matrix between all points in `x` and `y`, the - * points in `y` are divided into groups. An adjacency matrix describes for each - * point in `x` and each group whether to compute the distance. + * We use an adjacency matrix that describes which distances to calculate. The + * points in `y` are divided into groups, and the adjacency matrix indicates + * whether to compute distances between points in `x` and groups in `y`. In other + * words, if `adj[i,k]` is true then distance between point `x_i`, and points in + * `group_k` will be calculated. * * **Performance considerations** * From 7fb7cc9bb925a8a485ed4f43a8a1ed5b4d43bc52 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 10:00:40 +0100 Subject: [PATCH 32/49] test: Remove redundant comparison operator --- cpp/test/distance/masked_nn.cu | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index 0bcd7d777c..ee292bf855 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -297,16 +297,6 @@ struct CompareApproxAbsKVP { T eps; }; -template -struct CompareExactKVP { - typedef typename raft::KeyValuePair KVP; - bool operator()(const KVP& a, const KVP& b) const - { - if (a.value != b.value) return false; - return true; - } -}; - template ::testing::AssertionResult devArrMatch(const raft::KeyValuePair* expected, const raft::KeyValuePair* actual, @@ -364,12 +354,12 @@ inline auto gen_params() -> std::vector return regular; } -class MaskedL2NNTestMin : public ::testing::TestWithParam { +class MaskedL2NNTest : public ::testing::TestWithParam { // Empty. }; // -TEST_P(MaskedL2NNTestMin, Float) +TEST_P(MaskedL2NNTest, ReferenceCheckFloat) { using DataT = float; @@ -392,7 +382,7 @@ TEST_P(MaskedL2NNTestMin, Float) // This test checks whether running the maskedL2NN twice returns the same // output. -TEST_P(MaskedL2NNTestMin, Determinism) +TEST_P(MaskedL2NNTest, DeterminismCheck) { using DataT = float; @@ -413,7 +403,7 @@ TEST_P(MaskedL2NNTestMin, Determinism) handle.get_stream())); } -TEST_P(MaskedL2NNTestMin, Double) +TEST_P(MaskedL2NNTest, ReferenceCheckDouble) { using DataT = double; @@ -433,6 +423,7 @@ TEST_P(MaskedL2NNTestMin, Double) CompareApproxAbsKVP(p.tolerance), handle.get_stream())); } -INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTestMin, ::testing::ValuesIn(gen_params())); + +INSTANTIATE_TEST_CASE_P(MaskedL2NNTests, MaskedL2NNTest, ::testing::ValuesIn(gen_params())); } // end namespace raft::distance::masked_nn From fe31a132a81afbf7a7e272094c4be47b927581ab Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 11:16:57 +0100 Subject: [PATCH 33/49] Add sphinx docsfor maskedL2NNImpl --- .../raft/distance/detail/masked_nn.cuh | 112 +++++++++++++----- cpp/include/raft/distance/masked_nn.cuh | 4 +- 2 files changed, 84 insertions(+), 32 deletions(-) diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index c139f97e59..5171ffc686 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -177,9 +177,59 @@ __global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, obj.run(); } + +/** + * @brief Wrapper for maskedL2NNkernel + * + * Responsibilities: + * - Allocate (and initialize) workspace memory for: + * - mutexes used in nearest neighbor update step + * - adjacency matrix bitfield + * - Compress adjacency matrix to bitfield + * - Initialize output buffer (conditional on `initOutBuffer`) + * - Specify core and final operations for the L2 norm + * - Determine optimal launch configuration for kernel. + * - Launch kernel and check for errors. + * + * @tparam DataT Input data-type (for x and y matrices). + * @tparam OutT Output data-type (for key-value pairs). + * @tparam IdxT Index data-type. + * @tparam ReduceOpT A struct to perform the final needed reduction + * operation and also to initialize the output array + * elements with the appropriate initial value needed for + * reduction. + * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. + * + * @param handle RAFT handle for managing expensive resources + * @param[out] out Will contain reduced output (nn key-value pairs) + * @param[in] x First matrix. Row major. Dim = `m x k`. (on device) + * @param[in] y Second matrix. Row major. Dim = `n x k`. (on device) + * @param[in] xn L2 squared norm of `x`. Length = `m`. + * @param[in] yn L2 squared norm of `y`. Length = `n`. + * @param[in] adj A boolean adjacency matrix indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `m x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups Length of `group_idxs`. + * @param m Rows of `x`. + * @param n Rows of `y`. + * @param k Cols of `x` and `y`. + * @param redOp Reduction operator in the epilogue + * @param pairRedOp Reduction operation on key value pairs + * @param sqrt Whether to compute the squared or actual (i.e. sqrt) L2 norm. + * @param initOutBuffer Whether to initialize the output buffer + * + * + */ template void maskedL2NNImpl(const raft::handle_t& handle, - OutT* min, + OutT* out, const DataT* x, const DataT* y, const DataT* xn, @@ -218,24 +268,23 @@ void maskedL2NNImpl(const raft::handle_t& handle, compress_to_bits_naive<<>>( adj, num_groups, m, ws_adj64.data()); - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); + // Initialize output buffer with keyvalue pairs as determined by the reduction + // operator (it will be called with maxVal). constexpr auto maxVal = std::numeric_limits::max(); - typedef raft::KeyValuePair KVPair; - - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - if (initOutBuffer) { + dim3 grid(raft::ceildiv(m, P::Nthreads)); + dim3 block(P::Nthreads); + initKernel - <<>>(min, m, maxVal, redOp); + <<>>(out, m, maxVal, redOp); RAFT_CUDA_TRY(cudaGetLastError()); } + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; auto fin_op = raft::identity_op{}; - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - auto maskedL2NN = maskedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, maskedL2NN); - maskedL2NN<<>>(min, - x, - y, - xn, - yn, - ws_adj64.data(), - group_idxs, - num_groups, - m, - n, - k, - sqrt, - maxVal, - ws_fused_nn.data(), - redOp, - pairRedOp, - core_lambda, - fin_op); + constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 block(P::Nthreads); + dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); + + kernel<<>>(out, + x, + y, + xn, + yn, + ws_adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + sqrt, + maxVal, + ws_fused_nn.data(), + redOp, + pairRedOp, + core_lambda, + fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index f5b1839d1b..17c11630ab 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -125,9 +125,9 @@ struct MaskedL2NNParams { * * @param handle RAFT handle for managing expensive resources * @param params Parameter struct specifying the reduction operations. - * @param[in] x first matrix. Row major. Dim = `m x k`. + * @param[in] x First matrix. Row major. Dim = `m x k`. * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. + * @param[in] y Second matrix. Row major. Dim = `n x k`. * (on device). * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) From 3bf204f1acfdd04c1b70e08d4407b2c2e50df753 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 11:54:07 +0100 Subject: [PATCH 34/49] Document thread_adj --- .../distance/detail/masked_distance_base.cuh | 49 +++++++++++++++---- .../raft/distance/detail/masked_nn.cuh | 7 ++- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index c6844bec67..8ccda9d184 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -150,13 +150,27 @@ struct MaskedDistances : public BaseClass { // current group. All zero means we can skip this group. if (block_adj == 0) { continue; } - // Determine which results, that are computed by this thread, have to - // be taken into account. This information is stored in a bitfield, - // thread_adj. If all results computed by this thread can be ignored, - // then we can also skip some computations (thread_adj == 0). - + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). That is, + // for i = 0,.., AccRowsPerTh and j = 0,.., AccColsPerTh: + // + // ((1 << i) & thread_adj) > 0 <=> acc[i][j] must be computed. + // // We precompute this information because it is used in various - // locations to skip thread-local computations. + // locations to skip thread-local computations, specifically: + // + // 1. To skip computations if thread_adj == 0, i.e., none of the values + // of `acc` have to be computed. + // + // 2. In epilog_op, to consider only values of `acc` to be reduced that + // are not masked of. + // + // Note 1: Even when the computation can be skipped for a specific thread, + // the thread still participates in synchronization operations. + // + // Note 2: In theory, it should be possible to skip computations for + // specific rows of `acc`. In practice, however, this does not improve + // performance. int thread_adj = compute_thread_adjacency(block_adj); auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; @@ -220,12 +234,27 @@ struct MaskedDistances : public BaseClass { DI uint32_t compute_thread_adjacency(const uint64_t block_adj) { + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). It is described in + // more detail in the run() method. uint32_t thread_adj = 0; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - const uint64_t read_mask = 1ull << (this->accrowid + i * P::AccThRows); - const uint32_t write_mask = 1 << i; - if ((block_adj & read_mask) != 0) { thread_adj |= write_mask; } + for (int thread_row_idx = 0; thread_row_idx < P::AccRowsPerTh; ++thread_row_idx) { + // Index `thread_row_idx` refers to a row of the current threads' register + // tile `acc`, i.e., acc[i][:]. Index `block_row_idx` refers to the + // corresponding row of the current block tile in shared memory. + const int block_row_idx = this->accrowid + thread_row_idx * P::AccThRows; + + // block_row_is_adjacent is true if the current block_row_idx is adjacent + // to the current group. + const uint64_t block_mask = 1ull << block_row_idx; + const bool block_row_is_adjacent = (block_adj & block_mask) != 0; + if (block_row_is_adjacent) { + // If block row is adjacent, write a 1 bit to thread_adj at location + // `thread_row_idx`. + const uint32_t thread_mask = 1 << thread_row_idx; + thread_adj |= thread_mask; + } } return thread_adj; } diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 5171ffc686..96d61896ea 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -69,7 +69,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, // epilogue operation lambda for final value calculation auto epilog_lambda = [pairRedOp, &val, maxVal, sqrt] __device__( DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - int acc_adj, + int thread_adj, DataT* regxn, DataT* regyn, IdxT tile_idx_n, @@ -100,7 +100,10 @@ __global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - const bool ignore = (acc_adj & (1 << i)) == 0; + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). It is described in + // more detail in the maskedDistances.run() method. + const bool ignore = (thread_adj & (1 << i)) == 0; if (ignore) { continue; } #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { From 08448d8e7d23466e91e2acab2bb24b8ac759af38 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 12:07:32 +0100 Subject: [PATCH 35/49] Rename tile_end_n -> group_end_n I realized that I should also rename tile_end_n to end_n in the contractions base class. This lead to a simplification that made it possible to remove quite some duplication in ldgY. Specifically, we now call the ldgY method with `end_n` parameter from the ldgY method without `end_n` parameter, like so: ldgY(IdxT tile_idx_n, IdxT kidx) { ldgY(tile_idx_n, kidx, n); } --- .../distance/detail/masked_distance_base.cuh | 24 +++++------ .../raft/linalg/detail/contractions.cuh | 42 ++----------------- 2 files changed, 16 insertions(+), 50 deletions(-) diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index 8ccda9d184..6d4e3f40a6 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -173,12 +173,12 @@ struct MaskedDistances : public BaseClass { // performance. int thread_adj = compute_thread_adjacency(block_adj); - auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; - const auto tile_end_n = group_idxs[idx_g]; - for (; tile_idx_n < tile_end_n; tile_idx_n += P::Nblk) { - // We provide tile_end_n to limit the number of unnecessary data + auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; + const auto group_end_n = group_idxs[idx_g]; + for (; tile_idx_n < group_end_n; tile_idx_n += P::Nblk) { + // We provide group_end_n to limit the number of unnecessary data // points that are loaded from y. - this->ldgXY(tile_idx_m, tile_idx_n, 0, tile_end_n); + this->ldgXY(tile_idx_m, tile_idx_n, 0, group_end_n); reset_accumulator(); this->stsXY(); @@ -186,7 +186,7 @@ struct MaskedDistances : public BaseClass { this->switch_write_buffer(); for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx, tile_end_n); + this->ldgXY(tile_idx_m, tile_idx_n, kidx, group_end_n); // Process all data in shared memory (previous k-block) and // accumulate in registers. if (thread_adj != 0) { accumulate(); } @@ -205,13 +205,13 @@ struct MaskedDistances : public BaseClass { if (useNorms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; - load_norms(tile_idx_m, tile_idx_n, tile_end_n, regxn, regyn); + load_norms(tile_idx_m, tile_idx_n, group_end_n, regxn, regyn); if (thread_adj != 0) { - epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, tile_end_n); + epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, group_end_n); } } else { if (thread_adj != 0) { - epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, tile_end_n); + epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, group_end_n); } } } // tile_idx_n @@ -247,7 +247,7 @@ struct MaskedDistances : public BaseClass { // block_row_is_adjacent is true if the current block_row_idx is adjacent // to the current group. - const uint64_t block_mask = 1ull << block_row_idx; + const uint64_t block_mask = 1ull << block_row_idx; const bool block_row_is_adjacent = (block_adj & block_mask) != 0; if (block_row_is_adjacent) { // If block row is adjacent, write a 1 bit to thread_adj at location @@ -291,7 +291,7 @@ struct MaskedDistances : public BaseClass { DI void load_norms(IdxT tile_idx_m, IdxT tile_idx_n, - IdxT tile_end_n, + IdxT end_n, DataT (®xn)[P::AccRowsPerTh], DataT (®yn)[P::AccColsPerTh]) { @@ -306,7 +306,7 @@ struct MaskedDistances : public BaseClass { for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { auto idx = tile_idx_n + i; - syNorm[i] = idx < tile_end_n ? yn[idx] : 0; + syNorm[i] = idx < end_n ? yn[idx] : 0; } __syncthreads(); diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 4c5a43cd57..3093947b19 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -218,49 +218,15 @@ struct Contractions_NT { } } - DI void ldgY(IdxT tile_idx_n, IdxT kidx) - { - IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; - auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; - - if (isRowMajor) { - auto numRows = n; - auto koffset = kidx + scolid; -#pragma unroll - for (int i = 0; i < P::LdgPerThY; ++i) { - if (koffset < ldb && (yrowid + i * P::LdgRowsY) < numRows) { - ldg(ldgDataY[i], y + i * P::LdgRowsY * ldb + koffset); - } else { -#pragma unroll - for (int j = 0; j < P::Veclen; ++j) { - ldgDataY[i][j] = Zero; - } - } - } - } else { - auto numRows = k; - auto koffset = scolid; -#pragma unroll - for (int i = 0; i < P::LdgPerThY; ++i) { - if ((koffset + yrowid) < ldb && (srowid + kidx + i * P::LdgRowsY) < numRows) { - ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset); - } else { -#pragma unroll - for (int j = 0; j < P::Veclen; ++j) { - ldgDataY[i][j] = Zero; - } - } - } - } - } + DI void ldgY(IdxT tile_idx_n, IdxT kidx) { ldgY(tile_idx_n, kidx, n); } - DI void ldgY(IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) + DI void ldgY(IdxT tile_idx_n, IdxT kidx, IdxT end_n) { IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; if (isRowMajor) { - auto numRows = tile_end_n; + auto numRows = end_n; auto koffset = kidx + scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { @@ -278,7 +244,7 @@ struct Contractions_NT { auto koffset = scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { - if ((koffset + yrowid) < tile_end_n && (srowid + kidx + i * P::LdgRowsY) < numRows) { + if ((koffset + yrowid) < end_n && (srowid + kidx + i * P::LdgRowsY) < numRows) { ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset); } else { #pragma unroll From ed178ca99852c71691b90def2d654eef975d4178 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 12:13:38 +0100 Subject: [PATCH 36/49] Formatting --- .../raft/distance/detail/masked_nn.cuh | 56 +++++++++---------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 96d61896ea..153df4cd4f 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -180,7 +180,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, obj.run(); } - /** * @brief Wrapper for maskedL2NNkernel * @@ -278,45 +277,44 @@ void maskedL2NNImpl(const raft::handle_t& handle, dim3 grid(raft::ceildiv(m, P::Nthreads)); dim3 block(P::Nthreads); - initKernel - <<>>(out, m, maxVal, redOp); + initKernel<<>>(out, m, maxVal, redOp); RAFT_CUDA_TRY(cudaGetLastError()); } // Accumulation operation lambda auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - auto fin_op = raft::identity_op{}; + auto fin_op = raft::identity_op{}; - auto kernel = maskedL2NNkernel; + auto kernel = maskedL2NNkernel; constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); dim3 block(P::Nthreads); dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); kernel<<>>(out, - x, - y, - xn, - yn, - ws_adj64.data(), - group_idxs, - num_groups, - m, - n, - k, - sqrt, - maxVal, - ws_fused_nn.data(), - redOp, - pairRedOp, - core_lambda, - fin_op); + x, + y, + xn, + yn, + ws_adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + sqrt, + maxVal, + ws_fused_nn.data(), + redOp, + pairRedOp, + core_lambda, + fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } From 80cb3b05ee436e29baa66792a03df2c67f40f368 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 14:21:28 +0100 Subject: [PATCH 37/49] Add compress_to_bits kernel wrapper --- .../raft/distance/detail/compress_to_bits.cuh | 21 ++++++++++++------- .../raft/distance/detail/masked_nn.cuh | 4 +--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh index aa4bc27d39..da00277199 100644 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -18,12 +18,10 @@ #include #include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { template ::value>> -__global__ void compress_to_bits_naive(const bool* in, int in_rows, int in_cols, T* out) +__global__ void compress_to_bits_kernel(const bool* in, int in_rows, int in_cols, T* out) { constexpr int bits_per_element = 8 * sizeof(T); @@ -44,6 +42,15 @@ __global__ void compress_to_bits_naive(const bool* in, int in_rows, int in_cols, if (out_i < out_rows && out_j < out_cols) { atomicOr(&out[out_i * out_cols + out_j], bitfield); } } -}; // namespace detail -}; // namespace distance -}; // namespace raft +template ::value>> +void compress_to_bits( + const raft::handle_t& handle, const bool* in, int in_rows, int in_cols, T* out) +{ + auto stream = handle.get_stream(); + dim3 grid(raft::ceildiv(in_cols, 32), raft::ceildiv(in_rows, 32)); + dim3 block(32, 32); + compress_to_bits_kernel<<>>(in, in_rows, in_cols, out); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 153df4cd4f..427518548c 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -266,9 +266,7 @@ void maskedL2NNImpl(const raft::handle_t& handle, RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); // Compress boolean adjacency matrix to bitfield. - dim3 compress_grid(raft::ceildiv(m, 32), raft::ceildiv(num_groups, 32)); - compress_to_bits_naive<<>>( - adj, num_groups, m, ws_adj64.data()); + compress_to_bits(handle, adj, num_groups, m, ws_adj64.data()); // Initialize output buffer with keyvalue pairs as determined by the reduction // operator (it will be called with maxVal). From 546854db7a5949c991f89ee526ba41d43e01c9ed Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 16:47:46 +0100 Subject: [PATCH 38/49] Add compress_to_bits docs and test --- .../raft/distance/detail/compress_to_bits.cuh | 32 ++++ cpp/test/CMakeLists.txt | 1 + .../distance/masked_nn_compress_to_bits.cu | 174 ++++++++++++++++++ 3 files changed, 207 insertions(+) create mode 100644 cpp/test/distance/masked_nn_compress_to_bits.cu diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh index da00277199..6344ab9b8e 100644 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -15,11 +15,27 @@ */ #pragma once +#include #include #include namespace raft::distance::detail { +/** + * @brief Transpose and compress 2D boolean matrix to bitfield + * + * Utility kernel for maskedL2NN. + * + * @tparam T + * + * @parameter[in] in An `m x n` boolean matrix. Row major. + * @parameter in_rows The number of rows of `in`, i.e. `m`. + * @parameter in_cols The number of cols of `in`, i.e. `n`. + * + * @parameter[out] out An `(n / bits_per_elem) x m` matrix with elements of + * type T, where T is of size `bits_per_elem` bits. + * Note: the division (`/`) is a ceilDiv. + */ template ::value>> __global__ void compress_to_bits_kernel(const bool* in, int in_rows, int in_cols, T* out) { @@ -42,6 +58,22 @@ __global__ void compress_to_bits_kernel(const bool* in, int in_rows, int in_cols if (out_i < out_rows && out_j < out_cols) { atomicOr(&out[out_i * out_cols + out_j], bitfield); } } +/** + * @brief Transpose and compress 2D boolean matrix to bitfield + * + * Utility kernel for maskedL2NN. + * + * @tparam T + * + * @parameter handle RAFT handle. + * @parameter[in] in An `m x n` boolean matrix. Row major. + * @parameter in_rows The number of rows of `in`, i.e. `m`. + * @parameter in_cols The number of cols of `in`, i.e. `n`. + * + * @parameter[out] out An `(n / bits_per_elem) x m` matrix with elements of + * type T, where T is of size `bits_per_elem` bits. + * Note: the division (`/`) is a ceilDiv. + */ template ::value>> void compress_to_bits( const raft::handle_t& handle, const bool* in, int in_rows, int in_cols, T* out) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ddce9b0ee7..0a49d9b8fc 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -122,6 +122,7 @@ if(BUILD_TESTS) test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu test/distance/masked_nn.cu + test/distance/masked_nn_compress_to_bits.cu test/distance/gram.cu OPTIONAL DIST diff --git a/cpp/test/distance/masked_nn_compress_to_bits.cu b/cpp/test/distance/masked_nn_compress_to_bits.cu new file mode 100644 index 0000000000..abb6d48f11 --- /dev/null +++ b/cpp/test/distance/masked_nn_compress_to_bits.cu @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::distance::masked_nn::compress_to_bits { + +/** + * @brief Transpose and decompress 2D bitfield to boolean matrix + * + * Inverse operation of compress_to_bits + * + * @tparam T + * + * @parameter[in] in An `m x n` bitfield matrix. Row major. + * @parameter in_rows The number of rows of `in`, i.e. `m`. + * @parameter in_cols The number of cols of `in`, i.e. `n`. + * + * @parameter[out] out An `n x (m * bits_per_elem)` boolean matrix. + */ +template ::value>> +__global__ void decompress_bits_kernel(const T* in, int in_rows, int in_cols, bool* out) +{ + constexpr int bits_per_element = 8 * sizeof(T); + + const size_t i = threadIdx.y + blockIdx.y * blockDim.y; + const size_t j = threadIdx.x + blockIdx.x * blockDim.x; + + if (in_rows <= i || in_cols <= j) { return; } + + T bitfield = in[i * in_cols + j]; + const size_t out_rows = in_cols; + const size_t out_cols = in_rows * bits_per_element; + const size_t out_i = j; + const size_t out_j = i * bits_per_element; + + for (int bitpos = 0; bitpos < bits_per_element; ++bitpos) { + bool bit = ((T(1) << bitpos) & bitfield) != 0; + if (out_i < out_rows && out_j < out_cols) { + out[out_i * out_cols + out_j + bitpos] = bit; + } + } +} + +/** + * @brief Transpose and decompress 2D bitfield to boolean matrix + * + * Inverse operation of compress_to_bits + * + * @tparam T + * + * @parameter[in] in An `m x n` bitfield matrix. Row major. + * @parameter in_rows The number of rows of `in`, i.e. `m`. + * @parameter in_cols The number of cols of `in`, i.e. `n`. + * + * @parameter[out] out An `n x (m * bits_per_elem)` boolean matrix. + */ +template ::value>> +void decompress_bits( + const raft::handle_t& handle, const T* in, int in_rows, int in_cols, bool* out) +{ + auto stream = handle.get_stream(); + dim3 grid(raft::ceildiv(in_cols, 32), raft::ceildiv(in_rows, 32)); + dim3 block(32, 32); + decompress_bits_kernel<<>>(in, in_rows, in_cols, out); + RAFT_CUDA_TRY(cudaGetLastError()); +} + + +// Params holds parameters for test case +struct Params { + int m, n; +}; + +inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& +{ + return os << "m: " << p.m << ", n: " << p.n; +} + +inline auto gen_params() -> std::vector +{ + return raft::util::itertools::product( + {1, 3, 32, 33, 63, 64, 65, 10013}, + {1, 3, 32, 33, 63, 64, 65, 13001}); +} + + +// Check that the following holds +// +// decompress(compress(x)) == x +// +// for 2D boolean matrices x. +template +void check_invertible(const Params& p) { + using raft::distance::detail::compress_to_bits; + constexpr int bits_per_elem = sizeof(T) * 8; + + // Make m and n that are safe to ceildiv. + int m = p.m; + int n = raft::round_up_safe(p.n, bits_per_elem); + + // Generate random input + raft::handle_t handle{}; + raft::random::RngState r(1234ULL); + auto in = raft::make_device_matrix(handle, m, n); + raft::random::bernoulli(handle, r, in.data_handle(), m * n, 0.5f); + + int tmp_m = raft::ceildiv(n, bits_per_elem); + int tmp_n = m; + + int out_m = tmp_n; + int out_n = tmp_m * bits_per_elem; + + auto tmp = raft::make_device_matrix(handle, tmp_m, tmp_n); + auto out = raft::make_device_matrix(handle, out_m, out_n); + + ASSERT_EQ(in.extent(0), out.extent(0)) << "M does not match"; + ASSERT_EQ(in.extent(1), out.extent(1)) << "N does not match"; + + compress_to_bits(handle, in.data_handle(), in.extent(0), in.extent(1), tmp.data_handle()); + decompress_bits(handle, tmp.data_handle(), tmp.extent(0), tmp.extent(1), out.data_handle()); + + // Check for differences. + ASSERT_TRUE(raft::devArrMatch(in.data_handle(), + out.data_handle(), + in.extent(0) * in.extent(1), + raft::Compare(), + handle.get_stream())); +} + +class CompressToBitsTest : public ::testing::TestWithParam { + // Empty. +}; + +TEST_P(CompressToBitsTest, CheckInvertible64) +{ + using T = uint64_t; + check_invertible(GetParam()); +} +TEST_P(CompressToBitsTest, CheckInvertible32) +{ + using T = uint32_t; + check_invertible(GetParam()); +} + +INSTANTIATE_TEST_CASE_P(CompressToBits, CompressToBitsTest, ::testing::ValuesIn(gen_params())); + +} // end namespace raft::distance::masked_nn From 125ae98f58610c07bad10f583c21d8efa4c6e893 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 16:48:05 +0100 Subject: [PATCH 39/49] masked_nn: Fix test param printer --- cpp/test/distance/masked_nn.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index ee292bf855..1f65e4bcf4 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -150,7 +150,6 @@ struct Params { AdjacencyPattern pattern; }; -template inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& { os << "m: " << p.m << ", n: " << p.n << ", k: " << p.k << ", num_groups: " << p.num_groups From 6e776538573618894e88e1fb54bc4271976b990d Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 20:28:44 +0100 Subject: [PATCH 40/49] Use mdspan for compress_to_bits Also(!) remove the inherent transpose in compress_to_bits. --- .../raft/distance/detail/compress_to_bits.cuh | 92 +++++++++----- .../raft/distance/detail/masked_nn.cuh | 5 +- cpp/test/distance/masked_nn.cu | 37 +++--- .../distance/masked_nn_compress_to_bits.cu | 114 ++++++++++++------ 4 files changed, 167 insertions(+), 81 deletions(-) diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh index 6344ab9b8e..f9c4fed5cf 100644 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -22,66 +22,100 @@ namespace raft::distance::detail { /** - * @brief Transpose and compress 2D boolean matrix to bitfield + * @brief Compress 2D boolean matrix to bitfield * * Utility kernel for maskedL2NN. * * @tparam T * * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter in_rows The number of rows of `in`, i.e. `m`. - * @parameter in_cols The number of cols of `in`, i.e. `n`. - * - * @parameter[out] out An `(n / bits_per_elem) x m` matrix with elements of + * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of * type T, where T is of size `bits_per_elem` bits. * Note: the division (`/`) is a ceilDiv. */ template ::value>> -__global__ void compress_to_bits_kernel(const bool* in, int in_rows, int in_cols, T* out) +__global__ void compress_to_bits_kernel( + raft::device_matrix_view in, + raft::device_matrix_view out) { constexpr int bits_per_element = 8 * sizeof(T); + constexpr int tile_dim_m = bits_per_element; + constexpr int nthreads = 128; + constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector + + // Tile in shared memory is transposed + __shared__ bool smem[tile_dim_n][tile_dim_m]; - const size_t i = threadIdx.y + blockIdx.y * blockDim.y; - const size_t j = threadIdx.x + blockIdx.x * blockDim.x; + const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); + const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); - if (in_rows <= i || in_cols <= j) { return; } + for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { + const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); + const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); - bool bit = in[i * in_cols + j]; - int bitpos = j % bits_per_element; + if (in.extent(0) <= tile_idx_m) { break; } + // Fill shared memory tile + bool reg_buf[tile_dim_m]; +#pragma unroll + for (int i = 0; i < tile_dim_m; ++i) { + const int in_m = tile_idx_m + i; + const int in_n = tile_idx_n + threadIdx.x; + bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); + reg_buf[i] = in_bounds ? in(in_m, in_n) : false; + smem[threadIdx.x][i] = reg_buf[i]; + } + __syncthreads(); - T bitfield = bit ? T(1) << bitpos : 0; + // Drain memory tile into single output element out_elem. + T out_elem{0}; +#pragma unroll + for (int j = 0; j < tile_dim_n; ++j) { + if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } + } + __syncthreads(); - const size_t out_rows = raft::ceildiv(in_cols, bits_per_element); - const size_t out_cols = in_rows; - const size_t out_j = i; - const size_t out_i = j / bits_per_element; - if (out_i < out_rows && out_j < out_cols) { atomicOr(&out[out_i * out_cols + out_j], bitfield); } + // Write output. + int out_m = tile_idx_m / bits_per_element; + int out_n = tile_idx_n + threadIdx.x; + + if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } + } } /** - * @brief Transpose and compress 2D boolean matrix to bitfield + * @brief Compress 2D boolean matrix to bitfield * * Utility kernel for maskedL2NN. * * @tparam T * - * @parameter handle RAFT handle. * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter in_rows The number of rows of `in`, i.e. `m`. - * @parameter in_cols The number of cols of `in`, i.e. `n`. - * - * @parameter[out] out An `(n / bits_per_elem) x m` matrix with elements of + * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of * type T, where T is of size `bits_per_elem` bits. * Note: the division (`/`) is a ceilDiv. */ template ::value>> -void compress_to_bits( - const raft::handle_t& handle, const bool* in, int in_rows, int in_cols, T* out) +void compress_to_bits(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) { - auto stream = handle.get_stream(); - dim3 grid(raft::ceildiv(in_cols, 32), raft::ceildiv(in_rows, 32)); - dim3 block(32, 32); - compress_to_bits_kernel<<>>(in, in_rows, in_cols, out); + auto stream = handle.get_stream(); + constexpr int bits_per_element = 8 * sizeof(T); + + RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), + "Number of output rows must be ceildiv(input rows, bits_per_elem)"); + RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); + + const int num_SMs = raft::getMultiProcessorCount(); + int blocks_per_sm = 0; + constexpr int num_threads = 128; + constexpr int dyn_smem_size = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, compress_to_bits_kernel, num_threads, dyn_smem_size)); + + dim3 grid(num_SMs * blocks_per_sm); + dim3 block(128); + compress_to_bits_kernel<<>>(in, out); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 427518548c..5bbcbcf56e 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -266,7 +266,10 @@ void maskedL2NNImpl(const raft::handle_t& handle, RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); // Compress boolean adjacency matrix to bitfield. - compress_to_bits(handle, adj, num_groups, m, ws_adj64.data()); + auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); + auto adj64_view = + raft::make_device_matrix_view(ws_adj64.data(), m_div_64, num_groups); + compress_to_bits(handle, adj_view, adj64_view); // Initialize output buffer with keyvalue pairs as determined by the reduction // operator (it will be called with maxVal). diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index 1f65e4bcf4..c80c984992 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -43,17 +43,24 @@ enum AdjacencyPattern { // - init_adj: to initialize the adjacency kernel with a specific adjacency pattern // - referenceKernel: to produce the ground-truth output -__global__ void init_adj( - int m, int n, int num_groups, AdjacencyPattern pattern, bool* adj, int* group_idxs) +__global__ void init_adj(AdjacencyPattern pattern, + int n, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs) { - for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_groups; i += blockDim.y * gridDim.y) { - for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < m; j += blockDim.x * gridDim.x) { + int m = adj.extent(0); + int num_groups = adj.extent(1); + + for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; + idx_m += blockDim.y * gridDim.y) { + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; + idx_g += blockDim.x * gridDim.x) { switch (pattern) { - case checkerboard: adj[i * m + j] = (i + j) % 2; break; - case checkerboard_4: adj[i * m + j] = (i + (j / 4)) % 2; break; - case checkerboard_64: adj[i * m + j] = (i + (j / 64)) % 2; break; - case all_true: adj[i * m + j] = true; break; - case all_false: adj[i * m + j] = false; break; + case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; + case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; + case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; + case all_true: adj(idx_m, idx_g) = true; break; + case all_false: adj(idx_m, idx_g) = false; break; default: assert(false && "unknown pattern"); } } @@ -68,11 +75,11 @@ __global__ void init_adj( // - The group_idxs[num_groups - 1] should always equal n. if (blockIdx.y == 0 && threadIdx.y == 0) { - const int j_stride = blockDim.x * gridDim.x; - for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_groups; j += j_stride) { - group_idxs[j] = (j + 1) * (n / num_groups); + const int g_stride = blockDim.x * gridDim.x; + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { + group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); } - group_idxs[num_groups - 1] = n; + group_idxs(num_groups - 1) = n; } } @@ -106,7 +113,7 @@ __global__ __launch_bounds__(32 * NWARPS, for (int i = num_groups; 0 <= i; --i) { if (nidx < group_idxs[i]) { group_idx = i; } } - const bool include_dist = adj[group_idx * m + midx] && midx < m && nidx < n; + const bool include_dist = adj[midx * num_groups + group_idx] && midx < m && nidx < n; // Compute L2 metric. DataT acc = DataT(0); @@ -180,7 +187,7 @@ struct Inputs { dim3 block(32, 32); dim3 grid(10, 10); init_adj<<>>( - p.m, p.n, p.num_groups, p.pattern, adj.data_handle(), group_idxs.data_handle()); + p.pattern, p.n, adj.view(), group_idxs.view()); RAFT_CUDA_TRY(cudaGetLastError()); } }; diff --git a/cpp/test/distance/masked_nn_compress_to_bits.cu b/cpp/test/distance/masked_nn_compress_to_bits.cu index abb6d48f11..7597362274 100644 --- a/cpp/test/distance/masked_nn_compress_to_bits.cu +++ b/cpp/test/distance/masked_nn_compress_to_bits.cu @@ -14,15 +14,16 @@ * limitations under the License. */ -#include "../test_utils.h" #include "../test_utils.cuh" +#include "../test_utils.h" #include #include #include -#include #include #include #include +#include +#include #include #include #include @@ -42,7 +43,7 @@ namespace raft::distance::masked_nn::compress_to_bits { * @parameter in_rows The number of rows of `in`, i.e. `m`. * @parameter in_cols The number of cols of `in`, i.e. `n`. * - * @parameter[out] out An `n x (m * bits_per_elem)` boolean matrix. + * @parameter[out] out An `(m * bits_per_elem) x n` boolean matrix. */ template ::value>> __global__ void decompress_bits_kernel(const T* in, int in_rows, int in_cols, bool* out) @@ -54,17 +55,17 @@ __global__ void decompress_bits_kernel(const T* in, int in_rows, int in_cols, bo if (in_rows <= i || in_cols <= j) { return; } - T bitfield = in[i * in_cols + j]; - const size_t out_rows = in_cols; - const size_t out_cols = in_rows * bits_per_element; - const size_t out_i = j; - const size_t out_j = i * bits_per_element; + const size_t out_rows = in_rows * bits_per_element; + const size_t out_cols = in_cols; + const size_t out_i = i * bits_per_element; + const size_t out_j = j; + if (out_rows <= out_i && out_cols <= out_j) { return; } + + T bitfield = in[i * in_cols + j]; for (int bitpos = 0; bitpos < bits_per_element; ++bitpos) { - bool bit = ((T(1) << bitpos) & bitfield) != 0; - if (out_i < out_rows && out_j < out_cols) { - out[out_i * out_cols + out_j + bitpos] = bit; - } + bool bit = ((T(1) << bitpos) & bitfield) != 0; + out[(out_i + bitpos) * out_cols + out_j] = bit; } } @@ -82,8 +83,7 @@ __global__ void decompress_bits_kernel(const T* in, int in_rows, int in_cols, bo * @parameter[out] out An `n x (m * bits_per_elem)` boolean matrix. */ template ::value>> -void decompress_bits( - const raft::handle_t& handle, const T* in, int in_rows, int in_cols, bool* out) +void decompress_bits(const raft::handle_t& handle, const T* in, int in_rows, int in_cols, bool* out) { auto stream = handle.get_stream(); dim3 grid(raft::ceildiv(in_cols, 32), raft::ceildiv(in_rows, 32)); @@ -92,7 +92,6 @@ void decompress_bits( RAFT_CUDA_TRY(cudaGetLastError()); } - // Params holds parameters for test case struct Params { int m, n; @@ -103,48 +102,46 @@ inline auto operator<<(std::ostream& os, const Params& p) -> std::ostream& return os << "m: " << p.m << ", n: " << p.n; } -inline auto gen_params() -> std::vector -{ - return raft::util::itertools::product( - {1, 3, 32, 33, 63, 64, 65, 10013}, - {1, 3, 32, 33, 63, 64, 65, 13001}); -} - - // Check that the following holds // // decompress(compress(x)) == x // // for 2D boolean matrices x. template -void check_invertible(const Params& p) { +void check_invertible(const Params& p) +{ using raft::distance::detail::compress_to_bits; constexpr int bits_per_elem = sizeof(T) * 8; // Make m and n that are safe to ceildiv. - int m = p.m; - int n = raft::round_up_safe(p.n, bits_per_elem); + int m = raft::round_up_safe(p.m, bits_per_elem); + int n = p.n; // Generate random input raft::handle_t handle{}; - raft::random::RngState r(1234ULL); + raft::random::RngState r(1ULL); auto in = raft::make_device_matrix(handle, m, n); raft::random::bernoulli(handle, r, in.data_handle(), m * n, 0.5f); - int tmp_m = raft::ceildiv(n, bits_per_elem); - int tmp_n = m; + int tmp_m = raft::ceildiv(m, bits_per_elem); + int out_m = tmp_m * bits_per_elem; - int out_m = tmp_n; - int out_n = tmp_m * bits_per_elem; + auto tmp = raft::make_device_matrix(handle, tmp_m, n); + auto out = raft::make_device_matrix(handle, out_m, n); - auto tmp = raft::make_device_matrix(handle, tmp_m, tmp_n); - auto out = raft::make_device_matrix(handle, out_m, out_n); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); ASSERT_EQ(in.extent(0), out.extent(0)) << "M does not match"; ASSERT_EQ(in.extent(1), out.extent(1)) << "N does not match"; - compress_to_bits(handle, in.data_handle(), in.extent(0), in.extent(1), tmp.data_handle()); + compress_to_bits(handle, in.view(), tmp.view()); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); + decompress_bits(handle, tmp.data_handle(), tmp.extent(0), tmp.extent(1), out.data_handle()); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); // Check for differences. ASSERT_TRUE(raft::devArrMatch(in.data_handle(), @@ -152,23 +149,68 @@ void check_invertible(const Params& p) { in.extent(0) * in.extent(1), raft::Compare(), handle.get_stream())); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +void check_all_true(const Params& p) +{ + using raft::distance::detail::compress_to_bits; + using T = uint64_t; + constexpr int bits_per_elem = sizeof(T) * 8; + + // Make m and n that are safe to ceildiv. + int m = raft::round_up_safe(p.m, bits_per_elem); + int n = p.n; + + raft::handle_t handle{}; + raft::random::RngState r(1ULL); + auto in = raft::make_device_matrix(handle, m, n); + raft::matrix::fill(handle, in.view(), true); + + int tmp_m = raft::ceildiv(m, bits_per_elem); + auto tmp = raft::make_device_matrix(handle, tmp_m, n); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); + + compress_to_bits(handle, in.view(), tmp.view()); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); + + auto expected = raft::make_device_matrix(handle, tmp_m, n); + raft::matrix::fill(handle, expected.view(), ~T(0)); + + // Check for differences. + ASSERT_TRUE(raft::devArrMatch(expected.data_handle(), + tmp.data_handle(), + tmp.extent(0) * tmp.extent(1), + raft::Compare(), + handle.get_stream())); + handle.sync_stream(); + RAFT_CUDA_TRY(cudaGetLastError()); } class CompressToBitsTest : public ::testing::TestWithParam { // Empty. }; +TEST_P(CompressToBitsTest, CheckTrue64) { check_all_true(GetParam()); } + TEST_P(CompressToBitsTest, CheckInvertible64) { using T = uint64_t; check_invertible(GetParam()); } + TEST_P(CompressToBitsTest, CheckInvertible32) { using T = uint32_t; check_invertible(GetParam()); } -INSTANTIATE_TEST_CASE_P(CompressToBits, CompressToBitsTest, ::testing::ValuesIn(gen_params())); +std::vector params = raft::util::itertools::product( + {1, 3, 32, 33, 63, 64, 65, 128, 10013}, {1, 3, 32, 33, 63, 64, 65, 13001}); + +INSTANTIATE_TEST_CASE_P(CompressToBits, CompressToBitsTest, ::testing::ValuesIn(params)); -} // end namespace raft::distance::masked_nn +} // namespace raft::distance::masked_nn::compress_to_bits From 892514672851d44581ecebb90da79933f3ad6686 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 25 Jan 2023 21:00:26 +0100 Subject: [PATCH 41/49] Make benchmark more informative Add flop/s etc. --- cpp/bench/distance/masked_nn.cu | 186 ++++++++++++++++++++++---------- 1 file changed, 127 insertions(+), 59 deletions(-) diff --git a/cpp/bench/distance/masked_nn.cu b/cpp/bench/distance/masked_nn.cu index cef91f8daf..3677d44864 100644 --- a/cpp/bench/distance/masked_nn.cu +++ b/cpp/bench/distance/masked_nn.cu @@ -37,7 +37,7 @@ namespace raft::bench::distance::masked_nn { // Introduce various sparsity patterns -enum SparsityPattern { +enum AdjacencyPattern { checkerboard = 0, checkerboard_4 = 1, checkerboard_64 = 2, @@ -45,22 +45,29 @@ enum SparsityPattern { all_false = 4 }; -struct masked_l2_nn_inputs { +struct Params { int m, n, k, num_groups; - SparsityPattern pattern; -}; // struct masked_l2_nn_inputs + AdjacencyPattern pattern; +}; // struct Params -__global__ void init_adj( - int m, int n, int num_groups, SparsityPattern pattern, bool* adj, int* group_idxs) +__global__ void init_adj(AdjacencyPattern pattern, + int n, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs) { - for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < num_groups; i += blockDim.y * gridDim.y) { - for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < m; j += blockDim.x * gridDim.x) { + int m = adj.extent(0); + int num_groups = adj.extent(1); + + for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; + idx_m += blockDim.y * gridDim.y) { + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; + idx_g += blockDim.x * gridDim.x) { switch (pattern) { - case checkerboard: adj[i * m + j] = (i + j) % 2; break; - case checkerboard_4: adj[i * m + j] = (i + (j / 4)) % 2; break; - case checkerboard_64: adj[i * m + j] = (i + (j / 64)) % 2; break; - case all_true: adj[i * m + j] = true; break; - case all_false: adj[i * m + j] = false; break; + case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; + case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; + case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; + case all_true: adj(idx_m, idx_g) = true; break; + case all_false: adj(idx_m, idx_g) = false; break; default: assert(false && "unknown pattern"); } } @@ -75,11 +82,11 @@ __global__ void init_adj( // - The group_idxs[num_groups - 1] should always equal n. if (blockIdx.y == 0 && threadIdx.y == 0) { - for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_groups; - j += blockDim.x * gridDim.x) { - group_idxs[j] = (j + 1) * (n / num_groups); + const int g_stride = blockDim.x * gridDim.x; + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { + group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); } - group_idxs[num_groups - 1] = n; + group_idxs(num_groups - 1) = n; } } @@ -93,7 +100,7 @@ struct masked_l2_nn : public fixture { using ParamT = raft::distance::MaskedL2NNParams; // Parameters - masked_l2_nn_inputs params; + Params params; // Data raft::device_vector out; raft::device_matrix x, y; @@ -101,7 +108,7 @@ struct masked_l2_nn : public fixture { raft::device_matrix adj; raft::device_vector group_idxs; - masked_l2_nn(const masked_l2_nn_inputs& p) + masked_l2_nn(const Params& p) : params(p), out{raft::make_device_vector(handle, p.m)}, x{raft::make_device_matrix(handle, p.m, p.k)}, @@ -124,14 +131,13 @@ struct masked_l2_nn : public fixture { dim3 block(32, 32); dim3 grid(10, 10); - init_adj<<>>( - p.m, p.n, p.num_groups, p.pattern, adj.data_handle(), group_idxs.data_handle()); + init_adj<<>>(p.pattern, p.n, adj.view(), group_idxs.view()); RAFT_CUDA_TRY(cudaGetLastError()); } void run_benchmark(::benchmark::State& state) override { - bool init_out = false; + bool init_out = true; bool sqrt = false; ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; @@ -147,49 +153,111 @@ struct masked_l2_nn : public fixture { group_idxs.view(), out.view()); }); + + // Virtual flop count if no skipping had occurred. + size_t virtual_flops = size_t(2) * size_t(params.m) * size_t(params.n) * size_t(params.k); + + int64_t read_elts = params.n * params.k + params.m * params.k; + int64_t write_elts = params.m; + + // Virtual min flops is the number of flops that would have been executed if + // the algorithm had actually skipped each computation that it could have + // skipped. + size_t virtual_min_flops = 0; + switch (params.pattern) { + case checkerboard: + case checkerboard_4: + case checkerboard_64: virtual_min_flops = virtual_flops / 2; break; + case all_true: virtual_min_flops = virtual_flops; break; + case all_false: virtual_min_flops = 0; break; + default: assert(false && "unknown pattern"); + } + + // VFLOP/s is the "virtual" flop count that would have executed if there was + // no adjacency pattern. This is useful for comparing to fusedL2NN + state.counters["VFLOP/s"] = benchmark::Counter(virtual_flops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + // Virtual min flops is the number of flops that would have been executed if + // the algorithm had actually skipped each computation that it could have + // skipped. + state.counters["VminFLOP/s"] = benchmark::Counter(virtual_min_flops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["m"] = benchmark::Counter(params.m); + state.counters["n"] = benchmark::Counter(params.n); + state.counters["k"] = benchmark::Counter(params.k); + state.counters["num_groups"] = benchmark::Counter(params.num_groups); + state.counters["group size"] = benchmark::Counter(params.n / params.num_groups); + state.counters["Pat"] = benchmark::Counter(static_cast(params.pattern)); + + state.counters["SM count"] = raft::getMultiProcessorCount(); } }; // struct MaskedL2NN -// TODO: Consider thinning the list of benchmark cases.. -const std::vector masked_l2_nn_input_vecs = { +const std::vector masked_l2_nn_input_vecs = { // Very fat matrices... - {32, 16384, 16384, 32, SparsityPattern::checkerboard}, - {64, 16384, 16384, 32, SparsityPattern::checkerboard}, - {128, 16384, 16384, 32, SparsityPattern::checkerboard}, - {256, 16384, 16384, 32, SparsityPattern::checkerboard}, - {512, 16384, 16384, 32, SparsityPattern::checkerboard}, - {1024, 16384, 16384, 32, SparsityPattern::checkerboard}, - {16384, 32, 16384, 32, SparsityPattern::checkerboard}, - {16384, 64, 16384, 32, SparsityPattern::checkerboard}, - {16384, 128, 16384, 32, SparsityPattern::checkerboard}, - {16384, 256, 16384, 32, SparsityPattern::checkerboard}, - {16384, 512, 16384, 32, SparsityPattern::checkerboard}, - {16384, 1024, 16384, 32, SparsityPattern::checkerboard}, + {32, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {64, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {128, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {256, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {512, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {1024, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 32, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 64, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 128, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 256, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 512, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 1024, 16384, 32, AdjacencyPattern::checkerboard}, // Representative matrices... - {16384, 16384, 32, 32, SparsityPattern::checkerboard}, - {16384, 16384, 64, 32, SparsityPattern::checkerboard}, - {16384, 16384, 128, 32, SparsityPattern::checkerboard}, - {16384, 16384, 256, 32, SparsityPattern::checkerboard}, - {16384, 16384, 512, 32, SparsityPattern::checkerboard}, - {16384, 16384, 1024, 32, SparsityPattern::checkerboard}, - {16384, 16384, 16384, 32, SparsityPattern::checkerboard}, - - {16384, 16384, 32, 32, SparsityPattern::checkerboard_4}, - {16384, 16384, 64, 32, SparsityPattern::checkerboard_4}, - {16384, 16384, 128, 32, SparsityPattern::checkerboard_4}, - {16384, 16384, 256, 32, SparsityPattern::checkerboard_4}, - {16384, 16384, 512, 32, SparsityPattern::checkerboard_4}, - {16384, 16384, 1024, 32, SparsityPattern::checkerboard_4}, - {16384, 16384, 16384, 32, SparsityPattern::checkerboard_4}, - - {16384, 16384, 32, 32, SparsityPattern::checkerboard_64}, - {16384, 16384, 64, 32, SparsityPattern::checkerboard_64}, - {16384, 16384, 128, 32, SparsityPattern::checkerboard_64}, - {16384, 16384, 256, 32, SparsityPattern::checkerboard_64}, - {16384, 16384, 512, 32, SparsityPattern::checkerboard_64}, - {16384, 16384, 1024, 32, SparsityPattern::checkerboard_64}, - {16384, 16384, 16384, 32, SparsityPattern::checkerboard_64}, + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_4}, + + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_64}, + + {16384, 16384, 32, 32, AdjacencyPattern::all_true}, + {16384, 16384, 64, 32, AdjacencyPattern::all_true}, + {16384, 16384, 128, 32, AdjacencyPattern::all_true}, + {16384, 16384, 256, 32, AdjacencyPattern::all_true}, + {16384, 16384, 512, 32, AdjacencyPattern::all_true}, + {16384, 16384, 1024, 32, AdjacencyPattern::all_true}, + {16384, 16384, 16384, 32, AdjacencyPattern::all_true}, + + {16384, 16384, 32, 32, AdjacencyPattern::all_false}, + {16384, 16384, 64, 32, AdjacencyPattern::all_false}, + {16384, 16384, 128, 32, AdjacencyPattern::all_false}, + {16384, 16384, 256, 32, AdjacencyPattern::all_false}, + {16384, 16384, 512, 32, AdjacencyPattern::all_false}, + {16384, 16384, 1024, 32, AdjacencyPattern::all_false}, + {16384, 16384, 16384, 32, AdjacencyPattern::all_false}, }; RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); From e52b0f94afc1f1b6f4071eee2ac96b349030554d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 10:10:11 -0500 Subject: [PATCH 42/49] Forcing sccache reinit. --- build.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/build.sh b/build.sh index b47e1ed862..f34c032204 100755 --- a/build.sh +++ b/build.sh @@ -75,6 +75,8 @@ COMPILE_DIST_LIBRARY=OFF ENABLE_NN_DEPENDENCIES=OFF INSTALL_TARGET=install +SCCACHE_RECACHE=1 + TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" ENABLE_thrust_DEPENDENCY=ON From 85c6294d6f3adc10ef623332ee4ca7d6509afa42 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 11:19:43 -0500 Subject: [PATCH 43/49] Breaking specializations for refine into individual files --- cpp/CMakeLists.txt | 9 +++- .../raft/linalg/detail/contractions.cuh | 2 +- .../knn/detail/epsilon_neighborhood.cuh | 2 +- cpp/src/distance/neighbors/refine.cu | 52 ------------------- .../neighbors/refine_d_uint64_t_float.cu | 33 ++++++++++++ .../neighbors/refine_d_uint64_t_int8_t.cu | 33 ++++++++++++ .../neighbors/refine_d_uint64_t_uint8_t.cu | 33 ++++++++++++ .../neighbors/refine_h_uint64_t_float.cu | 33 ++++++++++++ .../neighbors/refine_h_uint64_t_int8_t.cu | 33 ++++++++++++ .../neighbors/refine_h_uint64_t_uint8_t.cu | 33 ++++++++++++ 10 files changed, 207 insertions(+), 56 deletions(-) delete mode 100644 cpp/src/distance/neighbors/refine.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c6850b290f..a45c5b0cc8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -284,7 +284,12 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/cluster/update_centroids_double.cu src/distance/cluster/cluster_cost_float.cu src/distance/cluster/cluster_cost_double.cu - src/distance/neighbors/refine.cu + src/distance/neighbors/refine_d_uint64_t_float.cu + src/distance/neighbors/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/refine_h_uint64_t_float.cu + src/distance/neighbors/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/refine_h_uint64_t_uint8_t.cu src/distance/neighbors/ivfpq_search.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 6d7a8e2292..f2d71117f7 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index cd0e005921..7616083796 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/distance/neighbors/refine.cu b/cpp/src/distance/neighbors/refine.cu deleted file mode 100644 index 83e3383cba..0000000000 --- a/cpp/src/distance/neighbors/refine.cu +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::runtime::neighbors { - -#define RAFT_INST_REFINE(IDX_T, DATA_T) \ - void refine(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric) \ - { \ - raft::neighbors::detail::refine_device( \ - handle, dataset, queries, neighbor_candidates, indices, distances, metric); \ - } \ - \ - void refine(raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric) \ - { \ - raft::neighbors::detail::refine_host( \ - dataset, queries, neighbor_candidates, indices, distances, metric); \ - } - -RAFT_INST_REFINE(uint64_t, float); -RAFT_INST_REFINE(uint64_t, uint8_t); -RAFT_INST_REFINE(uint64_t, int8_t); - -#undef RAFT_INST_REFINE - -} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu new file mode 100644 index 0000000000..32819099d9 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu new file mode 100644 index 0000000000..ff45e74ba1 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..0a1590194b --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu new file mode 100644 index 0000000000..2d734ac5bf --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu new file mode 100644 index 0000000000..9749499298 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..dbc2100635 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors From 0fad8425730ef3d4a28b4d945952833dd0732809 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 12:45:44 -0500 Subject: [PATCH 44/49] Checking in --- cpp/bench/neighbors/refine.cu | 11 ++-- .../raft/neighbors/specializations.cuh | 4 +- .../raft/neighbors/specializations/refine.cuh | 51 +++++++++++++++++++ .../neighbors/refine_d_uint64_t_float.cu | 1 + .../neighbors/refine_d_uint64_t_int8_t.cu | 1 + .../neighbors/refine_d_uint64_t_uint8_t.cu | 1 + .../neighbors/refine_h_uint64_t_float.cu | 1 + .../neighbors/refine_h_uint64_t_int8_t.cu | 2 +- .../neighbors/refine_h_uint64_t_uint8_t.cu | 1 + .../refine_d_uint64_t_float.cu | 30 +++++++++++ .../refine_d_uint64_t_int8_t.cu | 30 +++++++++++ .../refine_d_uint64_t_uint8_t.cu | 30 +++++++++++ .../refine_h_uint64_t_float.cu | 30 +++++++++++ .../refine_h_uint64_t_int8_t.cu | 29 +++++++++++ .../refine_h_uint64_t_uint8_t.cu | 30 +++++++++++ cpp/test/neighbors/refine.cu | 12 ++--- 16 files changed, 250 insertions(+), 14 deletions(-) create mode 100644 cpp/include/raft/neighbors/specializations/refine.cuh create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index 3349b8b6ae..16b115cab4 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -27,6 +27,7 @@ #if defined RAFT_DISTANCE_COMPILED #include +#include #endif #if defined RAFT_NN_COMPILED @@ -52,7 +53,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::os return os; } -RefineInputs p; +RefineInputs p; template class RefineAnn : public fixture { @@ -113,9 +114,9 @@ std::vector> getInputs() return out; } -using refine_float_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); +using refine_float_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs()); -using refine_uint8_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); +using refine_uint8_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 0511bbbf6c..77c49b70e6 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ #include #include #include +#include #include - #endif diff --git a/cpp/include/raft/neighbors/specializations/refine.cuh b/cpp/include/raft/neighbors/specializations/refine.cuh new file mode 100644 index 0000000000..71e83a26f3 --- /dev/null +++ b/cpp/include/raft/neighbors/specializations/refine.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::neighbors { + +#ifdef RAFT_INST +#undef RAFT_INST +#endif + +#define RAFT_INST(T, IdxT) \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric); \ + \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + distance::DistanceType metric); + +RAFT_INST(float, uint64_t); +RAFT_INST(uint8_t, uint64_t); +RAFT_INST(int8_t, uint64_t); + +#undef RAFT_INST +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu index 32819099d9..75fe526b07 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu index ff45e74ba1..aaf05ca3cb 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu index 0a1590194b..574ed7cf29 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu index 2d734ac5bf..d03c082329 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu index 9749499298..01982ada95 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -15,9 +15,9 @@ */ #include +#include namespace raft::runtime::neighbors { - void refine(raft::device_resources const& handle, raft::host_matrix_view dataset, raft::host_matrix_view queries, diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu index dbc2100635..08a9ff410e 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu @@ -15,6 +15,7 @@ */ #include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu new file mode 100644 index 0000000000..6bb1985d94 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu new file mode 100644 index 0000000000..7e70ee5e29 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..53de106ef9 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu new file mode 100644 index 0000000000..b473924741 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu new file mode 100644 index 0000000000..c8b0e4c1c2 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..b9e0f58ef6 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 98933046b9..6f9e8210be 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -31,7 +31,7 @@ #include -#if defined RAFT_NN_COMPILED +#if defined RAFT_DISTANCE_COMPILED #include #endif @@ -107,8 +107,8 @@ class RefineTest : public ::testing::TestWithParam> { RefineHelper data; }; -const std::vector> inputs = - raft::util::itertools::product>( +const std::vector> inputs = + raft::util::itertools::product>( {137}, {1000}, {16}, @@ -117,16 +117,16 @@ const std::vector> inputs = {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false, true}); -typedef RefineTest RefineTestF; +typedef RefineTest RefineTestF; TEST_P(RefineTestF, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_uint8; +typedef RefineTest RefineTestF_uint8; TEST_P(RefineTestF_uint8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_uint8, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_int8; +typedef RefineTest RefineTestF_int8; TEST_P(RefineTestF_int8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_int8, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors From f7788af34ee8e17efd0e0c408b42b959b024c72b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 15:10:53 -0500 Subject: [PATCH 45/49] Including just the refine specialization --- cpp/include/raft/neighbors/specializations.cuh | 3 +-- cpp/src/distance/neighbors/refine_d_uint64_t_float.cu | 2 +- cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu | 2 +- cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu | 2 +- cpp/src/distance/neighbors/refine_h_uint64_t_float.cu | 2 +- cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu | 2 +- cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu | 2 +- 7 files changed, 7 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 77c49b70e6..d17467c8a7 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -20,9 +20,8 @@ #pragma once #include +#include #include #include #include - -#include #endif diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu index 75fe526b07..d7b460180a 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu index aaf05ca3cb..3db07f0cdb 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu index 574ed7cf29..2ce43d5800 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu index d03c082329..2a2dcff3bf 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu index 01982ada95..d7c60b62a5 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { void refine(raft::device_resources const& handle, diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu index 08a9ff410e..e9c4345e97 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu @@ -15,7 +15,7 @@ */ #include -#include +#include namespace raft::runtime::neighbors { From 9e7b7298cf444ccbfec4df61d7b40b796c83227b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 15:30:46 -0500 Subject: [PATCH 46/49] Proper import of speicalizations --- cpp/bench/neighbors/knn.cuh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index 60eb8c257d..eec1cba99e 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -32,6 +32,9 @@ #include #if defined RAFT_DISTANCE_COMPILED #include +#include +#else +#pragma message("NN / Distance specializations are not enabled; expect very long building times.") #endif #endif From 060e62cd9363d8d8831c6bd619a8a303442bb596 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 15:40:24 -0500 Subject: [PATCH 47/49] Remove SCCACHE_RECACHE from build.sh --- build.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/build.sh b/build.sh index f34c032204..b47e1ed862 100755 --- a/build.sh +++ b/build.sh @@ -75,8 +75,6 @@ COMPILE_DIST_LIBRARY=OFF ENABLE_NN_DEPENDENCIES=OFF INSTALL_TARGET=install -SCCACHE_RECACHE=1 - TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;NEIGHBORS_TEST;STATS_TEST;UTILS_TEST" BENCH_TARGETS="CLUSTER_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH" ENABLE_thrust_DEPENDENCY=ON From 2b0c02b832e54f4d749595e21706bdbe6c018dc3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 26 Jan 2023 17:09:10 -0500 Subject: [PATCH 48/49] Small compilation error remains --- cpp/test/neighbors/refine.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 6f9e8210be..a78f5cfe5c 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -109,11 +109,11 @@ class RefineTest : public ::testing::TestWithParam> { const std::vector> inputs = raft::util::itertools::product>( - {137}, - {1000}, - {16}, - {1, 10, 33}, - {33}, + {static_cast(137)}, + {static_cast(1000)}, + {static_cast(16)}, + {static_cast(1), static_cast(10), static_cast(33)}, + {static_cast(33)}, {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false, true}); From 1e83640a1c12ce2671dc0fc5f0f732efd0d79510 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 27 Jan 2023 12:16:14 +0100 Subject: [PATCH 49/49] Take device_resources instead of handle --- cpp/include/raft/distance/detail/compress_to_bits.cuh | 2 +- cpp/include/raft/distance/detail/masked_nn.cuh | 2 +- cpp/include/raft/distance/masked_nn.cuh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh index f9c4fed5cf..e36b7ce707 100644 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ b/cpp/include/raft/distance/detail/compress_to_bits.cuh @@ -95,7 +95,7 @@ __global__ void compress_to_bits_kernel( * Note: the division (`/`) is a ceilDiv. */ template ::value>> -void compress_to_bits(const raft::handle_t& handle, +void compress_to_bits(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh index 5bbcbcf56e..1c92de16fc 100644 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -230,7 +230,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, * */ template -void maskedL2NNImpl(const raft::handle_t& handle, +void maskedL2NNImpl(raft::device_resources const& handle, OutT* out, const DataT* x, const DataT* y, diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh index 17c11630ab..ea2e10a304 100644 --- a/cpp/include/raft/distance/masked_nn.cuh +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -145,7 +145,7 @@ struct MaskedL2NNParams { * (on device) */ template -void maskedL2NN(const raft::handle_t& handle, +void maskedL2NN(raft::device_resources const& handle, raft::distance::MaskedL2NNParams params, raft::device_matrix_view x, raft::device_matrix_view y,