From 8d3e8a0ce255710c4154693a859b66fa722a4d36 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 22:21:22 +0200 Subject: [PATCH 01/93] contractions: Concentrate tile index calculations The calculation of the tile indices are now performed in ldgXY(). This will make it possible to remove all state related to the tile index out of the class in the next commit. Note that the calculation of the tile index can depend on which overloaded constructor is called(!) --- .../detail/pairwise_distance_base.cuh | 27 ++---- .../raft/linalg/detail/contractions.cuh | 84 +++++++++++++------ .../knn/detail/epsilon_neighborhood.cuh | 10 +-- 3 files changed, 72 insertions(+), 49 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 445b4bac52..c401b90601 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -138,27 +138,14 @@ struct PairwiseDistances : public BaseClass { DI void updateIndicesY() { const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; + this->increment_grid_idx_n(stride); } DI void updateIndicesXY() { const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; + this->increment_grid_idx_m(stride); + this->reset_grid_idx_n(); } DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) @@ -187,7 +174,7 @@ struct PairwiseDistances : public BaseClass { this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() @@ -197,15 +184,15 @@ struct PairwiseDistances : public BaseClass { accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration // This is needed for making sure next grid stride of // non-norm based metrics uses previously accumulated buffer so // it doesn't make shmem dirty until previous iteration // is complete. - this->pageRd ^= 1; + this->switch_read_buffer(); } DI void accumulate() diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index e247f39bc7..346ec34771 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -40,14 +40,15 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; - /** current thread's global mem row id for X data */ - IdxT xrowid; - /** current thread's global mem row id for Y data */ - IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x; + const DataT* x_base; /** global memory pointer to Y matrix */ - const DataT* y; + const DataT* y_base; + + /** Support variables to provide backward compatibility **/ + IdxT grid_idx_m = 0; + IdxT grid_idx_n = 0; + bool first_constructor_called; /** current thread's smem row id */ int srowid; @@ -94,10 +95,8 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), - yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), - x(_x + xrowid * lda), - y(_y + yrowid * ldb), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -105,7 +104,8 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0) + pageRd(0), + first_constructor_called(true) { } @@ -133,6 +133,8 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -140,19 +142,9 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0) + pageRd(0), + first_constructor_called(false) { - if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; - x = _x + xrowid * lda; - y = _y + yrowid * ldb; - } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; - x = _x + xrowid + srowid * lda; - y = _y + yrowid + srowid * ldb; - } } protected: @@ -166,6 +158,12 @@ struct Contractions_NT { ldgY(kidx); } + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) + { + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx); + } + /** * @brief Store current block of X/Y from registers to smem * @param[in] kidx current start index of k to be loaded @@ -186,9 +184,35 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } + DI void increment_grid_idx_m(IdxT by) { grid_idx_m += by; } + + DI void increment_grid_idx_n(IdxT by) { grid_idx_n += by; } + + DI void reset_grid_idx_n() { grid_idx_n = 0; } + + DI void switch_read_buffer() { this->pageRd ^= 1; } + + DI void switch_write_buffer() { this->pageWr ^= 1; } + private: DI void ldgX(IdxT kidx) { + // Backward compatible way to determine the tile index. This depends on + // whether the first or the second constructor was called. The first + // constructor is called in epsilon_neighborhood.cuh and the second + // constructor is called in pairwise_distance_base.cuh. + if (first_constructor_called) { + ldgX(IdxT(blockIdx.x) * P::Mblk, kidx); + } else { + ldgX(grid_idx_m + IdxT(blockIdx.y) * P::Mblk, kidx); + } + } + + DI void ldgX(IdxT tile_idx_m, IdxT kidx) + { + IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; + auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; + if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -222,6 +246,18 @@ struct Contractions_NT { DI void ldgY(IdxT kidx) { + if (first_constructor_called) { + ldgY(IdxT(blockIdx.y) * P::Nblk, kidx); + } else { + ldgY(grid_idx_n + IdxT(blockIdx.x) * P::Nblk, kidx); + } + } + + DI void ldgY(IdxT tile_idx_n, IdxT kidx) + { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + if (isRowMajor) { auto numRows = n; auto koffset = kidx + scolid; @@ -315,4 +351,4 @@ struct Contractions_NT { } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index e4843acee9..7616083796 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { private: DI void prolog() { - this->ldgXY(0); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() { for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx); accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration } From cb7baab5e501ae16841c9a82fc2c6af8d99521db Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 22:51:49 +0200 Subject: [PATCH 02/93] pairwise_distance_base: Remove all ldgXY(0) calls This commit moves all grid and tile indexing logic into the caller. Contractions_NT is now only responsible for *intra*-tile indexing. Due to the complexity of the epilog function, the ldgNextGridStride function is not yet called from within the main loop. That is the next goal so that we have all the grid and tile indexing localized in the loop. --- .../detail/pairwise_distance_base.cuh | 121 ++++++++++-------- .../raft/linalg/detail/contractions.cuh | 45 +------ 2 files changed, 67 insertions(+), 99 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index c401b90601..15bf334ffb 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -87,6 +87,12 @@ struct PairwiseDistances : public BaseClass { FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; + + const IdxT grid_stride_m; + const IdxT grid_stride_n; + const IdxT grid_offset_m; + const IdxT grid_offset_n; + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; public: @@ -116,53 +122,63 @@ struct PairwiseDistances : public BaseClass { core_op(_core_op), epilog_op(_epilog_op), fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op) + rowEpilog_op(_rowEpilog_op), + grid_stride_m(P::Nblk * gridDim.y), + grid_stride_n(P::Mblk * gridDim.x), + grid_offset_m(P::Mblk * blockIdx.y), + grid_offset_n(P::Nblk * blockIdx.x) { } DI void run() { - for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; - gridStrideY += P::Mblk * gridDim.y) { - for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { - prolog(gridStrideX, gridStrideY); - loop(); - epilog(gridStrideX, gridStrideY); + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + this->ldgXY(tile_idx_m, grid_offset_n, 0); + for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + accumulate(); // last iteration + // This is needed for making sure next grid stride of + // non-norm based metrics uses previously accumulated buffer so + // it doesn't make shmem dirty until previous iteration + // is complete. + this->switch_read_buffer(); + + epilog(tile_idx_n, tile_idx_m); } - rowEpilog_op(gridStrideY); + rowEpilog_op(tile_idx_m); } } private: - DI void updateIndicesY() - { - const auto stride = P::Nblk * gridDim.x; - this->increment_grid_idx_n(stride); - } - - DI void updateIndicesXY() - { - const auto stride = P::Mblk * gridDim.y; - this->increment_grid_idx_m(stride); - this->reset_grid_idx_n(); - } - - DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) + DI void ldgNextGridStride(IdxT tile_idx_n, IdxT tile_idx_m) { // Fetch next grid stride ldg if within range - if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { - updateIndicesY(); - this->ldgXY(0); - } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { - updateIndicesXY(); - this->ldgXY(0); + const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; + const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; + if ((next_tile_tile_idx_n) < this->n) { + this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); + } else if ((next_tile_tile_idx_m) < this->m) { + this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); } } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) + DI void prolog(IdxT tile_idx_n, IdxT tile_idx_m) { - if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); } + if (tile_idx_n == blockIdx.x * P::Nblk) { this->ldgXY(0); } #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -177,22 +193,15 @@ struct PairwiseDistances : public BaseClass { this->switch_write_buffer(); } - DI void loop() - { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); + DI void reset_accumulator() { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } } - accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. - this->switch_read_buffer(); } DI void accumulate() @@ -213,22 +222,22 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT gridStrideX, IdxT gridStrideY) + DI void epilog(IdxT tile_idx_n, IdxT tile_idx_m) { if (useNorms) { DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); DataT* syNorm = (&sxNorm[P::Mblk]); // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * P::Nblk) { + if (tile_idx_n == blockIdx.x * P::Nblk) { for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; + auto idx = tile_idx_m + i; sxNorm[i] = idx < this->m ? xn[idx] : 0; } } for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = gridStrideX + i; + auto idx = tile_idx_n + i; syNorm[i] = idx < this->n ? yn[idx] : 0; } @@ -245,17 +254,17 @@ struct PairwiseDistances : public BaseClass { } // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); + ldgNextGridStride(tile_idx_n, tile_idx_m); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + ldgNextGridStride(tile_idx_n, tile_idx_m); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { - IdxT starty = gridStrideY + this->accrowid; - IdxT startx = gridStrideX + this->acccolid; + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 346ec34771..f2d71117f7 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -45,11 +45,6 @@ struct Contractions_NT { /** global memory pointer to Y matrix */ const DataT* y_base; - /** Support variables to provide backward compatibility **/ - IdxT grid_idx_m = 0; - IdxT grid_idx_n = 0; - bool first_constructor_called; - /** current thread's smem row id */ int srowid; /** current thread's smem column id */ @@ -104,8 +99,7 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0), - first_constructor_called(true) + pageRd(0) { } @@ -142,8 +136,7 @@ struct Contractions_NT { sx((DataT*)_smem), sy(&(sx[P::SmemPageX])), pageWr(0), - pageRd(0), - first_constructor_called(false) + pageRd(0) { } @@ -152,12 +145,6 @@ struct Contractions_NT { * @brief Load current block of X/Y from global memory to registers * @param[in] kidx current start index of k to be loaded */ - DI void ldgXY(IdxT kidx) - { - ldgX(kidx); - ldgY(kidx); - } - DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) { ldgX(tile_idx_m, kidx); @@ -184,30 +171,11 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } - DI void increment_grid_idx_m(IdxT by) { grid_idx_m += by; } - - DI void increment_grid_idx_n(IdxT by) { grid_idx_n += by; } - - DI void reset_grid_idx_n() { grid_idx_n = 0; } - DI void switch_read_buffer() { this->pageRd ^= 1; } DI void switch_write_buffer() { this->pageWr ^= 1; } private: - DI void ldgX(IdxT kidx) - { - // Backward compatible way to determine the tile index. This depends on - // whether the first or the second constructor was called. The first - // constructor is called in epsilon_neighborhood.cuh and the second - // constructor is called in pairwise_distance_base.cuh. - if (first_constructor_called) { - ldgX(IdxT(blockIdx.x) * P::Mblk, kidx); - } else { - ldgX(grid_idx_m + IdxT(blockIdx.y) * P::Mblk, kidx); - } - } - DI void ldgX(IdxT tile_idx_m, IdxT kidx) { IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; @@ -244,15 +212,6 @@ struct Contractions_NT { } } - DI void ldgY(IdxT kidx) - { - if (first_constructor_called) { - ldgY(IdxT(blockIdx.y) * P::Nblk, kidx); - } else { - ldgY(grid_idx_n + IdxT(blockIdx.x) * P::Nblk, kidx); - } - } - DI void ldgY(IdxT tile_idx_n, IdxT kidx) { IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; From 066bf3b22d90412e280e77492e42160e0fa011ce Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 2 Sep 2022 23:40:32 +0200 Subject: [PATCH 03/93] pairwise_distance_base: Move all logic into run loop This commit removes the epilog function and moves its functionality into the run loop. The next step might be to see if the ldgNextGridStride() method has to be called the current location, or if performance is the same if its called at the start of the loop. --- .../detail/pairwise_distance_base.cuh | 128 ++++++++---------- 1 file changed, 57 insertions(+), 71 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 15bf334ffb..78effeca6d 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -87,7 +87,6 @@ struct PairwiseDistances : public BaseClass { FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; - const IdxT grid_stride_m; const IdxT grid_stride_n; const IdxT grid_offset_m; @@ -141,14 +140,14 @@ struct PairwiseDistances : public BaseClass { this->switch_write_buffer(); for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - accumulate(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration // This is needed for making sure next grid stride of @@ -157,14 +156,25 @@ struct PairwiseDistances : public BaseClass { // is complete. this->switch_read_buffer(); - epilog(tile_idx_n, tile_idx_m); + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, regxn, regyn); + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } + if (writeOut) { store_output(tile_idx_m, tile_idx_n); } } rowEpilog_op(tile_idx_m); } } private: - DI void ldgNextGridStride(IdxT tile_idx_n, IdxT tile_idx_m) + DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) { // Fetch next grid stride ldg if within range const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; @@ -176,24 +186,8 @@ struct PairwiseDistances : public BaseClass { } } - DI void prolog(IdxT tile_idx_n, IdxT tile_idx_m) + DI void reset_accumulator() { - if (tile_idx_n == blockIdx.x * P::Nblk) { this->ldgXY(0); } - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; - } - } - - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - } - - DI void reset_accumulator() { // Reset accumulator registers to zero. #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -222,60 +216,52 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT tile_idx_n, IdxT tile_idx_m) + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) { - if (useNorms) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (tile_idx_n == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (tile_idx_n == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; } + } - __syncthreads(); + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_n, tile_idx_m); - epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_n, tile_idx_m); - epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + } - if (writeOut) { - IdxT starty = tile_idx_m + this->accrowid; - IdxT startx = tile_idx_n + this->acccolid; + DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) + { + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); } } } From a15d5fc1ecad74f04190ee351436b79db3127b8b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 5 Oct 2022 16:17:56 +0200 Subject: [PATCH 04/93] pairwise_distance_base: Fix typo This results in subtle issues with non-square KernelPolicy, as found in fusedL2KNN. --- cpp/include/raft/distance/detail/pairwise_distance_base.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 78effeca6d..5da3b6f8c1 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -122,8 +122,8 @@ struct PairwiseDistances : public BaseClass { epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), - grid_stride_m(P::Nblk * gridDim.y), - grid_stride_n(P::Mblk * gridDim.x), + grid_stride_m(P::Mblk * gridDim.y), + grid_stride_n(P::Nblk * gridDim.x), grid_offset_m(P::Mblk * blockIdx.y), grid_offset_n(P::Nblk * blockIdx.x) { From 71c6da65ca544fe29dc0ab7f10e90b0cc72dca09 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 11 Jan 2023 16:01:18 +0100 Subject: [PATCH 05/93] Remove deprecated header --- cpp/include/raft/distance/distance.hpp | 23 -------------- cpp/include/raft/distance/distance_type.hpp | 27 ---------------- cpp/include/raft/distance/fused_l2_nn.hpp | 31 ------------------- cpp/include/raft/distance/specializations.hpp | 31 ------------------- 4 files changed, 112 deletions(-) delete mode 100644 cpp/include/raft/distance/distance.hpp delete mode 100644 cpp/include/raft/distance/distance_type.hpp delete mode 100644 cpp/include/raft/distance/fused_l2_nn.hpp delete mode 100644 cpp/include/raft/distance/specializations.hpp diff --git a/cpp/include/raft/distance/distance.hpp b/cpp/include/raft/distance/distance.hpp deleted file mode 100644 index e5d39be86b..0000000000 --- a/cpp/include/raft/distance/distance.hpp +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * This file is deprecated and will be removed in release 22.06. - * Please use the cuh version instead. - */ - -#pragma once - -#include \ No newline at end of file diff --git a/cpp/include/raft/distance/distance_type.hpp b/cpp/include/raft/distance/distance_type.hpp deleted file mode 100644 index f6eb4614f9..0000000000 --- a/cpp/include/raft/distance/distance_type.hpp +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * This file is deprecated and will be removed at some point in a future release. - * Please use `raft/distance/distance_types.hpp` instead. - */ - -#pragma once - -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use distance_types.hpp instead.") - -#include \ No newline at end of file diff --git a/cpp/include/raft/distance/fused_l2_nn.hpp b/cpp/include/raft/distance/fused_l2_nn.hpp deleted file mode 100644 index 74ad0974f4..0000000000 --- a/cpp/include/raft/distance/fused_l2_nn.hpp +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * This file is deprecated and will be removed in release 22.06. - * Please use the cuh version instead. - */ - -/** - * DISCLAIMER: this file is deprecated: use fused_l2_nn.cuh instead - */ - -#pragma once - -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the cuh version instead.") - -#include "fused_l2_nn.cuh" diff --git a/cpp/include/raft/distance/specializations.hpp b/cpp/include/raft/distance/specializations.hpp deleted file mode 100644 index 04afb73036..0000000000 --- a/cpp/include/raft/distance/specializations.hpp +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * This file is deprecated and will be removed in release 22.06. - * Please use the cuh version instead. - */ - -/** - * DISCLAIMER: this file is deprecated: use specializations.cuh instead - */ - -#pragma once - -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the cuh version instead.") - -#include "specializations.cuh" From 4bbedf660d8e1f2c01ab2bd8066a96cc7bd40307 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 12 Jan 2023 13:50:51 +0100 Subject: [PATCH 06/93] Replace lambdas by raft::void_op --- cpp/include/raft/distance/detail/pairwise_distance_base.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 5da3b6f8c1..140664f394 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -328,7 +328,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) FinalLambda fin_op) { extern __shared__ char smem[]; - auto rowEpilog = [] __device__(IdxT starty) { return; }; + auto rowEpilog = raft::void_op(); PairwiseDistances Date: Thu, 12 Jan 2023 17:14:05 +0100 Subject: [PATCH 07/93] Use an operator for L1 distance --- .../distance/detail/distance_operators.cuh | 51 ++++++++ cpp/include/raft/distance/detail/l1.cuh | 48 ++----- .../distance/detail/pairwise_distance_op.cuh | 118 ++++++++++++++++++ 3 files changed, 180 insertions(+), 37 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_operators.cuh create mode 100644 cpp/include/raft/distance/detail/pairwise_distance_op.cuh diff --git a/cpp/include/raft/distance/detail/distance_operators.cuh b/cpp/include/raft/distance/detail/distance_operators.cuh new file mode 100644 index 0000000000..4abaeaaf8b --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_operators.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft::distance::detail { + + +// Describes the computation the l1 distance +struct l1_distance_op { + // Whether norms of data should be loaded. + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() { + return Policy::SmemSize; + } + + template + DI void core(AccT & acc, DataT & x, DataT & y) const { + acc += raft::abs(x - y); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) const { + return; + }; + +}; + +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index bf10651b60..8eee0ae220 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -16,11 +16,15 @@ #pragma once #include +#include +#include namespace raft { namespace distance { namespace detail { + + /** * @brief the L1 distance matrix calculation implementer * It computes the following equation: cij = op(ai-bj) @@ -69,45 +73,15 @@ static void l1Impl(const DataT* x, dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::abs(x - y); - acc += diff; - }; + l1_distance_op distance_op{}; - // epilogue operation lambda for final value calculation - auto epilog_lambda = raft::void_op(); + using PCT = params_CT; - if (isRowMajor) { - auto l1RowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, l1RowMajor); - - l1RowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto l1ColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, l1ColMajor); - l1ColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } + auto kernel = pairwiseDistanceOpKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, kernel); + + kernel<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_op.cuh b/cpp/include/raft/distance/detail/pairwise_distance_op.cuh new file mode 100644 index 0000000000..91c66a2217 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_distance_op.cuh @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::distance::detail { + + + +template +struct params_CT { + using DataT = data_type; + using AccT = accumulate_type; + using OutT = out_type; + using IdxT = index_type; + + using PolicyT = policy; + + using opT = op_type; + using FinOpT = final_op_type; + static constexpr bool is_row_major = row_major; +}; + +template +__global__ __launch_bounds__(PCT::PolicyT::Nthreads, 2) + + void pairwiseDistanceOpKernel( + const typename PCT::DataT* x, + const typename PCT::DataT* y, + const typename PCT::DataT* _xn, + const typename PCT::DataT* _yn, + typename PCT::IdxT m, + typename PCT::IdxT n, + typename PCT::IdxT k, + typename PCT::IdxT lda, + typename PCT::IdxT ldb, + typename PCT::IdxT ldd, + typename PCT::OutT* dOutput, + typename PCT::opT distance_op, + typename PCT::FinOpT fin_op) +{ + using AccT = typename PCT::AccT; + using DataT = typename PCT::DataT; + using OutT = typename PCT::OutT; + using IdxT = typename PCT::IdxT; + + using Policy = typename PCT::PolicyT; + + // Instantiate PCT to access constexpr members. + PCT compile_time_params{}; + + extern __shared__ char smem[]; + + // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) + auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { + // use .template to disambiguate (See: https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template core(acc, x, y); + }; + auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); + }; + + // No support for row_epilog_op. + auto row_epilog_op = raft::void_op(); + // Always write output + constexpr bool write_out = true; + constexpr bool use_norms = distance_op.use_norms; + PairwiseDistances + obj( + x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, epilog_op, fin_op, row_epilog_op); + obj.run(); + +} +}; // namespace detail From 3e3478b05c2afb145a4dc5f22318b31d9c868848 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 12 Jan 2023 18:25:56 +0100 Subject: [PATCH 08/93] Add launch function This is more general than just for L1. Making use of it more is work in progress. --- cpp/include/raft/distance/detail/l1.cuh | 45 +++++++++++++++++++------ 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 8eee0ae220..2ad6895b27 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -24,6 +24,35 @@ namespace distance { namespace detail { +template +static void distance_matrix_launch( + typename PCT::opT distance_op, + typename PCT::FinOpT fin_op, + const typename PCT::DataT* x, + const typename PCT::DataT* y, + const typename PCT::DataT* _xn, + const typename PCT::DataT* _yn, + typename PCT::IdxT m, + typename PCT::IdxT n, + typename PCT::IdxT k, + typename PCT::IdxT lda, + typename PCT::IdxT ldb, + typename PCT::IdxT ldd, + typename PCT::OutT* dOutput, + cudaStream_t stream) +{ + using Policy = typename PCT::PolicyT; + + dim3 blk(Policy::Nthreads); + size_t smem_size = distance_op.template shared_mem_size(); + dim3 grid = launchConfigGenerator(m, n, smem_size, pairwiseDistanceOpKernel); + + pairwiseDistanceOpKernel<<>>( + x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); + + RAFT_CUDA_TRY(cudaGetLastError()); + +} /** * @brief the L1 distance matrix calculation implementer @@ -68,22 +97,18 @@ static void l1Impl(const DataT* x, { typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type KPolicy; - dim3 blk(KPolicy::Nthreads); - l1_distance_op distance_op{}; using PCT = params_CT; - auto kernel = pairwiseDistanceOpKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, kernel); - - kernel<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); - - RAFT_CUDA_TRY(cudaGetLastError()); + distance_matrix_launch( + distance_op, fin_op, // Operations + x, y, nullptr, nullptr, // Input data + m, n, k, lda, ldb, ldd, // Dimensions + dOutput, // Output data + stream); // CUDA stream } template Date: Fri, 13 Jan 2023 11:58:43 +0100 Subject: [PATCH 09/93] l1: Replace run-time -> compile-time dispatch --- cpp/include/raft/distance/detail/l1.cuh | 246 ++++++++---------- .../distance/detail/pairwise_distance_op.cuh | 163 +++++++++--- 2 files changed, 235 insertions(+), 174 deletions(-) diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 2ad6895b27..7645421fbf 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -15,182 +15,156 @@ */ #pragma once +#include #include #include -#include namespace raft { namespace distance { namespace detail { - template -static void distance_matrix_launch( - typename PCT::opT distance_op, - typename PCT::FinOpT fin_op, - const typename PCT::DataT* x, - const typename PCT::DataT* y, - const typename PCT::DataT* _xn, - const typename PCT::DataT* _yn, - typename PCT::IdxT m, - typename PCT::IdxT n, - typename PCT::IdxT k, - typename PCT::IdxT lda, - typename PCT::IdxT ldb, - typename PCT::IdxT ldd, - typename PCT::OutT* dOutput, - cudaStream_t stream) +static void distance_matrix_launch(typename PCT::opT distance_op, + typename PCT::FinOpT fin_op, + const typename PCT::DataT* x, + const typename PCT::DataT* y, + const typename PCT::DataT* _xn, + const typename PCT::DataT* _yn, + typename PCT::IdxT m, + typename PCT::IdxT n, + typename PCT::IdxT k, + typename PCT::IdxT lda, + typename PCT::IdxT ldb, + typename PCT::IdxT ldd, + typename PCT::OutT* dOutput, + cudaStream_t stream) { using Policy = typename PCT::PolicyT; dim3 blk(Policy::Nthreads); size_t smem_size = distance_op.template shared_mem_size(); - dim3 grid = launchConfigGenerator(m, n, smem_size, pairwiseDistanceOpKernel); + dim3 grid = launchConfigGenerator(m, n, smem_size, pairwiseDistanceOpKernel); pairwiseDistanceOpKernel<<>>( x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); - } -/** - * @brief the L1 distance matrix calculation implementer - * It computes the following equation: cij = op(ai-bj) - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix - * @param fin_op the final gemm epilogue lambda - */ -template -static void l1Impl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) +// Determine the largest number of elements that can be loaded in one +// instruction without causing misalignment errors. +template +int max_aligned_load(const DataT* x, const DataT* y, int ldx, int ldy) { - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type KPolicy; + auto base_x = reinterpret_cast(x); + auto base_y = reinterpret_cast(y); + size_t stride_X = sizeof(DataT) * ldx; // stride in bytes + size_t stride_Y = sizeof(DataT) * ldy; // stride in bytes - l1_distance_op distance_op{}; + bool base_16B_aligned = base_x % 16 == 0 && base_y % 16 == 0; + bool base_8B_aligned = base_x % 8 == 0 && base_y % 8 == 0; - using PCT = params_CT; + bool stride_16B_aligned = stride_X % 16 == 0 && stride_Y % 16 == 0; + bool stride_8B_aligned = stride_X % 8 == 0 && stride_Y % 8 == 0; - distance_matrix_launch( - distance_op, fin_op, // Operations - x, y, nullptr, nullptr, // Input data - m, n, k, lda, ldb, ldd, // Dimensions - dOutput, // Output data - stream); // CUDA stream + if (16 % sizeof(DataT) == 0 && base_16B_aligned && stride_16B_aligned) { + return 16 / sizeof(DataT); + } else if (8 % sizeof(DataT) == 0 && base_8B_aligned && stride_8B_aligned) { + return 8 / sizeof(DataT); + } else { + return 1; + } } -template -void l1(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) + typename FinOpT, + typename IdxT = int> +void distance_matrix_dispatch(opT distance_op, + int m_, + int n_, + int k_, + const DataT* x_, + const DataT* y_, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) { - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - l1Impl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - l1Impl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + // Determine leading dimensions and possibly flip order of passing x and y if + // column_major. + // + // ldx, ldy, and ld_out are the leading dimensions of x, y, and out + const DataT* x; + const DataT* y; + int ldx, ldy, ld_out; + int m, n, k; + if (is_row_major) { + // Pass x, y, m, n, k in order + x = x_, y = y_; + m = m_, n = n_, k = k_; + ldx = k_, ldy = k_, ld_out = n_; } else { - l1Impl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + // Flip x, y, and m, n, k. + x = y_, y = x_; + m = n_, n = m_, k = k_; + ldx = n_, ldy = m_, ld_out = m_; } + + int vectorized_load_num_elem = max_aligned_load(x, y, ldx, ldy); + + // We dispatch based on + // - vectorized_load_num_elem + // - is_row_major + + // Create run-time parameter struct that does the dispatching + using PRT = params_RT; + PRT run_time_params{vectorized_load_num_elem, is_row_major}; + + // Turn run-time parameters into compile-time parameters. + bool dispatch_success = run_time_params.dispatch_with_compile_time_params( + // We pass a lambda that receives the compile-time parameters and can use these + // to call the correct kernel. + [&](auto compile_time_params) { + // compile_time_params is an empty struct that we can convert back to a type + // using decltype. + return distance_matrix_launch( + distance_op, + fin_op, + x, + y, + nullptr, + nullptr, // TODO: use _xn, _yn for non-l1 distances + m, + n, + k, + ldx, + ldy, + ld_out, + out, + stream); + }); } -/** - * @brief the L1 distance matrix calculation - * It computes the following equation: cij = op(ai-bj) - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template +template void l1Impl(int m, int n, int k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, cudaStream_t stream, - bool isRowMajor) + bool is_row_major) { - typedef std::is_same is_bool; - typedef typename std::conditional::type L1OutType; - Index_ lda, ldb, ldd; - L1OutType* pDcast = reinterpret_cast(pD); - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - l1( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + l1_distance_op distance_op{}; - } else { - lda = n, ldb = m, ldd = m; - l1( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); - } + distance_matrix_dispatch( + distance_op, m, n, k, x, y, out, fin_op, stream, is_row_major); } + } // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/pairwise_distance_op.cuh b/cpp/include/raft/distance/detail/pairwise_distance_op.cuh index 91c66a2217..2b776a378f 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_op.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,8 +25,6 @@ namespace raft::distance::detail { - - template struct params_CT { using DataT = data_type; - using AccT = accumulate_type; - using OutT = out_type; - using IdxT = index_type; - + using AccT = accumulate_type; + using OutT = out_type; + using IdxT = index_type; using PolicyT = policy; - - using opT = op_type; - using FinOpT = final_op_type; + using opT = op_type; + using FinOpT = final_op_type; static constexpr bool is_row_major = row_major; }; +template +struct params_RT { + int vectorized_load_num_elem = 1; + bool row_major = true; + + // Turn run-time parameters into compile-time parameters. + // Call the provided function f with these compile-time parameters. + // Returns false if dispatch fails, i.e., if there is no implementation + // for the given runtime parameters. + template + bool dispatch_with_compile_time_params(F&& f) const + { + return convert_vectorized_load_num_elem(f); + } + + // Step 1: convert alignment into a compile time constant + template + bool convert_vectorized_load_num_elem(F&& f) const + { + bool fail = false; + switch (vectorized_load_num_elem) { + case 1: return layout<1>(f); + case 2: return layout<2>(f); + case 4: + // We need "if constexpr" here, to prevent the if else to be delegated + // to run time (in which case a kernel that loads 4 doubles is + // generated). This is especially important, because that leads to + // compilation errors (which we want to avoid). + if constexpr (sizeof(DataT) < 8) { + return layout<4>(f); + } else { + // For doubles, load at most 2 elements in one instruction. + return layout<2>(f); + } + default: return fail; + }; + } + + // Step 2: convert layout into a compile time constant + template + bool layout(F&& f) const + { + if (row_major) { + return to_compile_time_params(f); + } else { + return to_compile_time_params(f); + } + } + + // Step 3: convert compile-time constant into compile-time parameter struct and invoke + // function f with these compile time parameters. + template + bool to_compile_time_params(F&& f) const + { + // Determine kernel policy using vec_len and layout + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename std::conditional::type Policy; + + // Create compile-time parameter type and instantiate a struct; + using PCT = params_CT; + PCT compile_time_params{}; + + // Dispatch to f + f(compile_time_params); + + bool dispatch_success = true; + return dispatch_success; + } +}; + template __global__ __launch_bounds__(PCT::PolicyT::Nthreads, 2) - void pairwiseDistanceOpKernel( - const typename PCT::DataT* x, - const typename PCT::DataT* y, - const typename PCT::DataT* _xn, - const typename PCT::DataT* _yn, - typename PCT::IdxT m, - typename PCT::IdxT n, - typename PCT::IdxT k, - typename PCT::IdxT lda, - typename PCT::IdxT ldb, - typename PCT::IdxT ldd, - typename PCT::OutT* dOutput, - typename PCT::opT distance_op, - typename PCT::FinOpT fin_op) + void pairwiseDistanceOpKernel(const typename PCT::DataT* x, + const typename PCT::DataT* y, + const typename PCT::DataT* _xn, + const typename PCT::DataT* _yn, + typename PCT::IdxT m, + typename PCT::IdxT n, + typename PCT::IdxT k, + typename PCT::IdxT lda, + typename PCT::IdxT ldb, + typename PCT::IdxT ldd, + typename PCT::OutT* dOutput, + typename PCT::opT distance_op, + typename PCT::FinOpT fin_op) { - using AccT = typename PCT::AccT; + using AccT = typename PCT::AccT; using DataT = typename PCT::DataT; - using OutT = typename PCT::OutT; - using IdxT = typename PCT::IdxT; + using OutT = typename PCT::OutT; + using IdxT = typename PCT::IdxT; using Policy = typename PCT::PolicyT; - // Instantiate PCT to access constexpr members. + // Instantiate compile time parameters to access constexpr members. PCT compile_time_params{}; extern __shared__ char smem[]; // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - // use .template to disambiguate (See: https://en.cppreference.com/w/cpp/language/dependent_name) + // use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) distance_op.template core(acc, x, y); }; auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + distance_op.template epilog( + acc, regxn, regyn, gridStrideX, gridStrideY); }; // No support for row_epilog_op. @@ -110,9 +183,23 @@ __global__ __launch_bounds__(PCT::PolicyT::Nthreads, 2) decltype(row_epilog_op), compile_time_params.is_row_major, write_out> - obj( - x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, epilog_op, fin_op, row_epilog_op); + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + _xn, + _yn, + dOutput, + smem, + core_op, + epilog_op, + fin_op, + row_epilog_op); obj.run(); - } -}; // namespace detail + +}; // namespace raft::distance::detail From b23205707497252512d8007d6c434c096977b89e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 13:18:34 +0100 Subject: [PATCH 10/93] pairwise matrix: move files into subdirectories --- .../l1.cuh} | 24 +-- cpp/include/raft/distance/detail/l1.cuh | 134 +------------ .../dispatch.cuh} | 186 ++++++++++-------- .../detail/pairwise_matrix/kernel_sm60.cuh | 134 +++++++++++++ 4 files changed, 251 insertions(+), 227 deletions(-) rename cpp/include/raft/distance/detail/{distance_operators.cuh => distance_ops/l1.cuh} (73%) rename cpp/include/raft/distance/detail/{pairwise_distance_op.cuh => pairwise_matrix/dispatch.cuh} (52%) create mode 100644 cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh diff --git a/cpp/include/raft/distance/detail/distance_operators.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh similarity index 73% rename from cpp/include/raft/distance/detail/distance_operators.cuh rename to cpp/include/raft/distance/detail/distance_ops/l1.cuh index 4abaeaaf8b..08ca313fe2 100644 --- a/cpp/include/raft/distance/detail/distance_operators.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,37 +15,37 @@ */ #pragma once -#include - -namespace raft::distance::detail { +namespace raft::distance::detail::ops { // Describes the computation the l1 distance struct l1_distance_op { - // Whether norms of data should be loaded. + // Do not load norms of data, the computation of L1 distance does not use them. static constexpr bool use_norms = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() { + constexpr size_t shared_mem_size() + { return Policy::SmemSize; } template - DI void core(AccT & acc, DataT & x, DataT & y) const { + DI void core(AccT& acc, DataT& x, DataT& y) const + { acc += raft::abs(x - y); }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, + DataT* regxn, + DataT* regyn, IdxT gridStrideX, - IdxT gridStrideY) const { + IdxT gridStrideY) const + { return; }; - }; -} // namespace raft::distance::detail +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 7645421fbf..a5f279d9a4 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -15,139 +15,13 @@ */ #pragma once -#include -#include -#include +#include "distance_ops/l1.cuh" +#include "pairwise_matrix/dispatch.cuh" namespace raft { namespace distance { namespace detail { -template -static void distance_matrix_launch(typename PCT::opT distance_op, - typename PCT::FinOpT fin_op, - const typename PCT::DataT* x, - const typename PCT::DataT* y, - const typename PCT::DataT* _xn, - const typename PCT::DataT* _yn, - typename PCT::IdxT m, - typename PCT::IdxT n, - typename PCT::IdxT k, - typename PCT::IdxT lda, - typename PCT::IdxT ldb, - typename PCT::IdxT ldd, - typename PCT::OutT* dOutput, - cudaStream_t stream) -{ - using Policy = typename PCT::PolicyT; - - dim3 blk(Policy::Nthreads); - size_t smem_size = distance_op.template shared_mem_size(); - dim3 grid = launchConfigGenerator(m, n, smem_size, pairwiseDistanceOpKernel); - - pairwiseDistanceOpKernel<<>>( - x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -// Determine the largest number of elements that can be loaded in one -// instruction without causing misalignment errors. -template -int max_aligned_load(const DataT* x, const DataT* y, int ldx, int ldy) -{ - auto base_x = reinterpret_cast(x); - auto base_y = reinterpret_cast(y); - size_t stride_X = sizeof(DataT) * ldx; // stride in bytes - size_t stride_Y = sizeof(DataT) * ldy; // stride in bytes - - bool base_16B_aligned = base_x % 16 == 0 && base_y % 16 == 0; - bool base_8B_aligned = base_x % 8 == 0 && base_y % 8 == 0; - - bool stride_16B_aligned = stride_X % 16 == 0 && stride_Y % 16 == 0; - bool stride_8B_aligned = stride_X % 8 == 0 && stride_Y % 8 == 0; - - if (16 % sizeof(DataT) == 0 && base_16B_aligned && stride_16B_aligned) { - return 16 / sizeof(DataT); - } else if (8 % sizeof(DataT) == 0 && base_8B_aligned && stride_8B_aligned) { - return 8 / sizeof(DataT); - } else { - return 1; - } -} - -template -void distance_matrix_dispatch(opT distance_op, - int m_, - int n_, - int k_, - const DataT* x_, - const DataT* y_, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Determine leading dimensions and possibly flip order of passing x and y if - // column_major. - // - // ldx, ldy, and ld_out are the leading dimensions of x, y, and out - const DataT* x; - const DataT* y; - int ldx, ldy, ld_out; - int m, n, k; - if (is_row_major) { - // Pass x, y, m, n, k in order - x = x_, y = y_; - m = m_, n = n_, k = k_; - ldx = k_, ldy = k_, ld_out = n_; - } else { - // Flip x, y, and m, n, k. - x = y_, y = x_; - m = n_, n = m_, k = k_; - ldx = n_, ldy = m_, ld_out = m_; - } - - int vectorized_load_num_elem = max_aligned_load(x, y, ldx, ldy); - - // We dispatch based on - // - vectorized_load_num_elem - // - is_row_major - - // Create run-time parameter struct that does the dispatching - using PRT = params_RT; - PRT run_time_params{vectorized_load_num_elem, is_row_major}; - - // Turn run-time parameters into compile-time parameters. - bool dispatch_success = run_time_params.dispatch_with_compile_time_params( - // We pass a lambda that receives the compile-time parameters and can use these - // to call the correct kernel. - [&](auto compile_time_params) { - // compile_time_params is an empty struct that we can convert back to a type - // using decltype. - return distance_matrix_launch( - distance_op, - fin_op, - x, - y, - nullptr, - nullptr, // TODO: use _xn, _yn for non-l1 distances - m, - n, - k, - ldx, - ldy, - ld_out, - out, - stream); - }); -} - template void l1Impl(int m, int n, @@ -159,9 +33,9 @@ void l1Impl(int m, cudaStream_t stream, bool is_row_major) { - l1_distance_op distance_op{}; + ops::l1_distance_op distance_op{}; - distance_matrix_dispatch( + distance_matrix_dispatch( distance_op, m, n, k, x, y, out, fin_op, stream, is_row_major); } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_op.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh similarity index 52% rename from cpp/include/raft/distance/detail/pairwise_distance_op.cuh rename to cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 2b776a378f..d2c8dfe660 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_op.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -14,14 +14,9 @@ * limitations under the License. */ #pragma once -#include -#include -#include -#include -#include -#include -#include +#include +#include "kernel_sm60.cuh" namespace raft::distance::detail { @@ -36,11 +31,11 @@ template struct params_CT { - using DataT = data_type; - using AccT = accumulate_type; - using OutT = out_type; - using IdxT = index_type; - using PolicyT = policy; + using DataT = data_type; + using AccT = accumulate_type; + using OutT = out_type; + using IdxT = index_type; + using PolicyT = policy; using opT = op_type; using FinOpT = final_op_type; static constexpr bool is_row_major = row_major; @@ -122,84 +117,105 @@ struct params_RT { } }; -template -__global__ __launch_bounds__(PCT::PolicyT::Nthreads, 2) - - void pairwiseDistanceOpKernel(const typename PCT::DataT* x, - const typename PCT::DataT* y, - const typename PCT::DataT* _xn, - const typename PCT::DataT* _yn, - typename PCT::IdxT m, - typename PCT::IdxT n, - typename PCT::IdxT k, - typename PCT::IdxT lda, - typename PCT::IdxT ldb, - typename PCT::IdxT ldd, - typename PCT::OutT* dOutput, - typename PCT::opT distance_op, - typename PCT::FinOpT fin_op) +// Determine the largest number of elements that can be loaded in one +// instruction without causing misalignment errors. +template +int max_aligned_load(const DataT* x, const DataT* y, int ldx, int ldy) { - using AccT = typename PCT::AccT; - using DataT = typename PCT::DataT; - using OutT = typename PCT::OutT; - using IdxT = typename PCT::IdxT; - - using Policy = typename PCT::PolicyT; - - // Instantiate compile time parameters to access constexpr members. - PCT compile_time_params{}; - - extern __shared__ char smem[]; - - // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) - auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - // use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template core(acc, x, y); - }; - auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - distance_op.template epilog( - acc, regxn, regyn, gridStrideX, gridStrideY); - }; - - // No support for row_epilog_op. - auto row_epilog_op = raft::void_op(); - // Always write output - constexpr bool write_out = true; - constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances - obj(x, + auto base_x = reinterpret_cast(x); + auto base_y = reinterpret_cast(y); + size_t stride_X = sizeof(DataT) * ldx; // stride in bytes + size_t stride_Y = sizeof(DataT) * ldy; // stride in bytes + + bool base_16B_aligned = base_x % 16 == 0 && base_y % 16 == 0; + bool base_8B_aligned = base_x % 8 == 0 && base_y % 8 == 0; + + bool stride_16B_aligned = stride_X % 16 == 0 && stride_Y % 16 == 0; + bool stride_8B_aligned = stride_X % 8 == 0 && stride_Y % 8 == 0; + + if (16 % sizeof(DataT) == 0 && base_16B_aligned && stride_16B_aligned) { + return 16 / sizeof(DataT); + } else if (8 % sizeof(DataT) == 0 && base_8B_aligned && stride_8B_aligned) { + return 8 / sizeof(DataT); + } else { + return 1; + } +} + +template +void distance_matrix_dispatch(opT distance_op, + int m_, + int n_, + int k_, + const DataT* x_, + const DataT* y_, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Determine leading dimensions and possibly flip order of passing x and y if + // column_major. + // + // ldx, ldy, and ld_out are the leading dimensions of x, y, and out + const DataT* x; + const DataT* y; + int ldx, ldy, ld_out; + int m, n, k; + if (is_row_major) { + // Pass x, y, m, n, k in order + x = x_, y = y_; + m = m_, n = n_, k = k_; + ldx = k_, ldy = k_, ld_out = n_; + } else { + // Flip x, y, and m, n, k. + x = y_, y = x_; + m = n_, n = m_, k = k_; + ldx = n_, ldy = m_, ld_out = m_; + } + + int vectorized_load_num_elem = max_aligned_load(x, y, ldx, ldy); + + // We dispatch based on + // - vectorized_load_num_elem + // - is_row_major + + // Create run-time parameter struct that does the dispatching + using PRT = params_RT; + PRT run_time_params{vectorized_load_num_elem, is_row_major}; + + // Turn run-time parameters into compile-time parameters. + bool dispatch_success = run_time_params.dispatch_with_compile_time_params( + // We pass a lambda that receives the compile-time parameters and can use these + // to call the correct kernel. + [&](auto compile_time_params) { + // compile_time_params is an empty struct that we can convert back to a type + // using decltype. + return pairwise_matrix( + distance_op, + fin_op, + x, y, + nullptr, + nullptr, // TODO: use _xn, _yn for non-l1 distances m, n, k, - lda, - ldb, - ldd, - _xn, - _yn, - dOutput, - smem, - core_op, - epilog_op, - fin_op, - row_epilog_op); - obj.run(); + ldx, + ldy, + ld_out, + out, + stream); + }); + + if (!dispatch_success) { + // TODO + } } }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh new file mode 100644 index 0000000000..ec50f6cbbf --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include + +namespace raft::distance::detail { + +template +__global__ __launch_bounds__(PCT::PolicyT::Nthreads, 2) + + void pairwise_matrix_kernel(const typename PCT::DataT* x, + const typename PCT::DataT* y, + const typename PCT::DataT* _xn, + const typename PCT::DataT* _yn, + typename PCT::IdxT m, + typename PCT::IdxT n, + typename PCT::IdxT k, + typename PCT::IdxT lda, + typename PCT::IdxT ldb, + typename PCT::IdxT ldd, + typename PCT::OutT* dOutput, + typename PCT::opT distance_op, + typename PCT::FinOpT fin_op) +{ + using AccT = typename PCT::AccT; + using DataT = typename PCT::DataT; + using OutT = typename PCT::OutT; + using IdxT = typename PCT::IdxT; + + using Policy = typename PCT::PolicyT; + + // Instantiate compile time parameters to access constexpr members. + PCT compile_time_params{}; + + extern __shared__ char smem[]; + + // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) + auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { + // use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template core(acc, x, y); + }; + auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + distance_op.template epilog( + acc, regxn, regyn, gridStrideX, gridStrideY); + }; + + // No support for row_epilog_op. + auto row_epilog_op = raft::void_op(); + // Always write output + constexpr bool write_out = true; + constexpr bool use_norms = distance_op.use_norms; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + _xn, + _yn, + dOutput, + smem, + core_op, + epilog_op, + fin_op, + row_epilog_op); + obj.run(); +} + +template +static void pairwise_matrix(typename PCT::opT distance_op, + typename PCT::FinOpT fin_op, + const typename PCT::DataT* x, + const typename PCT::DataT* y, + const typename PCT::DataT* _xn, + const typename PCT::DataT* _yn, + typename PCT::IdxT m, + typename PCT::IdxT n, + typename PCT::IdxT k, + typename PCT::IdxT lda, + typename PCT::IdxT ldb, + typename PCT::IdxT ldd, + typename PCT::OutT* dOutput, + cudaStream_t stream) +{ + using Policy = typename PCT::PolicyT; + + dim3 blk(Policy::Nthreads); + size_t smem_size = distance_op.template shared_mem_size(); + dim3 grid = launchConfigGenerator(m, n, smem_size, pairwise_matrix_kernel); + + pairwise_matrix_kernel<<>>( + x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +}; // namespace raft::distance::detail From 06f6ffa26613e492ac996be7a176eaec35c16fd0 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 13:56:02 +0100 Subject: [PATCH 11/93] pairwise matrix: Untangle dispatching and kernel template parameters By adding yet another struct ^^ --- .../detail/pairwise_matrix/dispatch.cuh | 95 +++++++--------- .../detail/pairwise_matrix/kernel_sm60.cuh | 101 +++++++++++------- 2 files changed, 99 insertions(+), 97 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index d2c8dfe660..0e056405a1 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -20,37 +20,17 @@ namespace raft::distance::detail { -template -struct params_CT { - using DataT = data_type; - using AccT = accumulate_type; - using OutT = out_type; - using IdxT = index_type; - using PolicyT = policy; - using opT = op_type; - using FinOpT = final_op_type; - static constexpr bool is_row_major = row_major; -}; - -template -struct params_RT { +template +struct params_dispatch { int vectorized_load_num_elem = 1; bool row_major = true; + template + struct params_constexpr { + static constexpr int vec_len = vl; + static constexpr bool is_row_major = rm; + }; + // Turn run-time parameters into compile-time parameters. // Call the provided function f with these compile-time parameters. // Returns false if dispatch fails, i.e., if there is no implementation @@ -69,17 +49,7 @@ struct params_RT { switch (vectorized_load_num_elem) { case 1: return layout<1>(f); case 2: return layout<2>(f); - case 4: - // We need "if constexpr" here, to prevent the if else to be delegated - // to run time (in which case a kernel that loads 4 doubles is - // generated). This is especially important, because that leads to - // compilation errors (which we want to avoid). - if constexpr (sizeof(DataT) < 8) { - return layout<4>(f); - } else { - // For doubles, load at most 2 elements in one instruction. - return layout<2>(f); - } + case 4: return layout<4>(f); default: return fail; }; } @@ -100,14 +70,9 @@ struct params_RT { template bool to_compile_time_params(F&& f) const { - // Determine kernel policy using vec_len and layout - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; - // Create compile-time parameter type and instantiate a struct; - using PCT = params_CT; - PCT compile_time_params{}; + using ct_params_T = params_constexpr; + ct_params_T compile_time_params{}; // Dispatch to f f(compile_time_params); @@ -181,22 +146,38 @@ void distance_matrix_dispatch(opT distance_op, int vectorized_load_num_elem = max_aligned_load(x, y, ldx, ldy); - // We dispatch based on - // - vectorized_load_num_elem - // - is_row_major - - // Create run-time parameter struct that does the dispatching - using PRT = params_RT; - PRT run_time_params{vectorized_load_num_elem, is_row_major}; + // Create run-time parameter struct that does the dispatching. + // + // In addition to the template parameters of this function (IdxT, DataT, + // etc..), we explicitly dispatch based on: + params_dispatch run_time_params{ + vectorized_load_num_elem, // 1. num array elements per load instruction + is_row_major // 2. the layout x, y, and out + }; // Turn run-time parameters into compile-time parameters. bool dispatch_success = run_time_params.dispatch_with_compile_time_params( // We pass a lambda that receives the compile-time parameters and can use these // to call the correct kernel. - [&](auto compile_time_params) { - // compile_time_params is an empty struct that we can convert back to a type - // using decltype. - return pairwise_matrix( + [&](auto p) { + // p has two constexpr members: + // - vec_len + // - is_row_major + + // There is no instruction to load 4 doubles, so we catch this situation + // and load 2 doubles. + constexpr bool load_4_doubles = sizeof(DataT) > 4 && p.vec_len == 4; + constexpr int vec_len = (load_4_doubles) ? 2 : p.vec_len; + + // Determine kernel policy using vec_len and layout + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename std::conditional::type Policy; + + // Create compile-time template parameter + using KP_T = kernel_params_T; + + return pairwise_matrix( distance_op, fin_op, x, diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index ec50f6cbbf..fa30ff2c3e 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -17,38 +17,59 @@ #include #include -#include +#include // TODO: remove #include namespace raft::distance::detail { -template -__global__ __launch_bounds__(PCT::PolicyT::Nthreads, 2) - - void pairwise_matrix_kernel(const typename PCT::DataT* x, - const typename PCT::DataT* y, - const typename PCT::DataT* _xn, - const typename PCT::DataT* _yn, - typename PCT::IdxT m, - typename PCT::IdxT n, - typename PCT::IdxT k, - typename PCT::IdxT lda, - typename PCT::IdxT ldb, - typename PCT::IdxT ldd, - typename PCT::OutT* dOutput, - typename PCT::opT distance_op, - typename PCT::FinOpT fin_op) +template +struct kernel_params_T { + using DataT = data_type; + using AccT = accumulate_type; + using OutT = out_type; + using IdxT = index_type; + using PolicyT = policy; + using opT = op_type; + using FinOpT = final_op_type; + static constexpr bool is_row_major = row_major; +}; + +template +__global__ __launch_bounds__(KP_T::PolicyT::Nthreads, 2) + + void pairwise_matrix_kernel(const typename KP_T::DataT* x, + const typename KP_T::DataT* y, + const typename KP_T::DataT* _xn, + const typename KP_T::DataT* _yn, + typename KP_T::IdxT m, + typename KP_T::IdxT n, + typename KP_T::IdxT k, + typename KP_T::IdxT lda, + typename KP_T::IdxT ldb, + typename KP_T::IdxT ldd, + typename KP_T::OutT* dOutput, + typename KP_T::opT distance_op, + typename KP_T::FinOpT fin_op) { - using AccT = typename PCT::AccT; - using DataT = typename PCT::DataT; - using OutT = typename PCT::OutT; - using IdxT = typename PCT::IdxT; + using AccT = typename KP_T::AccT; + using DataT = typename KP_T::DataT; + using OutT = typename KP_T::OutT; + using IdxT = typename KP_T::IdxT; - using Policy = typename PCT::PolicyT; + using Policy = typename KP_T::PolicyT; // Instantiate compile time parameters to access constexpr members. - PCT compile_time_params{}; + KP_T compile_time_params{}; extern __shared__ char smem[]; @@ -103,29 +124,29 @@ __global__ __launch_bounds__(PCT::PolicyT::Nthreads, 2) obj.run(); } -template -static void pairwise_matrix(typename PCT::opT distance_op, - typename PCT::FinOpT fin_op, - const typename PCT::DataT* x, - const typename PCT::DataT* y, - const typename PCT::DataT* _xn, - const typename PCT::DataT* _yn, - typename PCT::IdxT m, - typename PCT::IdxT n, - typename PCT::IdxT k, - typename PCT::IdxT lda, - typename PCT::IdxT ldb, - typename PCT::IdxT ldd, - typename PCT::OutT* dOutput, +template +static void pairwise_matrix(typename KP_T::opT distance_op, + typename KP_T::FinOpT fin_op, + const typename KP_T::DataT* x, + const typename KP_T::DataT* y, + const typename KP_T::DataT* _xn, + const typename KP_T::DataT* _yn, + typename KP_T::IdxT m, + typename KP_T::IdxT n, + typename KP_T::IdxT k, + typename KP_T::IdxT lda, + typename KP_T::IdxT ldb, + typename KP_T::IdxT ldd, + typename KP_T::OutT* dOutput, cudaStream_t stream) { - using Policy = typename PCT::PolicyT; + using Policy = typename KP_T::PolicyT; dim3 blk(Policy::Nthreads); size_t smem_size = distance_op.template shared_mem_size(); - dim3 grid = launchConfigGenerator(m, n, smem_size, pairwise_matrix_kernel); + dim3 grid = launchConfigGenerator(m, n, smem_size, pairwise_matrix_kernel); - pairwise_matrix_kernel<<>>( + pairwise_matrix_kernel<<>>( x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); From 2f41faa419e4bf08bb0ff68587bcaf3bca385c20 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 15:00:06 +0100 Subject: [PATCH 12/93] l2 unexp: Use pairwise matrix dispatch --- .../distance/detail/distance_ops/l2_unexp.cuh | 68 +++++++ .../raft/distance/detail/euclidean.cuh | 186 +++--------------- 2 files changed, 91 insertions(+), 163 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh new file mode 100644 index 0000000000..99fda59f03 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::distance::detail::ops { + +// Describes the computation the l2 unexpanded distance + +template +struct l2_unexp_generic_distance_op { + // Do not load norms of data, the computation of L1 distance does not use them. + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const auto diff = x - y; + acc += diff * diff; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + if constexpr (sqrt) { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(acc[i][j]); + } + } + } + }; +}; + + +// Define distance ops with and without square root computation. +using l2_unexp_distance_op = l2_unexp_generic_distance_op; +using l2_unexp_sqrt_distance_op = l2_unexp_generic_distance_op; + + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 1a2db63f5c..8ed1e9d615 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -21,6 +21,10 @@ #include #include +#include "distance_ops/l2_unexp.cuh" +#include "pairwise_matrix/dispatch.cuh" + + namespace raft { namespace distance { namespace detail { @@ -285,145 +289,6 @@ void euclideanAlgo1(Index_ m, } } -/** - * @brief the unexpanded euclidean distance matrix calculation - * It computes the following equation: cij = op((ai-bj)^2) - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam FinalLambda final lambda called on final distance value - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[in] sqrt if the square root is computed or not - * @param[output] pD output matrix - * @param fin_op the final gemm epilogue lambda - */ -template -void euclideanUnExpImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - bool sqrt, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = x - y; - acc += diff * diff; - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [sqrt] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - if (sqrt) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - }; - - if (isRowMajor) { - auto euclideanUnExpRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, euclideanUnExpRowMajor); - - euclideanUnExpRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - - } else { - auto euclideanUnExpColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, euclideanUnExpColMajor); - - euclideanUnExpColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void euclideanUnExp(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - bool sqrt, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - euclideanUnExpImpl( - x, y, m, n, k, lda, ldb, ldd, sqrt, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - euclideanUnExpImpl( - x, y, m, n, k, lda, ldb, ldd, sqrt, dOutput, fin_op, stream); - } else { - euclideanUnExpImpl( - x, y, m, n, k, lda, ldb, ldd, sqrt, dOutput, fin_op, stream); - } -} /** * @brief the unexpanded euclidean distance matrix calculation @@ -444,35 +309,30 @@ void euclideanUnExp(IdxT m, * @param stream cuda stream where to launch work * @param isRowMajor whether the input and output matrices are row major */ -template -void euclideanAlgo2(Index_ m, - Index_ n, - Index_ k, - const InType* pA, - const InType* pB, - OutType* pD, +template +void euclideanAlgo2(IdxT m, + IdxT n, + IdxT k, + const DataT* pA, + const DataT* pB, + OutT* pD, bool enable_sqrt, - FinalLambda fin_op, + FinOpT fin_op, cudaStream_t stream, bool isRowMajor) { - typedef std::is_same is_bool; - typedef typename std::conditional::type UnExpOutType; - UnExpOutType* pDcast = reinterpret_cast(pD); - Index_ lda, ldb, ldd; - - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - euclideanUnExp( - m, n, k, lda, ldb, ldd, pA, pB, enable_sqrt, pDcast, fin_op, stream); + if (enable_sqrt) { + ops::l2_unexp_sqrt_distance_op l2_sqrt_op{}; + distance_matrix_dispatch( + l2_sqrt_op, m, n, k, pA, pB, pD, fin_op, stream, isRowMajor); } else { - lda = n, ldb = m, ldd = m; - euclideanUnExp( - n, m, k, lda, ldb, ldd, pB, pA, enable_sqrt, pDcast, fin_op, stream); + ops::l2_unexp_distance_op l2_op{}; + distance_matrix_dispatch( + l2_op, m, n, k, pA, pB, pD, fin_op, stream, isRowMajor); } } From 7938614f0c74c7405c66b8074150c9e642952130 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 16:34:41 +0100 Subject: [PATCH 13/93] l2 exp: Use pairwise matrix dispatch This did remove support for the CUTLASS kernels. Has to be put back. --- .../raft/distance/detail/distance_ops/l1.cuh | 2 +- .../distance/detail/distance_ops/l2_exp.cuh | 72 +++++ .../distance/detail/distance_ops/l2_unexp.cuh | 16 +- .../raft/distance/detail/euclidean.cuh | 304 ++++-------------- cpp/include/raft/distance/detail/l1.cuh | 5 +- .../detail/pairwise_matrix/dispatch.cuh | 13 +- .../detail/pairwise_matrix/kernel_sm60.cuh | 3 +- 7 files changed, 166 insertions(+), 249 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 08ca313fe2..9d31b24851 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -25,7 +25,7 @@ struct l1_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh new file mode 100644 index 0000000000..c15b43a74e --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::distance::detail::ops { + +// Describes the computation the l2 expanded distance +// +// TODO: more explanation. +struct l2_exp_distance_op { + bool sqrt; + + l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} + + // Load norms of input data + static constexpr bool use_norms = true; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += x * y; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } + } + if (sqrt) { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(acc[i][j]); + } + } + } + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index 99fda59f03..03bbd936c6 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -19,15 +19,17 @@ namespace raft::distance::detail::ops { // Describes the computation the l2 unexpanded distance +struct l2_unexp_distance_op { + bool sqrt; + + l2_unexp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} -template -struct l2_unexp_generic_distance_op { // Do not load norms of data, the computation of L1 distance does not use them. static constexpr bool use_norms = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; @@ -47,7 +49,7 @@ struct l2_unexp_generic_distance_op { IdxT gridStrideX, IdxT gridStrideY) const { - if constexpr (sqrt) { + if (sqrt) { #pragma unroll for (int i = 0; i < Policy::AccRowsPerTh; ++i) { #pragma unroll @@ -59,10 +61,4 @@ struct l2_unexp_generic_distance_op { }; }; - -// Define distance ops with and without square root computation. -using l2_unexp_distance_op = l2_unexp_generic_distance_op; -using l2_unexp_sqrt_distance_op = l2_unexp_generic_distance_op; - - } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 8ed1e9d615..29088257e2 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -16,13 +16,11 @@ #pragma once -#include -#include #include -#include -#include "distance_ops/l2_unexp.cuh" #include "pairwise_matrix/dispatch.cuh" +#include "distance_ops/l2_exp.cuh" +#include "distance_ops/l2_unexp.cuh" namespace raft { @@ -44,249 +42,88 @@ struct L2ExpandedOp { __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; -/** - * @brief the expanded euclidean distance matrix calculation implementer - * It computes the following equation: C = op(A^2 + B^2 - 2AB) - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread for every LDG call - * it makes. check contractions.cuh for details. - * @tparam FinalLambda the final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] xn row norms of input matrix A. - * @param[in] yn row norms of input matrix B. - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[in] sqrt if the square root is computed or not - * @param[output] pD output matrix - * @param fin_op the final gemm epilogue lambda -* @param stream cuda stream to launch cuda operations. - */ -template -void euclideanExpImpl(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - bool sqrt, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ -#if (__CUDACC_VER_MAJOR__ < 12) - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { - using L2Op = L2ExpandedOp; - L2Op L2_dist_op(sqrt); - - cutlassDistanceKernel( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, L2_dist_op, stream); - - } else -#endif - { - - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [sqrt] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (sqrt) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - }; - - constexpr size_t shmemSize = - KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); - if (isRowMajor) { - auto euclideanExpRowMajor = pairwiseDistanceMatKernelPriorToAmpere; - dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); - - euclideanExpRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto euclideanExpColMajor = pairwiseDistanceMatKernelPriorToAmpere; - dim3 grid = launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); - euclideanExpColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} +// /** +// * @brief the expanded euclidean distance matrix calculation +// * It computes the following equation: C = op(A^2 + B^2 - 2AB) +// * @tparam InType input data-type (for A and B matrices) +// * @tparam AccType accumulation data-type +// * @tparam OutType output data-type (for C and D matrices) +// * @tparam FinalLambda the final lambda called by FragmentMultiplyAdd_ +// * @tparam Index_ index type +// * @param m number of rows of A and C/D +// * @param n number of columns of B and C/D +// * @param k number of cols of A and rows of B +// * @param pA input matrix +// * @param pB input matrix +// * @param pD output matrix +// * @param enable_sqrt if the square root is computed or not +// * @param workspace temporary workspace needed for computations +// * @param worksize number of bytes of the workspace +// * @param fin_op the final gemm epilogue lambda +// * @param stream cuda stream where to launch work +// * @param isRowMajor whether the input and output matrices are row major +// */ template -void euclideanExp(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - bool sqrt, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - euclideanExpImpl( - x, y, xn, yn, m, n, k, lda, ldb, ldd, sqrt, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - euclideanExpImpl( - x, y, xn, yn, m, n, k, lda, ldb, ldd, sqrt, dOutput, fin_op, stream); - } else { - euclideanExpImpl( - x, y, xn, yn, m, n, k, lda, ldb, ldd, sqrt, dOutput, fin_op, stream); - } -} - -/** - * @brief the expanded euclidean distance matrix calculation - * It computes the following equation: C = op(A^2 + B^2 - 2AB) - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda the final lambda called by FragmentMultiplyAdd_ - * @tparam Index_ index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param enable_sqrt if the square root is computed or not - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void euclideanAlgo1(Index_ m, - Index_ n, - Index_ k, - const InType* pA, - const InType* pB, - OutType* pD, + typename FinOpT, + typename IdxT = int> +void euclideanAlgo1(IdxT m, + IdxT n, + IdxT k, + const DataT* pA, + const DataT* pB, + OutT* pD, bool enable_sqrt, - AccType* workspace, + AccT* workspace, size_t& worksize, - FinalLambda fin_op, + FinOpT fin_op, cudaStream_t stream, bool isRowMajor) { + // TODO: handle cutlass kernels + // constexpr bool CUDA_11_or_below = __CUDACC_VER_MAJOR__ < 12; + + // if constexpr(CUDA_11_or_below) { + // const auto deviceVersion = getComputeCapability(); + // if (deviceVersion.first >= 8) { + // using L2Op = L2ExpandedOp; + // L2Op L2_dist_op(sqrt); + + // cutlassDistanceKernel( + // x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, L2_dist_op, stream); + // } + // } + + // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), - "OutType can be uint8_t, float, double," - "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); - typedef typename std::conditional::type ExpOutType; - ExpOutType* pDcast = reinterpret_cast(pD); + static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), + "OutT can be uint8_t, float, double," + "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); ASSERT( - !(((pA != pB) && (worksize < (m + n) * sizeof(AccType))) || (worksize < m * sizeof(AccType))), + !(((pA != pB) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); - Index_ lda, ldb, ldd; - InType* col_vec = workspace; - InType* row_vec = workspace; + DataT* norm_A = workspace; + DataT* norm_B = workspace; if (pA != pB) { - row_vec += m; + norm_B += m; raft::linalg::rowNorm( - col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); raft::linalg::rowNorm( - row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + norm_B, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } else { raft::linalg::rowNorm( - col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - euclideanExp( - m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, enable_sqrt, pDcast, fin_op, stream); - } else { - lda = n, ldb = m, ldd = m; - euclideanExp( - n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, enable_sqrt, pDcast, fin_op, stream); - } + ops::l2_exp_distance_op l2_op(enable_sqrt); + + distance_matrix_dispatch( + l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); } @@ -325,15 +162,14 @@ void euclideanAlgo2(IdxT m, cudaStream_t stream, bool isRowMajor) { - if (enable_sqrt) { - ops::l2_unexp_sqrt_distance_op l2_sqrt_op{}; - distance_matrix_dispatch( - l2_sqrt_op, m, n, k, pA, pB, pD, fin_op, stream, isRowMajor); - } else { - ops::l2_unexp_distance_op l2_op{}; - distance_matrix_dispatch( - l2_op, m, n, k, pA, pB, pD, fin_op, stream, isRowMajor); - } + ops::l2_unexp_distance_op l2_op(enable_sqrt); + + // The unexpanded L2 does not require the norms of a and b to be calculated. + const DataT* norm_A = nullptr; + const DataT* norm_B = nullptr; + + distance_matrix_dispatch( + l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); } }; // end namespace detail diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index a5f279d9a4..49402a9101 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -35,8 +35,11 @@ void l1Impl(int m, { ops::l1_distance_op distance_op{}; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + distance_matrix_dispatch( - distance_op, m, n, k, x, y, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } } // namespace detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 0e056405a1..4a8fb82861 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include "kernel_sm60.cuh" @@ -119,6 +120,8 @@ void distance_matrix_dispatch(opT distance_op, int k_, const DataT* x_, const DataT* y_, + const DataT* x_norm_, + const DataT* y_norm_, OutT* out, FinOpT fin_op, cudaStream_t stream, @@ -129,17 +132,22 @@ void distance_matrix_dispatch(opT distance_op, // // ldx, ldy, and ld_out are the leading dimensions of x, y, and out const DataT* x; + const DataT* x_norm; const DataT* y; + const DataT* y_norm; + int ldx, ldy, ld_out; int m, n, k; if (is_row_major) { // Pass x, y, m, n, k in order x = x_, y = y_; + x_norm = x_norm_, y_norm = y_norm_; m = m_, n = n_, k = k_; ldx = k_, ldy = k_, ld_out = n_; } else { // Flip x, y, and m, n, k. x = y_, y = x_; + x_norm = y_norm_, y_norm = x_norm_; m = n_, n = m_, k = k_; ldx = n_, ldy = m_, ld_out = m_; } @@ -182,8 +190,8 @@ void distance_matrix_dispatch(opT distance_op, fin_op, x, y, - nullptr, - nullptr, // TODO: use _xn, _yn for non-l1 distances + x_norm, + y_norm, m, n, k, @@ -195,6 +203,7 @@ void distance_matrix_dispatch(opT distance_op, }); if (!dispatch_success) { + std::printf("Dispatch error(!)\n"); // TODO } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index fa30ff2c3e..68026414c0 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -141,9 +141,10 @@ static void pairwise_matrix(typename KP_T::opT distance_op, cudaStream_t stream) { using Policy = typename KP_T::PolicyT; + using DataT = typename KP_T::DataT; dim3 blk(Policy::Nthreads); - size_t smem_size = distance_op.template shared_mem_size(); + size_t smem_size = distance_op.template shared_mem_size(); dim3 grid = launchConfigGenerator(m, n, smem_size, pairwise_matrix_kernel); pairwise_matrix_kernel<<>>( From 7afe6cc8219c4225442dbcbce4e8a28f22dbcb2f Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 16:35:29 +0100 Subject: [PATCH 14/93] Add template for distance operator I wasted a lot of time because I had not replaced the op::core() method of the l2_exp_distance_op after I copied it from l2_unexp_distance_op... If I copy something from the template and forget to fill it in, I get a compile error. --- .../distance/detail/distance_ops/template.cuh | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 cpp/include/raft/distance/detail/distance_ops/template.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh new file mode 100644 index 0000000000..cfd12b8bc1 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::distance::detail::ops { + +// Describes the computation the template distance +// +// Fill in the TODO items. + +struct template_op { + // Load norms of input data + static constexpr bool use_norms = TODO; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize + TODO; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + TODO; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + TODO; + }; + +} // namespace raft::distance::detail::ops From 5fe3292eba31fff96587c6e7c01757939166cb76 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 17:32:34 +0100 Subject: [PATCH 15/93] Reenable cutlass-based kernels for CUDA 12.0 --- .../raft/distance/detail/pairwise_distance_cutlass_base.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index f39d880da4..efd44ea4dc 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -19,8 +19,6 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" -#if (__CUDACC_VER_MAJOR__ < 12) - // We define CUTLASS_NAMESPACE in case // RAFT cmake is not used #ifndef CUTLASS_NAMESPACE @@ -174,5 +172,5 @@ void cutlassDistanceKernel(const DataT* x, }; // namespace detail }; // namespace distance }; // namespace raft -#endif // (__CUDACC_VER_MAJOR__ < 12) + #pragma GCC diagnostic pop From c623332ac89ab28a4a9c5ab2d23aa2193b07d256 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 18:26:02 +0100 Subject: [PATCH 16/93] pairwise matrix l2: Add support for CUTLASS kernels I am testing on CUDA 12, where it does not seem to work. Prior to my commits, the CUTLASS kernels were also not working. So not sure what's up. In any case: consider this untested. --- .../distance/detail/distance_ops/l2_exp.cuh | 16 +++ .../raft/distance/detail/euclidean.cuh | 63 +++++------ .../detail/pairwise_matrix/dispatch.cuh | 104 ++++++++++++------ 3 files changed, 116 insertions(+), 67 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index c15b43a74e..4dfb26a826 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -69,4 +69,20 @@ struct l2_exp_distance_op { } }; +// Epilogue operator for CUTLASS based kernel +template +struct l2_exp_cutlass_op { + bool sqrt; + + __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} + __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + return sqrt ? raft::sqrt(outVal) : outVal; + } + + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 29088257e2..51e2ff224f 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -27,22 +27,6 @@ namespace raft { namespace distance { namespace detail { -template -struct L2ExpandedOp { - bool sqrt; - - __device__ L2ExpandedOp() noexcept : sqrt(false) {} - __device__ L2ExpandedOp(bool isSqrt) noexcept : sqrt(isSqrt) {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - return sqrt ? raft::sqrt(outVal) : outVal; - } - - __device__ AccT operator()(DataT aData) const noexcept { return aData; } -}; - - // /** // * @brief the expanded euclidean distance matrix calculation // * It computes the following equation: C = op(A^2 + B^2 - 2AB) @@ -82,21 +66,6 @@ void euclideanAlgo1(IdxT m, cudaStream_t stream, bool isRowMajor) { - // TODO: handle cutlass kernels - // constexpr bool CUDA_11_or_below = __CUDACC_VER_MAJOR__ < 12; - - // if constexpr(CUDA_11_or_below) { - // const auto deviceVersion = getComputeCapability(); - // if (deviceVersion.first >= 8) { - // using L2Op = L2ExpandedOp; - // L2Op L2_dist_op(sqrt); - - // cutlassDistanceKernel( - // x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, L2_dist_op, stream); - // } - // } - - // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), "OutT can be uint8_t, float, double," @@ -120,10 +89,34 @@ void euclideanAlgo1(IdxT m, norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); } - ops::l2_exp_distance_op l2_op(enable_sqrt); - - distance_matrix_dispatch( - l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel otherwise. + + if constexpr (__CUDACC_VER_MAJOR__ == 12) { + // Always execute legacy kernels on CUDA 12 + ops::l2_exp_distance_op l2_op(enable_sqrt); + distance_matrix_dispatch( + l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); + } else { + const auto deviceVersion = getComputeCapability(); + if (deviceVersion.first >= 8) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using L2Op = ops::l2_exp_cutlass_op; + L2Op l2_op(enable_sqrt); + + distance_matrix_cutlass_dispatch( + l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); + } else { + // Else use "legacy" L2 + ops::l2_exp_distance_op l2_op(enable_sqrt); + distance_matrix_dispatch( + l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); + } + } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 4a8fb82861..650c8fa805 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -16,7 +16,9 @@ #pragma once #include +#include #include +#include #include "kernel_sm60.cuh" namespace raft::distance::detail { @@ -85,8 +87,8 @@ struct params_dispatch { // Determine the largest number of elements that can be loaded in one // instruction without causing misalignment errors. -template -int max_aligned_load(const DataT* x, const DataT* y, int ldx, int ldy) +template +int vectorized_load_num_elem(const DataT* x, const DataT* y, IdxT ldx, IdxT ldy) { auto base_x = reinterpret_cast(x); auto base_y = reinterpret_cast(y); @@ -115,13 +117,13 @@ template void distance_matrix_dispatch(opT distance_op, - int m_, - int n_, - int k_, - const DataT* x_, - const DataT* y_, - const DataT* x_norm_, - const DataT* y_norm_, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, OutT* out, FinOpT fin_op, cudaStream_t stream, @@ -129,38 +131,24 @@ void distance_matrix_dispatch(opT distance_op, { // Determine leading dimensions and possibly flip order of passing x and y if // column_major. - // - // ldx, ldy, and ld_out are the leading dimensions of x, y, and out - const DataT* x; - const DataT* x_norm; - const DataT* y; - const DataT* y_norm; - - int ldx, ldy, ld_out; - int m, n, k; + IdxT ldx, ldy, ld_out; if (is_row_major) { - // Pass x, y, m, n, k in order - x = x_, y = y_; - x_norm = x_norm_, y_norm = y_norm_; - m = m_, n = n_, k = k_; - ldx = k_, ldy = k_, ld_out = n_; + ldx = k, ldy = k, ld_out = n; } else { - // Flip x, y, and m, n, k. - x = y_, y = x_; - x_norm = y_norm_, y_norm = x_norm_; - m = n_, n = m_, k = k_; - ldx = n_, ldy = m_, ld_out = m_; + // Flip x, y, and m, n. + std::swap(x, y); + std::swap(x_norm, y_norm); + std::swap(m, n); + ldx = m, ldy = n, ld_out = n; } - int vectorized_load_num_elem = max_aligned_load(x, y, ldx, ldy); - // Create run-time parameter struct that does the dispatching. // // In addition to the template parameters of this function (IdxT, DataT, // etc..), we explicitly dispatch based on: params_dispatch run_time_params{ - vectorized_load_num_elem, // 1. num array elements per load instruction - is_row_major // 2. the layout x, y, and out + vectorized_load_num_elem(x, y, ldx, ldy), // 1. num array elements per load instruction + is_row_major // 2. the layout of x, y, and out }; // Turn run-time parameters into compile-time parameters. @@ -208,4 +196,56 @@ void distance_matrix_dispatch(opT distance_op, } } +template +void distance_matrix_cutlass_dispatch(opT cutlass_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Determine leading dimensions and possibly flip order of passing x and y if + // column_major. + IdxT ldx, ldy, ld_out; + if (is_row_major) { + ldx = k, ldy = k, ld_out = n; + } else { + std::swap(x, y); + std::swap(x_norm, y_norm); + std::swap(m, n); + ldx = m, ldy = n, ld_out = n; + } + + params_dispatch run_time_params{ + vectorized_load_num_elem(x, y, ldx, ldy), + is_row_major + }; + + bool dispatch_success = run_time_params.dispatch_with_compile_time_params( + [&](auto p) { + // Prevent loading 4 doubles in one instruction. + constexpr bool load_4_doubles = sizeof(DataT) > 4 && p.vec_len == 4; + constexpr int vec_len = (load_4_doubles) ? 2 : p.vec_len; + + cutlassDistanceKernel( + x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); + }); + + if (!dispatch_success) { + std::printf("Dispatch error(!)\n"); + // TODO + } +} + }; // namespace raft::distance::detail From 27511fc65f9c39a371a0efd02ffeb7bfcbbcd736 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 18:42:44 +0100 Subject: [PATCH 17/93] Canberra: use dispatching mechanism --- cpp/include/raft/distance/detail/README.org | 13 ++ cpp/include/raft/distance/detail/canberra.cuh | 181 +++--------------- .../distance/detail/distance_ops/canberra.cuh | 58 ++++++ .../distance/detail/distance_ops/template.cuh | 3 +- cpp/include/raft/distance/detail/l1.cuh | 6 +- 5 files changed, 103 insertions(+), 158 deletions(-) create mode 100644 cpp/include/raft/distance/detail/README.org create mode 100644 cpp/include/raft/distance/detail/distance_ops/canberra.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org new file mode 100644 index 0000000000..dcb9b1d1e2 --- /dev/null +++ b/cpp/include/raft/distance/detail/README.org @@ -0,0 +1,13 @@ +#+title: Readme + + +- [X] canberra.cuh +- [ ] chebyshev.cuh +- [ ] correlation.cuh +- [ ] cosine.cuh +- [ ] hamming.cuh +- [ ] hellinger.cuh +- [ ] jensen_shannon.cuh +- [ ] kl_divergence.cuh +- [ ] minkowski.cuh +- [ ] russell_rao.cuh diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index f17a26dc4b..3f0c2fa268 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -15,148 +15,23 @@ */ #pragma once -#include + +#include "distance_ops/canberra.cuh" +#include "pairwise_matrix/dispatch.cuh" namespace raft { namespace distance { namespace detail { -/** - * @brief the canberra distance matrix calculation implementer - * It computes the following equation: cij = max(cij, op(ai-bj)) - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and cols of C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] dOutput output matrix - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream to launch work - */ -template -static void canberraImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::abs(x - y); - const auto add = raft::abs(x) + raft::abs(y); - // deal with potential for 0 in denominator by - // forcing 1/0 instead - acc += ((add != 0) * diff / (add + (add == 0))); - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = raft::void_op(); - - if (isRowMajor) { - auto canberraRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, canberraRowMajor); - - canberraRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto canberraColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, canberraColMajor); - canberraColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void canberra(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - canberraImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - canberraImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - canberraImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} /** * @brief the canberra distance matrix calculation * It computes the following equation: cij = max(cij, op(ai-bj)) - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam FinOpT user-defined epilogue lamba + * @tparam IdxT Index type * @param[in] m number of rows of A and C/D * @param[in] n number of rows of B and cols of C/D * @param[in] k number of cols of A and B @@ -167,34 +42,28 @@ void canberra(IdxT m, * @param[in] stream cuda stream to launch work * @param[in] isRowMajor whether the input and output matrices are row major */ -template +template void canberraImpl(int m, int n, int k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, cudaStream_t stream, - bool isRowMajor) + bool is_row_major) { - typedef std::is_same is_bool; - typedef typename std::conditional::type canberraOutType; - Index_ lda, ldb, ldd; - canberraOutType* pDcast = reinterpret_cast(pD); - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - canberra( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); - } else { - lda = n, ldb = m, ldd = m; - canberra( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); - } + ops::canberra_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } } // namespace detail diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh new file mode 100644 index 0000000000..4fda825286 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance::detail::ops { + +// Describes the computation the canberra distance + +struct canberra_distance_op { + // Load norms of input data + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const auto diff = raft::abs(x - y); + const auto add = raft::abs(x) + raft::abs(y); + // deal with potential for 0 in denominator by + // forcing 1/0 instead + acc += ((add != 0) * diff / (add + (add == 0))); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + return; + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index cfd12b8bc1..4c624c5593 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -48,6 +48,7 @@ struct template_op { IdxT gridStrideY) const { TODO; - }; + } +}; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 49402a9101..cceb432c7d 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -22,7 +22,11 @@ namespace raft { namespace distance { namespace detail { -template +template void l1Impl(int m, int n, int k, From 58ce6f8fb3da0d6916a0421fcdb82669ea6e28ee Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 18:50:02 +0100 Subject: [PATCH 18/93] Chebyshev: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 2 +- .../raft/distance/detail/chebyshev.cuh | 174 +++--------------- .../detail/distance_ops/chebyshev.cuh | 55 ++++++ 3 files changed, 78 insertions(+), 153 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index dcb9b1d1e2..f84c2a0f2c 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -2,7 +2,7 @@ - [X] canberra.cuh -- [ ] chebyshev.cuh +- [X] chebyshev.cuh - [ ] correlation.cuh - [ ] cosine.cuh - [ ] hamming.cuh diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index 43b36e7921..9f49660301 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -15,136 +15,12 @@ */ #pragma once -#include -#include +#include "distance_ops/chebyshev.cuh" +#include "pairwise_matrix/dispatch.cuh" namespace raft { namespace distance { namespace detail { -/** - * @brief the Chebyshev distance matrix calculation implementer - * It computes the following equation: cij = max(cij, op(ai-bj)) - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and cols of C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[out] dOutput output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - */ -template -static void chebyshevImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::abs(x - y); - acc = raft::max(acc, diff); - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = raft::void_op(); - - if (isRowMajor) { - auto chebyshevRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, chebyshevRowMajor); - - chebyshevRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto chebyshevColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, chebyshevColMajor); - chebyshevColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void chebyshev(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - chebyshevImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - chebyshevImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - chebyshevImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} /** * @brief the chebyshev distance matrix calculation @@ -164,34 +40,28 @@ void chebyshev(IdxT m, * @param[in] stream cuda stream to launch work * @param[in] isRowMajor whether the input and output matrices are row major */ -template +template void chebyshevImpl(int m, - int n, - int k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor) + int n, + int k, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) { - typedef std::is_same is_bool; - typedef typename std::conditional::type chebyshevOutType; - Index_ lda, ldb, ldd; - chebyshevOutType* pDcast = reinterpret_cast(pD); - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - chebyshev( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); - } else { - lda = n, ldb = m, ldd = m; - chebyshev( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); - } + ops::chebyshev_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } } // namespace detail } // namespace distance diff --git a/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh b/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh new file mode 100644 index 0000000000..ced9fcf6f7 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance::detail::ops { + +// Describes the computation the chebyshev distance + +struct chebyshev_distance_op { + // Load norms of input data + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const auto diff = raft::abs(x - y); + acc = raft::max(acc, diff); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + return; + } +}; + +} // namespace raft::distance::detail::ops From d397c170e62ec5ebf02d61510830a37381b8c429 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 19:27:50 +0100 Subject: [PATCH 19/93] Correlation: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 2 +- .../raft/distance/detail/correlation.cuh | 228 +----------------- .../detail/distance_ops/correlation.cuh | 127 ++++++++++ .../distance/detail/distance_ops/template.cuh | 2 +- 4 files changed, 137 insertions(+), 222 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/correlation.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index f84c2a0f2c..dc66a55f60 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -3,7 +3,7 @@ - [X] canberra.cuh - [X] chebyshev.cuh -- [ ] correlation.cuh +- [X] correlation.cuh - [ ] cosine.cuh - [ ] hamming.cuh - [ ] hellinger.cuh diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index f7fe3678e6..89828c9ba2 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -15,192 +15,16 @@ */ #pragma once -#include + #include -#include + +#include "pairwise_matrix/dispatch.cuh" +#include "distance_ops/correlation.cuh" namespace raft { namespace distance { namespace detail { -/** - * @brief the Correlation distance matrix: - * - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] dOutput output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - */ -template -static void correlationImpl(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const DataT* x2n, - const DataT* y2n, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [x2n, y2n, m, n, k] __device__( - AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - DataT regx2n[KPolicy::AccRowsPerTh], regy2n[KPolicy::AccColsPerTh]; - - extern __shared__ char smem[]; - DataT* sx2Norm = - (DataT*)(&smem[KPolicy::SmemSize + (KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)]); - DataT* sy2Norm = (&sx2Norm[KPolicy::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * KPolicy::Nblk) { - for (int i = threadIdx.x; i < KPolicy::Mblk; i += KPolicy::Nthreads) { - auto idx = gridStrideY + i; - sx2Norm[i] = idx < m ? x2n[idx] : 0; - } - } - - for (int i = threadIdx.x; i < KPolicy::Nblk; i += KPolicy::Nthreads) { - auto idx = gridStrideX + i; - sy2Norm[i] = idx < n ? y2n[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { - regx2n[i] = sx2Norm[i * KPolicy::AccThRows + (threadIdx.x / KPolicy::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < KPolicy::AccColsPerTh; ++i) { - regy2n[i] = sy2Norm[i * KPolicy::AccThCols + (threadIdx.x % KPolicy::AccThCols)]; - } - -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - auto numer = k * acc[i][j] - (regxn[i] * regyn[j]); - auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); - auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); - - acc[i][j] = 1 - (numer / raft::sqrt(Q_denom * R_denom)); - } - } - }; - - constexpr size_t shmemSize = - KPolicy::SmemSize + (2 * (KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); - if (isRowMajor) { - constexpr auto correlationRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, correlationRowMajor); - correlationRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - constexpr auto correlationColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, correlationColMajor); - correlationColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void correlation(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const DataT* x2n, - const DataT* y2n, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - correlationImpl( - x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - correlationImpl( - x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - correlationImpl( - x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} - /** * @brief the Correlation distance matrix calculation * @@ -236,11 +60,6 @@ void correlationImpl(int m, cudaStream_t stream, bool isRowMajor) { - typedef std::is_same is_bool; - typedef typename std::conditional::type correlationOutType; - Index_ lda, ldb, ldd; - correlationOutType* pDcast = reinterpret_cast(pD); - ASSERT(!(((pA != pB) && (worksize < 2 * (m + n) * sizeof(AccType))) || (worksize < 2 * m * sizeof(AccType))), "workspace size error"); @@ -297,41 +116,10 @@ void correlationImpl(int m, raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream); } - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - correlation(m, - n, - k, - lda, - ldb, - ldd, - pA, - pB, - norm_col_vec, - norm_row_vec, - sq_norm_col_vec, - sq_norm_row_vec, - pDcast, - fin_op, - stream); - } else { - lda = n, ldb = m, ldd = m; - correlation(n, - m, - k, - lda, - ldb, - ldd, - pB, - pA, - norm_row_vec, - norm_col_vec, - sq_norm_row_vec, - sq_norm_col_vec, - pDcast, - fin_op, - stream); - } + using CorrOp = ops::correlation_distance_op; + CorrOp corr_op(isRowMajor, sq_norm_col_vec, sq_norm_row_vec, m, n, k); + distance_matrix_dispatch( + corr_op, m, n, k, pA, pB, norm_col_vec, norm_row_vec, pD, fin_op, stream, isRowMajor); } } // namespace detail diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh new file mode 100644 index 0000000000..98d90ea0a5 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::distance::detail::ops { + +// Describes the computation the correlation distance + + +template +struct correlation_distance_op { + const DataT_struct* x2n; + const DataT_struct* y2n; + IdxT_struct m; + IdxT_struct n; + IdxT_struct k; + + correlation_distance_op( + bool is_row_major, + const DataT_struct* x2n_, + const DataT_struct* y2n_, + IdxT_struct m_, + IdxT_struct n_, + IdxT_struct k_ + ) noexcept + : x2n(x2n_), + y2n(y2n_), + m(m_), + n(n_), + k(k_) + { + // The distance op is typically created before the row-major/col-major + // swapping has been done. So we do it here. + if (!is_row_major) { + std::swap(x2n, y2n); + std::swap(m, n); + } + } + + + // Load norms of input data + static constexpr bool use_norms = true; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += x * y; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + // Note how we can sneakily get a pointer to shared memory here, to store + // more data. If the implementation of PairwiseDistanceMatKernel ever + // changes, this will be where we find the bugs. + extern __shared__ char smem[]; + + DataT regx2n[Policy::AccRowsPerTh], regy2n[Policy::AccColsPerTh]; + + DataT* sx2Norm = + (DataT*)(&smem[Policy::SmemSize + (Policy::Mblk + Policy::Nblk) * sizeof(DataT)]); + DataT* sy2Norm = (&sx2Norm[Policy::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (gridStrideX == blockIdx.x * Policy::Nblk) { + for (int i = threadIdx.x; i < Policy::Mblk; i += Policy::Nthreads) { + auto idx = gridStrideY + i; + sx2Norm[i] = idx < m ? x2n[idx] : 0; + } + } + + for (int i = threadIdx.x; i < Policy::Nblk; i += Policy::Nthreads) { + auto idx = gridStrideX + i; + sy2Norm[i] = idx < n ? y2n[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + regx2n[i] = sx2Norm[i * Policy::AccThRows + (threadIdx.x / Policy::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < Policy::AccColsPerTh; ++i) { + regy2n[i] = sy2Norm[i * Policy::AccThCols + (threadIdx.x % Policy::AccThCols)]; + } + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + auto numer = k * acc[i][j] - (regxn[i] * regyn[j]); + auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); + auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); + + acc[i][j] = 1 - (numer / raft::sqrt(Q_denom * R_denom)); + } + } + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index 4c624c5593..98c35c6295 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -22,7 +22,7 @@ namespace raft::distance::detail::ops { // // Fill in the TODO items. -struct template_op { +struct template_distance_op { // Load norms of input data static constexpr bool use_norms = TODO; From 7005a4f2361d77ad7e5001d98f10fd4477f8c669 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 19:40:50 +0100 Subject: [PATCH 20/93] Hamming: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 2 +- .../distance/detail/distance_ops/hamming.cuh | 64 +++++++++++++++++++ cpp/include/raft/distance/detail/hamming.cuh | 43 ++++++------- 3 files changed, 83 insertions(+), 26 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/hamming.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index dc66a55f60..223a50bee1 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -5,7 +5,7 @@ - [X] chebyshev.cuh - [X] correlation.cuh - [ ] cosine.cuh -- [ ] hamming.cuh +- [X] hamming.cuh - [ ] hellinger.cuh - [ ] jensen_shannon.cuh - [ ] kl_divergence.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh new file mode 100644 index 0000000000..1f88424d70 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::distance::detail::ops { + +// Describes the computation the hamming distance + +template +struct hamming_distance_op { + IdxT_struct k; + + hamming_distance_op(IdxT_struct k_) noexcept : k(k_) { } + + // Load norms of input data + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += (x != y); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + const DataT one_over_k = DataT(1.0) / k; +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] *= one_over_k; + } + } + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh index bed9d09e3e..7d283def21 100644 --- a/cpp/include/raft/distance/detail/hamming.cuh +++ b/cpp/include/raft/distance/detail/hamming.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include "distance_ops/hamming.cuh" +#include "pairwise_matrix/dispatch.cuh" namespace raft { namespace distance { @@ -178,36 +179,28 @@ void hammingUnexpanded(IdxT m, * @param stream cuda stream where to launch work * @param isRowMajor whether the input and output matrices are row major */ -template +template void hammingUnexpandedImpl(int m, int n, int k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, cudaStream_t stream, - bool isRowMajor) + bool is_row_major) { - typedef std::is_same is_bool; - typedef - typename std::conditional::type hammingUnexpandedOutType; - Index_ lda, ldb, ldd; - hammingUnexpandedOutType* pDcast = reinterpret_cast(pD); - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - hammingUnexpanded( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + ops::hamming_distance_op distance_op{k}; - } else { - lda = n, ldb = m, ldd = m; - hammingUnexpanded( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); - } + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } } // namespace detail From 7831debb75eea89406ade619aab81890917be9a0 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 20:07:46 +0100 Subject: [PATCH 21/93] Hellinger: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 2 +- .../detail/distance_ops/hellinger.cuh | 66 ++++++ .../distance/detail/distance_ops/template.cuh | 4 + .../raft/distance/detail/hellinger.cuh | 219 +++--------------- 4 files changed, 107 insertions(+), 184 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/hellinger.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index 223a50bee1..47239d3f69 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -6,7 +6,7 @@ - [X] correlation.cuh - [ ] cosine.cuh - [X] hamming.cuh -- [ ] hellinger.cuh +- [X] hellinger.cuh - [ ] jensen_shannon.cuh - [ ] kl_divergence.cuh - [ ] minkowski.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh new file mode 100644 index 0000000000..b01f118923 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft::distance::detail::ops { + +// Describes the computation the hellinger distance +// +// Fill in the TODO items. + +struct hellinger_distance_op { + // Load norms of input data + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + // This is sqrt(x) * sqrt(y). + const auto product = x * y; + acc += product; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + const auto finalVal = (1 - acc[i][j]); + const auto rectifier = (!signbit(finalVal)); + acc[i][j] = raft::sqrt(rectifier * finalVal); + } + } + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index 98c35c6295..c770a575a0 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -23,6 +23,10 @@ namespace raft::distance::detail::ops { // Fill in the TODO items. struct template_distance_op { + TODO member; + + template_distance_op(TODO member_) noexcept : member(member_) { } + // Load norms of input data static constexpr bool use_norms = TODO; diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 13507fe84f..306977f266 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -15,173 +15,16 @@ */ #pragma once -#include +#include #include -#include + +#include "pairwise_matrix/dispatch.cuh" +#include "distance_ops/hellinger.cuh" namespace raft { namespace distance { namespace detail { -/** - * @brief the Hellinger distance matrix using the expanded form: - * It computes the following equation: - cij = sqrt(1 - sum(sqrt(x_k * y_k))) - * This distance computation modifies A and B by computing a sqrt - * and then performing a `pow(x, 2)` to convert it back. Because of this, - * it is possible that the values in A and B might differ slightly - * after this is invoked. - * - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] dOutput output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - */ -template -static void hellingerImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // First sqrt x and y - raft::linalg::unaryOp((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { - raft::linalg::unaryOp((DataT*)y, y, n * k, raft::sqrt_op{}, stream); - } - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - // This is sqrt(x) * sqrt(y). - const auto product = x * y; - acc += product; - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - const auto finalVal = (1 - acc[i][j]); - const auto rectifier = (!signbit(finalVal)); - acc[i][j] = raft::sqrt(rectifier * finalVal); - } - } - }; - - if (isRowMajor) { - auto hellingerRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, hellingerRowMajor); - - hellingerRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto hellingerColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, hellingerColMajor); - hellingerColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - // Revert sqrt of x and y - raft::linalg::unaryOp((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { - raft::linalg::unaryOp((DataT*)y, y, n * k, raft::sqrt_op{}, stream); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void hellinger(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - hellingerImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - hellingerImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - hellingerImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} - /** * @brief the Hellinger distance matrix calculation * It computes the following equation: @@ -206,35 +49,45 @@ void hellinger(IdxT m, * @param stream cuda stream where to launch work * @param isRowMajor whether the input and output matrices are row major */ -template +template void hellingerImpl(int m, int n, int k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, cudaStream_t stream, - bool isRowMajor) + bool is_row_major) { - typedef std::is_same is_bool; - typedef typename std::conditional::type hellingerOutType; - Index_ lda, ldb, ldd; - hellingerOutType* pDcast = reinterpret_cast(pD); - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - hellinger( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + // First sqrt x and y + const auto raft_sqrt = raft::linalg::unaryOp; + + raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); + if (x != y) { + raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); + } - } else { - lda = n, ldb = m, ldd = m; - hellinger( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + // Then calculate Hellinger distance + ops::hellinger_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + + // Finally revert sqrt of x and y + raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); + if (x != y) { + raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } + + RAFT_CUDA_TRY(cudaGetLastError()); } } // namespace detail } // namespace distance From 4dc72ce00b25e077e0654ae4f05f8b8c3b3e7789 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 20:14:28 +0100 Subject: [PATCH 22/93] Jensen-Shannon: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 2 +- .../detail/distance_ops/jensen_shannon.cuh | 65 ++++++ .../raft/distance/detail/jensen_shannon.cuh | 188 ++---------------- 3 files changed, 85 insertions(+), 170 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index 47239d3f69..4f18391fce 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -7,7 +7,7 @@ - [ ] cosine.cuh - [X] hamming.cuh - [X] hellinger.cuh -- [ ] jensen_shannon.cuh +- [X] jensen_shannon.cuh - [ ] kl_divergence.cuh - [ ] minkowski.cuh - [ ] russell_rao.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh new file mode 100644 index 0000000000..116af61964 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft::distance::detail::ops { + +// Describes the computation the jensen_shannon distance + +struct jensen_shannon_distance_op { + // Load norms of input data + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const DataT m = 0.5f * (x + y); + const bool m_zero = (m == 0); + const auto logM = (!m_zero) * raft::log(m + m_zero); + + const bool x_zero = (x == 0); + const bool y_zero = (y == 0); + acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(0.5 * acc[i][j]); + } + } + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh index f96da01b87..71339e0c1a 100644 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -15,157 +15,14 @@ */ #pragma once -#include +#include "distance_ops/jensen_shannon.cuh" +#include "pairwise_matrix/dispatch.cuh" + namespace raft { namespace distance { namespace detail { -/** - * @brief the Jensen Shannon distance matrix: - * It computes the following equation: - Cij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) - + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) - * - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] dOutput output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - */ -template -static void jensenShannonImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const DataT m = 0.5f * (x + y); - const bool m_zero = (m == 0); - const auto logM = (!m_zero) * raft::log(m + m_zero); - - const bool x_zero = (x == 0); - const bool y_zero = (y == 0); - acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(0.5 * acc[i][j]); - } - } - }; - - if (isRowMajor) { - auto jensenShannonRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, jensenShannonRowMajor); - - jensenShannonRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto jensenShannonColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, jensenShannonColMajor); - jensenShannonColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void jensenShannon(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - jensenShannonImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - jensenShannonImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - jensenShannonImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} - /** * @brief the Jensen Shannon distance matrix calculation * It computes the following equation: @@ -187,35 +44,28 @@ void jensenShannon(IdxT m, * @param stream cuda stream where to launch work * @param isRowMajor whether the input and output matrices are row major */ -template +template void jensenShannonImpl(int m, int n, int k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, cudaStream_t stream, - bool isRowMajor) + bool is_row_major) { - typedef std::is_same is_bool; - typedef typename std::conditional::type jensenShannonOutType; - Index_ lda, ldb, ldd; - jensenShannonOutType* pDcast = reinterpret_cast(pD); - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - jensenShannon( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + ops::jensen_shannon_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; - } else { - lda = n, ldb = m, ldd = m; - jensenShannon( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); - } + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } } // namespace detail } // namespace distance From b0d36c1cc4d2f34bdc8ff3a6281ad0b7145b2b32 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 20:53:29 +0100 Subject: [PATCH 23/93] remove old hamming code --- cpp/include/raft/distance/detail/hamming.cuh | 137 ------------------- 1 file changed, 137 deletions(-) diff --git a/cpp/include/raft/distance/detail/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh index 7d283def21..9935c96a40 100644 --- a/cpp/include/raft/distance/detail/hamming.cuh +++ b/cpp/include/raft/distance/detail/hamming.cuh @@ -22,143 +22,6 @@ namespace raft { namespace distance { namespace detail { -/** - * @brief the Hamming distance matrix using the unexpanded form: - * It computes the following equation: - Cij = sum(x_i != y_i) / k - * - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] dOutput output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - */ -template -static void hammingUnexpandedImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += (x != y); }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [k] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - const DataT one_over_k = DataT(1.0) / k; -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] *= one_over_k; - } - } - }; - - if (isRowMajor) { - auto hammingUnexpandedRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, hammingUnexpandedRowMajor); - - hammingUnexpandedRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto hammingUnexpandedColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, hammingUnexpandedColMajor); - hammingUnexpandedColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void hammingUnexpanded(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - hammingUnexpandedImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - hammingUnexpandedImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - hammingUnexpandedImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} - /** * @brief the Hamming Unexpanded distance matrix calculation * It computes the following equation: From e95a65bbee829ede8cfc3645f41bc512d38afdab Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 20:55:35 +0100 Subject: [PATCH 24/93] KL divergence: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 6 +- .../detail/distance_ops/kl_divergence.cuh | 88 +++++ .../raft/distance/detail/kl_divergence.cuh | 329 +++--------------- 3 files changed, 135 insertions(+), 288 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index 4f18391fce..f5d3b6b0a6 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -8,6 +8,10 @@ - [X] hamming.cuh - [X] hellinger.cuh - [X] jensen_shannon.cuh -- [ ] kl_divergence.cuh +- [X] kl_divergence.cuh + - *Notes*: the isRowMajor and x_equal_y boolean parameters where previously + template / constexpr parameters. Now they are passed by value. This greatly + reduces the number of kernels, but may have negative consequences for run + time. - [ ] minkowski.cuh - [ ] russell_rao.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh new file mode 100644 index 0000000000..a1f438e0d4 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft::distance::detail::ops { + +// Describes the computation of the kl_divergence +struct kl_divergence_op { + const bool is_row_major; + const bool x_equal_y; + + kl_divergence_op( + bool row_major_, + bool x_equal_y_=false + ) noexcept + : is_row_major(row_major_), + x_equal_y(x_equal_y_) + { } + + // Load norms of input data + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + // TODO: make sure that these branches get hoisted out of main loop.. Could + // be quite expensive otherwise. + if (x_equal_y) { + if (is_row_major) { + const bool x_zero = (x == 0); + const bool y_zero = (y == 0); + acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); + } else { + const bool y_zero = (y == 0); + const bool x_zero = (x == 0); + acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); + } + } else { + if (is_row_major) { + const bool x_zero = (x == 0); + acc += x * (raft::log(x + x_zero) - y); + } else { + const bool y_zero = (y == 0); + acc += y * (raft::log(y + y_zero) - x); + } + } + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = (0.5f * acc[i][j]); + } + } + } +}; +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index 7ebeaf4de9..e76cd5a3b9 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -15,276 +15,16 @@ */ #pragma once -#include +#include +#include + +#include "distance_ops/kl_divergence.cuh" +#include "pairwise_matrix/dispatch.cuh" namespace raft { namespace distance { namespace detail { -/** - * @brief the KL Divergence distance matrix: - * It computes the following equation: - Cij = 0.5 * sum(x * log (x / y)); - * This distance computation modifies A or B by computing a log(x) - * and then performing a `pow(e, log(x))` to convert it back. Because of this, - * it is possible that the values in A or B might differ slightly - * after this is invoked. - * - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] dOutput output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - */ -template -static void klDivergenceImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - if (isRowMajor) { - const bool x_zero = (x == 0); - acc += x * (raft::log(x + x_zero) - y); - } else { - const bool y_zero = (y == 0); - acc += y * (raft::log(y + y_zero) - x); - } - }; - - auto core_lambda_x_equal_y = [] __device__(AccT & acc, DataT & x, DataT & y) { - if (isRowMajor) { - const bool x_zero = (x == 0); - const bool y_zero = (y == 0); - acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); - } else { - const bool y_zero = (y == 0); - const bool x_zero = (x == 0); - acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); - } - }; - - auto unaryOp_lambda = [] __device__(DataT input) { - const bool x_zero = (input == 0); - return (!x_zero) * raft::log(input + x_zero); - }; - - auto unaryOp_lambda_reverse = [] __device__(DataT input) { - // reverse previous log (x) back to x using (e ^ log(x)) - const bool x_zero = (input == 0); - return (!x_zero) * raft::exp(input); - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = (0.5f * acc[i][j]); - } - } - }; - - if (isRowMajor) { - constexpr auto klDivergenceRowMajor = pairwiseDistanceMatKernel; - constexpr auto klDivergenceRowMajorXequalY = - pairwiseDistanceMatKernel; - if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, klDivergenceRowMajor); - klDivergenceRowMajor<<>>(x, - y, - nullptr, - nullptr, - m, - n, - k, - lda, - ldb, - ldd, - dOutput, - core_lambda, - epilog_lambda, - fin_op); - // Now reverse previous log (x) back to x using (e ^ log(x)) - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda_reverse, stream); - } else { - dim3 grid = - launchConfigGenerator(m, n, KPolicy::SmemSize, klDivergenceRowMajorXequalY); - klDivergenceRowMajorXequalY<<>>(x, - y, - nullptr, - nullptr, - m, - n, - k, - lda, - ldb, - ldd, - dOutput, - core_lambda_x_equal_y, - epilog_lambda, - fin_op); - } - } else { - constexpr auto klDivergenceColMajor = pairwiseDistanceMatKernel; - constexpr auto klDivergenceColMajorXequalY = - pairwiseDistanceMatKernel; - if (x != y) { - raft::linalg::unaryOp( - (DataT*)x, x, m * k, unaryOp_lambda, stream); - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, klDivergenceColMajor); - klDivergenceColMajor<<>>(x, - y, - nullptr, - nullptr, - m, - n, - k, - lda, - ldb, - ldd, - dOutput, - core_lambda, - epilog_lambda, - fin_op); - // Now reverse previous log (x) back to x using (e ^ log(x)) - raft::linalg::unaryOp( - (DataT*)x, x, m * k, unaryOp_lambda_reverse, stream); - } else { - dim3 grid = - launchConfigGenerator(m, n, KPolicy::SmemSize, klDivergenceColMajorXequalY); - klDivergenceColMajorXequalY<<>>(x, - y, - nullptr, - nullptr, - m, - n, - k, - lda, - ldb, - ldd, - dOutput, - core_lambda_x_equal_y, - epilog_lambda, - fin_op); - } - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void klDivergence(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - klDivergenceImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - klDivergenceImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - klDivergenceImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} - /** * @brief the KL Divergence distance matrix calculation * It computes the following equation: @@ -308,34 +48,49 @@ void klDivergence(IdxT m, * @param stream cuda stream where to launch work * @param isRowMajor whether the input and output matrices are row major */ -template +template void klDivergenceImpl(int m, int n, int k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, cudaStream_t stream, - bool isRowMajor) + bool is_row_major) { - typedef std::is_same is_bool; - typedef typename std::conditional::type klDivergenceOutType; - Index_ lda, ldb, ldd; - klDivergenceOutType* pDcast = reinterpret_cast(pD); - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - klDivergence( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + auto unaryOp_lambda = [] __device__(DataT input) { + const bool x_zero = (input == 0); + return (!x_zero) * raft::myLog(input + x_zero); }; + + auto unaryOp_lambda_reverse = [] __device__(DataT input) { + // reverse previous log (x) back to x using (e ^ log(x)) + const bool x_zero = (input == 0); + return (!x_zero) * raft::myExp(input); }; + + // This op takes some shortcuts when x equals y. So it behavior changes based + // on this. + ops::kl_divergence_op kl_divergence{is_row_major, x == y}; + + if (x != y) { + raft::linalg::unaryOp( + (DataT*)y, y, n * k, unaryOp_lambda, stream); + } + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); - } else { - lda = n, ldb = m, ldd = m; - klDivergence( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + if (x != y) { + // Now reverse previous log (x) back to x using (e ^ log(x)) + raft::linalg::unaryOp( + (DataT*)y, y, n * k, unaryOp_lambda_reverse, stream); } } } // namespace detail From f1c105bd070483d032edbdc8e0b356879f9a89b4 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 21:09:39 +0100 Subject: [PATCH 25/93] Minkowski: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 6 +- .../detail/distance_ops/minkowski.cuh | 66 ++++++ .../raft/distance/detail/minkowski.cuh | 192 ++---------------- 3 files changed, 92 insertions(+), 172 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/minkowski.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index f5d3b6b0a6..a82cc9a0e3 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -1,6 +1,8 @@ #+title: Readme - +- [X] Euclidean + - *Notes*: isRowMajor is now a runtime parameter. Was it a compile time + parameter before? - [X] canberra.cuh - [X] chebyshev.cuh - [X] correlation.cuh @@ -13,5 +15,5 @@ template / constexpr parameters. Now they are passed by value. This greatly reduces the number of kernels, but may have negative consequences for run time. -- [ ] minkowski.cuh +- [X] minkowski.cuh - [ ] russell_rao.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh b/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh new file mode 100644 index 0000000000..11be4e6ae0 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft::distance::detail::ops { + +// Describes the computation the minkowski distance + +template +struct minkowski_distance_op { + DataT_struct p; + + minkowski_distance_op(DataT_struct p_) noexcept : p(p_) { } + + // Load norms of input data + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const auto diff = raft::abs(x - y); + acc += raft::pow(diff, p); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + const auto one_over_p = 1.0f / p; +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = raft::pow(acc[i][j], one_over_p); + } + } + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index 42af8cd281..778ceb45cf 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -15,154 +15,13 @@ */ #pragma once -#include +#include "pairwise_matrix/dispatch.cuh" +#include "distance_ops/minkowski.cuh" namespace raft { namespace distance { namespace detail { -/** - * @brief the unexpanded Minkowski distance matrix calculation - * It computes the following equation: cij = sum(|x - y|^p)^(1/p) - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and cols of C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - * @param[in] the value of `p` for Minkowski (l-p) distances. - */ -template -void minkowskiUnExpImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream, - DataT p) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [p] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::abs(x - y); - acc += raft::pow(diff, p); - }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [p] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - const auto one_over_p = 1.0f / p; -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::pow(acc[i][j], one_over_p); - } - } - }; - - if (isRowMajor) { - auto minkowskiUnExpRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, minkowskiUnExpRowMajor); - - minkowskiUnExpRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - - } else { - auto minkowskiUnExpColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, minkowskiUnExpColMajor); - - minkowskiUnExpColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void minkowskiUnExp(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream, - DataT metric_arg) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - minkowskiUnExpImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream, metric_arg); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - minkowskiUnExpImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream, metric_arg); - } else { - minkowskiUnExpImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream, metric_arg); - } -} - /** * @brief the unexpanded minkowski distance matrix calculation * It computes the following equation: cij = sum(|x - y|^p)^(1/p) @@ -182,36 +41,29 @@ void minkowskiUnExp(IdxT m, * @param[in] isRowMajor whether the input and output matrices are row major * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. */ -template -void minkowskiImpl(Index_ m, - Index_ n, - Index_ k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, +template +void minkowskiImpl(IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, cudaStream_t stream, - bool isRowMajor, - InType metric_arg) + bool is_row_major, + DataT metric_arg) { - typedef std::is_same is_bool; - typedef typename std::conditional::type LpUnexpOutType; - LpUnexpOutType* pDcast = reinterpret_cast(pD); - Index_ lda, ldb, ldd; + ops::minkowski_distance_op distance_op{metric_arg}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - minkowskiUnExp( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream, metric_arg); - } else { - lda = n, ldb = m, ldd = m; - minkowskiUnExp( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream, metric_arg); - } + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } }; // end namespace detail }; // end namespace distance From ac66e3f9bc4109eef8be103c408ff637c5f708d5 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 21:19:38 +0100 Subject: [PATCH 26/93] Russel-Rao: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 2 +- .../detail/distance_ops/russel_rao.cuh | 67 +++++++ .../raft/distance/detail/russell_rao.cuh | 180 ++---------------- 3 files changed, 86 insertions(+), 163 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index a82cc9a0e3..4c2005381c 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -16,4 +16,4 @@ reduces the number of kernels, but may have negative consequences for run time. - [X] minkowski.cuh -- [ ] russell_rao.cuh +- [X] russell_rao.cuh diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh new file mode 100644 index 0000000000..d4d1044b6e --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::distance::detail::ops { + +// Describes the computation the russel_rao distance + +template +struct russel_rao_distance_op { + IdxT_struct k; + const float one_over_k; + + russel_rao_distance_op(IdxT_struct k_) noexcept + : k(k_), + one_over_k(1.0f / k_) + { } + + // Load norms of input data + static constexpr bool use_norms = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += x * y; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = (k - acc[i][j]) * one_over_k; + } + } + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/russell_rao.cuh b/cpp/include/raft/distance/detail/russell_rao.cuh index 5d516e7830..5e8da08b1d 100644 --- a/cpp/include/raft/distance/detail/russell_rao.cuh +++ b/cpp/include/raft/distance/detail/russell_rao.cuh @@ -15,150 +15,13 @@ */ #pragma once -#include +#include "distance_ops/russel_rao.cuh" +#include "pairwise_matrix/dispatch.cuh" namespace raft { namespace distance { namespace detail { -/** - * @brief the Russell Rao distance matrix: - * It computes the following equation: - Cij = (k - sum(x_i * y_i)) / k - * - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread - for every LDG call. details in contractions.cuh - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and C/D - * @param[in] k number of cols of A and B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] dOutput output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - */ -template -static void russellRaoImpl(const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; - - const float one_over_k = 1.0 / k; - // epilogue operation lambda for final value calculation - auto epilog_lambda = [k, one_over_k] __device__( - AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = (k - acc[i][j]) * one_over_k; - } - } - }; - - if (isRowMajor) { - constexpr auto russellRaoRowMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, russellRaoRowMajor); - - russellRaoRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - constexpr auto russellRaoColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, russellRaoColMajor); - russellRaoColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void russellRao(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - russellRaoImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - russellRaoImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - russellRaoImpl( - x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} - /** * @brief the Russell Rao distance matrix calculation * It computes the following equation: @@ -179,35 +42,28 @@ void russellRao(IdxT m, * @param stream cuda stream where to launch work * @param isRowMajor whether the input and output matrices are row major */ -template +template void russellRaoImpl(int m, int n, int k, - const InType* pA, - const InType* pB, - OutType* pD, - FinalLambda fin_op, + const DataT* x, + const DataT* y, + OutT* out, + FinOpT fin_op, cudaStream_t stream, - bool isRowMajor) + bool is_row_major) { - typedef std::is_same is_bool; - typedef typename std::conditional::type russellRaoOutType; - Index_ lda, ldb, ldd; - russellRaoOutType* pDcast = reinterpret_cast(pD); - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - russellRao( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + ops::russel_rao_distance_op distance_op{k}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; - } else { - lda = n, ldb = m, ldd = m; - russellRao( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); - } + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } } // namespace detail } // namespace distance From a89896a456d72592a8490a057d325b4d45fea930 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 21:45:38 +0100 Subject: [PATCH 27/93] Cosine: use pairwise matrix dispatch --- cpp/include/raft/distance/detail/README.org | 12 +- cpp/include/raft/distance/detail/cosine.cuh | 261 ++++-------------- .../distance/detail/distance_ops/cosine.cuh | 70 +++++ .../raft/distance/detail/euclidean.cuh | 1 - 4 files changed, 135 insertions(+), 209 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/cosine.cuh diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index 4c2005381c..03d540cb84 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -1,12 +1,18 @@ #+title: Readme - [X] Euclidean - - *Notes*: isRowMajor is now a runtime parameter. Was it a compile time - parameter before? + - *Notes*: + - enable_sqrt is now a runtime parameter. Was it a compile time + parameter before? + - CUTLASS fails on CUDA 12 (but prior to refactoring CUDA 12 did not work + either). I have not yet tested if everything works correctly on CUDA 11. - [X] canberra.cuh - [X] chebyshev.cuh - [X] correlation.cuh -- [ ] cosine.cuh +- [X] cosine.cuh + - *Notes*: cutlass fails on CUDA 12 (but prior to refactoring CUDA 12 did not + work either). I have not yet tested if everything works correctly on + CUDA 11. - [X] hamming.cuh - [X] hellinger.cuh - [X] jensen_shannon.cuh diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index 46a694aa51..ea1dd64933 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -15,181 +15,15 @@ */ #pragma once - -#include -#include #include -#include + +#include "pairwise_matrix/dispatch.cuh" +#include "distance_ops/cosine.cuh" namespace raft { namespace distance { namespace detail { -template -struct CosineOp { - __device__ CosineOp() noexcept {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); - } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } -}; - -/** - * @brief the cosine distance matrix calculation implementer - * It computes the following equation: - * C = 1 - op(A * B / sqrt(A^2) * sqrt(B^2))) - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Veclen number of k-elements loaded by each thread for every LDG call - * it makes. check contractions.cuh for details. - * @tparam FinalLambda the final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major, - false for column major - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] xn row norms of input matrix A. - * @param[in] yn row norms of input matrix B. - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix - * @param fin_op the final gemm epilogue lambda -* @param stream cuda stream to launch cuda operations. - */ -template -void cosineImpl(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ -#if (__CUDACC_VER_MAJOR__ < 12) - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { - using CosineOp_ = CosineOp; - CosineOp_ cosine_dist_op; - - cutlassDistanceKernel( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, cosine_dist_op, stream); - - } else -#endif - { - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - - typedef typename std::conditional::type KPolicy; - - dim3 blk(KPolicy::Nthreads); - - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { -#pragma unroll - for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); - } - } - }; - - constexpr size_t shmemSize = - KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); - if (isRowMajor) { - auto cosineRowMajor = pairwiseDistanceMatKernelPriorToAmpere; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); - cosineRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } else { - auto cosineColMajor = pairwiseDistanceMatKernelPriorToAmpere; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); - cosineColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); - } - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void cosine(IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - OutT* dOutput, - FinalLambda fin_op, - cudaStream_t stream) -{ - size_t bytesA = sizeof(DataT) * lda; - size_t bytesB = sizeof(DataT) * ldb; - if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - cosineImpl( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - cosineImpl( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } else { - cosineImpl( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); - } -} - /** * @brief the expanded cosine distance matrix calculation * It computes the following equation: @@ -213,57 +47,74 @@ void cosine(IdxT m, * @param stream cuda stream where to launch work * @param isRowMajor whether the input and output matrices are row major */ -template -void cosineAlgo1(Index_ m, - Index_ n, - Index_ k, - const InType* pA, - const InType* pB, - OutType* pD, - AccType* workspace, - size_t worksize, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor) +template +void cosineAlgo1(IdxT m, + IdxT n, + IdxT k, + const DataT* pA, + const DataT* pB, + OutT* pD, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool isRowMajor) { // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), - "OutType can be uint8_t, float, double," - "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); - typedef typename std::conditional::type CosOutType; - CosOutType* pDcast = reinterpret_cast(pD); + static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), + "OutT can be uint8_t, float, double," + "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); ASSERT( - !(((pA != pB) && (worksize < (m + n) * sizeof(AccType))) || (worksize < m * sizeof(AccType))), + !(((pA != pB) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); - Index_ lda, ldb, ldd; - InType* col_vec = workspace; - InType* row_vec = workspace; + + DataT* norm_A = workspace; + DataT* norm_B = workspace; if (pA != pB) { - row_vec += m; + norm_B += m; raft::linalg::rowNorm( - col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); + norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); raft::linalg::rowNorm( - row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); + norm_B, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); } else { raft::linalg::rowNorm( - col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); + norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); } - if (isRowMajor) { - lda = k, ldb = k, ldd = n; - cosine( - m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, fin_op, stream); + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel otherwise. + + if constexpr (__CUDACC_VER_MAJOR__ == 12) { + // Always execute legacy kernels on CUDA 12 + ops::cosine_distance_op distance_op{}; + distance_matrix_dispatch( + distance_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); } else { - lda = n, ldb = m, ldd = m; - cosine( - n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast, fin_op, stream); + const auto deviceVersion = getComputeCapability(); + if (deviceVersion.first >= 8) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using Op = ops::cosine_cutlass_op; + Op distance_op{}; + + distance_matrix_cutlass_dispatch( + distance_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); + } else { + // Else use "legacy" L2 + ops::cosine_distance_op distance_op{}; + distance_matrix_dispatch( + distance_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); + } } } diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh new file mode 100644 index 0000000000..c2679d5380 --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::distance::detail::ops { + +// Describes the computation the cosine distance + +struct cosine_distance_op { + // Load norms of input data + static constexpr bool use_norms = true; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + constexpr size_t shared_mem_size() + { + return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + } + + template + DI void core(AccT& acc, DataT& x, DataT& y) const + { + acc += x * y; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); + } + } + } +}; + + +template +struct cosine_cutlass_op { + __device__ cosine_cutlass_op() noexcept {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 51e2ff224f..3cdc5489a6 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -22,7 +22,6 @@ #include "distance_ops/l2_exp.cuh" #include "distance_ops/l2_unexp.cuh" - namespace raft { namespace distance { namespace detail { From 16b2acdc55da1f1877b086233399361418f539df Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 13 Jan 2023 21:49:39 +0100 Subject: [PATCH 28/93] Fix include for l1 op --- cpp/include/raft/distance/detail/distance_ops/l1.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 9d31b24851..7153154588 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include namespace raft::distance::detail::ops { From 1326e3408254b8ad9562541c0d9d35dc00241cb2 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 10 Feb 2023 11:21:35 +0100 Subject: [PATCH 29/93] kl_divergence: Use raft::log instead of raft::myLog --- cpp/include/raft/distance/detail/kl_divergence.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index e76cd5a3b9..e2f7bf2beb 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -65,14 +65,14 @@ void klDivergenceImpl(int m, { auto unaryOp_lambda = [] __device__(DataT input) { const bool x_zero = (input == 0); - return (!x_zero) * raft::myLog(input + x_zero); }; + return (!x_zero) * raft::log(input + x_zero); }; auto unaryOp_lambda_reverse = [] __device__(DataT input) { // reverse previous log (x) back to x using (e ^ log(x)) const bool x_zero = (input == 0); - return (!x_zero) * raft::myExp(input); }; + return (!x_zero) * raft::exp(input); }; - // This op takes some shortcuts when x equals y. So it behavior changes based + // This op takes some shortcuts when x equals y. So its behavior changes based // on this. ops::kl_divergence_op kl_divergence{is_row_major, x == y}; From 0169b26f55b5a4d3ed7db3b05c72fb9a6a7f50e4 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 10 Feb 2023 11:23:53 +0100 Subject: [PATCH 30/93] distance_op: Add expensive_inner_loop marker This indicates that the operator uses expensive operations (pow, exp, log) in the inner loop. Therefore, unrolling and or veclen parameters should be adjusted --- cpp/include/raft/distance/detail/distance_ops/canberra.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/correlation.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/cosine.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/hamming.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/hellinger.cuh | 3 +++ .../raft/distance/detail/distance_ops/jensen_shannon.cuh | 3 +++ .../raft/distance/detail/distance_ops/kl_divergence.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/l1.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/minkowski.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh | 3 +++ cpp/include/raft/distance/detail/distance_ops/template.cuh | 3 +++ 14 files changed, 42 insertions(+) diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 4fda825286..e9c16d6d6d 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -25,6 +25,9 @@ namespace raft::distance::detail::ops { struct canberra_distance_op { // Load norms of input data static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = true; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh b/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh index ced9fcf6f7..a68d9fc21c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh @@ -25,6 +25,9 @@ namespace raft::distance::detail::ops { struct chebyshev_distance_op { // Load norms of input data static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 98d90ea0a5..eb18355ca9 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -54,6 +54,9 @@ struct correlation_distance_op { // Load norms of input data static constexpr bool use_norms = true; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index c2679d5380..bbc1ffcba2 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -23,6 +23,9 @@ namespace raft::distance::detail::ops { struct cosine_distance_op { // Load norms of input data static constexpr bool use_norms = true; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 1f88424d70..c8b3b7658e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -28,6 +28,9 @@ struct hamming_distance_op { // Load norms of input data static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index b01f118923..b0fae700b5 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -26,6 +26,9 @@ namespace raft::distance::detail::ops { struct hellinger_distance_op { // Load norms of input data static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index 116af61964..124010e96d 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -24,6 +24,9 @@ namespace raft::distance::detail::ops { struct jensen_shannon_distance_op { // Load norms of input data static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = true; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index a1f438e0d4..a97582aa5a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -34,6 +34,9 @@ struct kl_divergence_op { // Load norms of input data static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = true; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 7153154588..4bb4a8796c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -23,6 +23,9 @@ namespace raft::distance::detail::ops { struct l1_distance_op { // Do not load norms of data, the computation of L1 distance does not use them. static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 4dfb26a826..13a41190c1 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -28,6 +28,9 @@ struct l2_exp_distance_op { // Load norms of input data static constexpr bool use_norms = true; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index 03bbd936c6..31fbd11667 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -26,6 +26,9 @@ struct l2_unexp_distance_op { // Do not load norms of data, the computation of L1 distance does not use them. static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh b/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh index 11be4e6ae0..8deb42d1fe 100644 --- a/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh @@ -29,6 +29,9 @@ struct minkowski_distance_op { // Load norms of input data static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = true; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index d4d1044b6e..f46a1a5e67 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -32,6 +32,9 @@ struct russel_rao_distance_op { // Load norms of input data static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index c770a575a0..d7bbfc7fca 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -29,6 +29,9 @@ struct template_distance_op { // Load norms of input data static constexpr bool use_norms = TODO; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. From 52e95e1f255dbddbb5dda6513ce4feff6386bc5e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 10 Feb 2023 11:26:55 +0100 Subject: [PATCH 31/93] Update copyright notices --- cpp/include/raft/distance/detail/cosine.cuh | 2 +- cpp/include/raft/distance/detail/hamming.cuh | 2 +- .../raft/distance/detail/pairwise_distance_cutlass_base.cuh | 2 +- cpp/include/raft/distance/detail/russell_rao.cuh | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index ea1dd64933..4ae0c285f5 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh index 9935c96a40..824e930023 100644 --- a/cpp/include/raft/distance/detail/hamming.cuh +++ b/cpp/include/raft/distance/detail/hamming.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index efd44ea4dc..0d26d940b3 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/russell_rao.cuh b/cpp/include/raft/distance/detail/russell_rao.cuh index 5e8da08b1d..6bf5ae04bb 100644 --- a/cpp/include/raft/distance/detail/russell_rao.cuh +++ b/cpp/include/raft/distance/detail/russell_rao.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 28cd57bffefdf8037ff8e3ea1618efffd76e27ca Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 10 Feb 2023 11:52:05 +0100 Subject: [PATCH 32/93] Reusable dispatch mechanism --- .../detail/pairwise_matrix/dispatch.cuh | 241 ++++++++---------- 1 file changed, 104 insertions(+), 137 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 650c8fa805..1eb2a65d5a 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,99 +15,101 @@ */ #pragma once +#include "kernel_sm60.cuh" #include -#include -#include #include -#include "kernel_sm60.cuh" +#include +#include namespace raft::distance::detail { +/** + * @brief: Computes minimal alignment of row starting elements in 2D array + * + * The 2D matrix x is assumed to be row-major. This function computes the + * minimal alignment in bytes of the first elements of each row. + * Output can be 16, 8, 4, 2, 1. + * + * @param x Base pointer of row-major input matrix + * @param stride Stride in number of element between consecutive rows. + */ template -struct params_dispatch { - int vectorized_load_num_elem = 1; - bool row_major = true; - - template - struct params_constexpr { - static constexpr int vec_len = vl; - static constexpr bool is_row_major = rm; - }; +size_t alignment_of_2d_array(const DataT* x, size_t stride) +{ + auto base = reinterpret_cast(x); + size_t stride_bytes = sizeof(DataT) * stride; - // Turn run-time parameters into compile-time parameters. - // Call the provided function f with these compile-time parameters. - // Returns false if dispatch fails, i.e., if there is no implementation - // for the given runtime parameters. - template - bool dispatch_with_compile_time_params(F&& f) const - { - return convert_vectorized_load_num_elem(f); + for (int align = 16; align >= 0; align /= 2) { + bool base_aligned = base % align == 0; + bool stride_aligned = stride_bytes % align == 0; + if (base_aligned && stride_aligned) { return align; } } + return 1; +} - // Step 1: convert alignment into a compile time constant - template - bool convert_vectorized_load_num_elem(F&& f) const +template +struct alignment_tag { + static constexpr int value = n; +}; + +struct alignment_dispatch { + size_t byte_alignment = 0; + + template + alignment_dispatch(const DataT* x, const DataT* y, size_t ldx, size_t ldy) { - bool fail = false; - switch (vectorized_load_num_elem) { - case 1: return layout<1>(f); - case 2: return layout<2>(f); - case 4: return layout<4>(f); - default: return fail; - }; + size_t align_x = alignment_of_2d_array(x, ldx); + size_t align_y = alignment_of_2d_array(y, ldy); + + byte_alignment = min(align_x, align_y); } - // Step 2: convert layout into a compile time constant - template - bool layout(F&& f) const + template + auto operator()(F&& f) const { - if (row_major) { - return to_compile_time_params(f); - } else { - return to_compile_time_params(f); + switch (byte_alignment) { + case 16: f(alignment_tag<16>()); break; + case 8: f(alignment_tag<8>()); break; + case 4: f(alignment_tag<4>()); break; + case 2: f(alignment_tag<2>()); break; + default: f(alignment_tag<1>()); break; } } +}; - // Step 3: convert compile-time constant into compile-time parameter struct and invoke - // function f with these compile time parameters. - template - bool to_compile_time_params(F&& f) const - { - // Create compile-time parameter type and instantiate a struct; - using ct_params_T = params_constexpr; - ct_params_T compile_time_params{}; +template +struct row_major_tag { + static constexpr int value = rm; +}; - // Dispatch to f - f(compile_time_params); +struct row_major_dispatch { + bool is_row_major_; + row_major_dispatch(bool row_major) : is_row_major_(row_major) {} - bool dispatch_success = true; - return dispatch_success; + template + auto operator()(F&& f) const + { + if (is_row_major_) { + f(row_major_tag()); + } else { + f(row_major_tag()); + } } }; -// Determine the largest number of elements that can be loaded in one -// instruction without causing misalignment errors. -template -int vectorized_load_num_elem(const DataT* x, const DataT* y, IdxT ldx, IdxT ldy) +template +auto join_dispatch(F1&& f1, F2&& f2) { - auto base_x = reinterpret_cast(x); - auto base_y = reinterpret_cast(y); - size_t stride_X = sizeof(DataT) * ldx; // stride in bytes - size_t stride_Y = sizeof(DataT) * ldy; // stride in bytes - - bool base_16B_aligned = base_x % 16 == 0 && base_y % 16 == 0; - bool base_8B_aligned = base_x % 8 == 0 && base_y % 8 == 0; - - bool stride_16B_aligned = stride_X % 16 == 0 && stride_Y % 16 == 0; - bool stride_8B_aligned = stride_X % 8 == 0 && stride_Y % 8 == 0; + const auto lam = [f1, f2](auto f) { + f1([f, f2](auto... args1) { f2([f, args1...](auto... args2) { f(args1..., args2...); }); }); + }; + return lam; +} - if (16 % sizeof(DataT) == 0 && base_16B_aligned && stride_16B_aligned) { - return 16 / sizeof(DataT); - } else if (8 % sizeof(DataT) == 0 && base_8B_aligned && stride_8B_aligned) { - return 8 / sizeof(DataT); - } else { - return 1; - } +template +auto join_dispatch(F1 f1, F2 f2, Fs... fs) +{ + return join_dispatch(join_dispatch(f1, f2), std::forward(fs)...); } template run_time_params{ - vectorized_load_num_elem(x, y, ldx, ldy), // 1. num array elements per load instruction - is_row_major // 2. the layout of x, y, and out - }; + alignment_dispatch d_align(x, y, ldx, ldy); + row_major_dispatch d_row_major(is_row_major); + auto dispatch = join_dispatch(d_align, d_row_major); - // Turn run-time parameters into compile-time parameters. - bool dispatch_success = run_time_params.dispatch_with_compile_time_params( - // We pass a lambda that receives the compile-time parameters and can use these - // to call the correct kernel. - [&](auto p) { - // p has two constexpr members: - // - vec_len - // - is_row_major - - // There is no instruction to load 4 doubles, so we catch this situation - // and load 2 doubles. - constexpr bool load_4_doubles = sizeof(DataT) > 4 && p.vec_len == 4; - constexpr int vec_len = (load_4_doubles) ? 2 : p.vec_len; - - // Determine kernel policy using vec_len and layout - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; - - // Create compile-time template parameter - using KP_T = kernel_params_T; - - return pairwise_matrix( - distance_op, - fin_op, - x, - y, - x_norm, - y_norm, - m, - n, - k, - ldx, - ldy, - ld_out, - out, - stream); - }); - - if (!dispatch_success) { - std::printf("Dispatch error(!)\n"); - // TODO - } + dispatch([&](auto alignment_tag, auto row_major_tag) { + // Compute number of elements that can be loaded in one instruction + // without causing misalignent errors. + constexpr int vec_len_ideal = + (alignment_tag.value % sizeof(DataT) == 0) ? alignment_tag.value / sizeof(DataT) : 1; + + // To keep compile times in check, we only specialize on veclen > 1 when + // the inner loop is relatively cheap (< 5 flops). + constexpr int vec_len = distance_op.expensive_inner_loop ? 1 : vec_len_ideal; + + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename std::conditional::type Policy; + + // Create compile-time template parameter + using KP_T = kernel_params_T; + + return pairwise_matrix( + distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); + }); } template run_time_params{ - vectorized_load_num_elem(x, y, ldx, ldy), - is_row_major - }; + alignment_dispatch d_align(x, y, ldx, ldy); + row_major_dispatch d_row_major(is_row_major); - bool dispatch_success = run_time_params.dispatch_with_compile_time_params( - [&](auto p) { - // Prevent loading 4 doubles in one instruction. - constexpr bool load_4_doubles = sizeof(DataT) > 4 && p.vec_len == 4; - constexpr int vec_len = (load_4_doubles) ? 2 : p.vec_len; + auto dispatch = join_dispatch(d_align, d_row_major); - cutlassDistanceKernel( - x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); - }); + dispatch([&](auto alignment_tag, auto row_major_tag) { + constexpr int vec_len = + (alignment_tag.value % sizeof(DataT) == 0) ? alignment_tag.value / sizeof(DataT) : 1; - if (!dispatch_success) { - std::printf("Dispatch error(!)\n"); - // TODO - } + cutlassDistanceKernel( + x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); + }); } }; // namespace raft::distance::detail From c44aecef0ba07bea2e91f7690d3552b03886ccfa Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 10 Feb 2023 14:36:49 +0100 Subject: [PATCH 33/93] Dispatch mechanism using switch statement I fear the other way was getting too complicated and possibilities for reuse were scarce anyway. --- .../detail/pairwise_matrix/dispatch.cuh | 152 ++++++++---------- 1 file changed, 69 insertions(+), 83 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 1eb2a65d5a..cf95b10960 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -48,68 +48,27 @@ size_t alignment_of_2d_array(const DataT* x, size_t stride) } template -struct alignment_tag { - static constexpr int value = n; -}; +using align_constant = std::integral_constant; -struct alignment_dispatch { - size_t byte_alignment = 0; - - template - alignment_dispatch(const DataT* x, const DataT* y, size_t ldx, size_t ldy) - { - size_t align_x = alignment_of_2d_array(x, ldx); - size_t align_y = alignment_of_2d_array(y, ldy); - - byte_alignment = min(align_x, align_y); - } - - template - auto operator()(F&& f) const - { +template +inline void dispatch(bool row_major, size_t byte_alignment, F&& f) { + if (row_major) { switch (byte_alignment) { - case 16: f(alignment_tag<16>()); break; - case 8: f(alignment_tag<8>()); break; - case 4: f(alignment_tag<4>()); break; - case 2: f(alignment_tag<2>()); break; - default: f(alignment_tag<1>()); break; + case 16: f(std::bool_constant(), align_constant<16>()); break; + case 8: f(std::bool_constant(), align_constant<8>()); break; + case 4: f(std::bool_constant(), align_constant<4>()); break; + case 2: f(std::bool_constant(), align_constant<2>()); break; + default: f(std::bool_constant(), align_constant<1>()); break; } - } -}; - -template -struct row_major_tag { - static constexpr int value = rm; -}; - -struct row_major_dispatch { - bool is_row_major_; - row_major_dispatch(bool row_major) : is_row_major_(row_major) {} - - template - auto operator()(F&& f) const - { - if (is_row_major_) { - f(row_major_tag()); - } else { - f(row_major_tag()); + } else { + switch (byte_alignment) { + case 16: f(std::bool_constant(), align_constant<16>()); break; + case 8: f(std::bool_constant(), align_constant<8>()); break; + case 4: f(std::bool_constant(), align_constant<4>()); break; + case 2: f(std::bool_constant(), align_constant<2>()); break; + default: f(std::bool_constant(), align_constant<1>()); break; } } -}; - -template -auto join_dispatch(F1&& f1, F2&& f2) -{ - const auto lam = [f1, f2](auto f) { - f1([f, f2](auto... args1) { f2([f, args1...](auto... args2) { f(args1..., args2...); }); }); - }; - return lam; -} - -template -auto join_dispatch(F1 f1, F2 f2, Fs... fs) -{ - return join_dispatch(join_dispatch(f1, f2), std::forward(fs)...); } template 1 when - // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len = distance_op.expensive_inner_loop ? 1 : vec_len_ideal; - - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; + size_t align_x = alignment_of_2d_array(x, ldx); + size_t align_y = alignment_of_2d_array(y, ldy); + size_t byte_alignment = min(align_x, align_y); + + dispatch( + is_row_major, + byte_alignment, + [&](auto row_major, auto alignment) { + // row_major and alignment are std::integral_constants of type bool and + // size_t respectively. + + // Since alignment is in bytes, it could be smaller than sizeof(DataT). + // Handle this (unlikely) case here. + if constexpr (alignment() < sizeof(DataT)) { + RAFT_EXPECTS(sizeof(DataT) <= alignment(), "Input matrix must be aligned to size of elements."); + return; + } + + // Compute number of elements that can be loaded in one instruction + // without causing misalignent errors. + constexpr int vec_len_aligned = + (alignment() % sizeof(DataT) == 0) ? alignment() / sizeof(DataT) : 1; + + // To keep compile times in check, we only specialize on veclen > 1 when + // the inner loop is relatively cheap (< 5 flops). + constexpr int vec_len = distance_op.expensive_inner_loop ? 1 : vec_len_aligned; + + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename std::conditional::type Policy; // Create compile-time template parameter - using KP_T = kernel_params_T; + using KP_T = kernel_params_T; return pairwise_matrix( distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); @@ -201,17 +173,31 @@ void distance_matrix_cutlass_dispatch(opT cutlass_op, ldx = m, ldy = n, ld_out = n; } - alignment_dispatch d_align(x, y, ldx, ldy); - row_major_dispatch d_row_major(is_row_major); + size_t align_x = alignment_of_2d_array(x, ldx); + size_t align_y = alignment_of_2d_array(y, ldy); + size_t byte_alignment = min(align_x, align_y); + + + dispatch( + is_row_major, + byte_alignment, + [&](auto row_major, auto alignment) { + // row_major and alignment are std::integral_constants of type bool and + // size_t respectively. - auto dispatch = join_dispatch(d_align, d_row_major); + // Since alignment is in bytes, it could be smaller than sizeof(DataT). + // Handle this (unlikely) case here. + if constexpr (alignment() < sizeof(DataT)) { + RAFT_EXPECTS(sizeof(DataT) <= alignment(), "Input matrix must be aligned to size of elements."); + return; + } - dispatch([&](auto alignment_tag, auto row_major_tag) { - constexpr int vec_len = - (alignment_tag.value % sizeof(DataT) == 0) ? alignment_tag.value / sizeof(DataT) : 1; + // Compute number of elements that can be loaded in one instruction + // without causing misalignent errors. + constexpr int vec_len = (alignment() % sizeof(DataT) == 0) ? alignment() / sizeof(DataT) : 1; - cutlassDistanceKernel( - x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); + cutlassDistanceKernel( + x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); }); } From 7c3bd763bcb9761bb6022641519d404b96b992f7 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 10 Feb 2023 16:42:43 +0100 Subject: [PATCH 34/93] Remove one ".template" from kernel_sm60 --- .../raft/distance/detail/pairwise_matrix/kernel_sm60.cuh | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 68026414c0..eed50c36f7 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -17,8 +17,6 @@ #include #include -#include // TODO: remove - #include namespace raft::distance::detail { @@ -75,9 +73,7 @@ __global__ __launch_bounds__(KP_T::PolicyT::Nthreads, 2) // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - // use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template core(acc, x, y); + distance_op.core(acc, x, y); }; auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT * regxn, @@ -90,6 +86,7 @@ __global__ __launch_bounds__(KP_T::PolicyT::Nthreads, 2) // No support for row_epilog_op. auto row_epilog_op = raft::void_op(); + // Always write output constexpr bool write_out = true; constexpr bool use_norms = distance_op.use_norms; From d62eeb79e75824cd2cee6b2c155fbf7c823dae7d Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 10 Feb 2023 16:43:12 +0100 Subject: [PATCH 35/93] Dispatch on veclen instead of byte_alignment To reduce compile times. --- .../detail/pairwise_matrix/dispatch.cuh | 96 +++++++++---------- 1 file changed, 45 insertions(+), 51 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index cf95b10960..75e557a420 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -48,25 +48,21 @@ size_t alignment_of_2d_array(const DataT* x, size_t stride) } template -using align_constant = std::integral_constant; +using vec_len_constant = std::integral_constant; template -inline void dispatch(bool row_major, size_t byte_alignment, F&& f) { +inline void dispatch(bool row_major, int vec_len, F&& f) { if (row_major) { - switch (byte_alignment) { - case 16: f(std::bool_constant(), align_constant<16>()); break; - case 8: f(std::bool_constant(), align_constant<8>()); break; - case 4: f(std::bool_constant(), align_constant<4>()); break; - case 2: f(std::bool_constant(), align_constant<2>()); break; - default: f(std::bool_constant(), align_constant<1>()); break; + switch (vec_len) { + case 4: f(std::bool_constant(), vec_len_constant<4>()); break; + case 2: f(std::bool_constant(), vec_len_constant<2>()); break; + default: f(std::bool_constant(), vec_len_constant<1>()); break; } } else { - switch (byte_alignment) { - case 16: f(std::bool_constant(), align_constant<16>()); break; - case 8: f(std::bool_constant(), align_constant<8>()); break; - case 4: f(std::bool_constant(), align_constant<4>()); break; - case 2: f(std::bool_constant(), align_constant<2>()); break; - default: f(std::bool_constant(), align_constant<1>()); break; + switch (vec_len) { + case 4: f(std::bool_constant(), vec_len_constant<4>()); break; + case 2: f(std::bool_constant(), vec_len_constant<2>()); break; + default: f(std::bool_constant(), vec_len_constant<1>()); break; } } } @@ -107,39 +103,38 @@ void distance_matrix_dispatch(opT distance_op, size_t align_y = alignment_of_2d_array(y, ldy); size_t byte_alignment = min(align_x, align_y); + // Since alignment is in bytes, it could be smaller than sizeof(DataT). + // Handle this (unlikely) case here. + RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, "Input matrix must be aligned to size of elements."); + + // Compute number of elements that can be loaded in one instruction + // without causing misalignent errors. + int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; + dispatch( is_row_major, - byte_alignment, - [&](auto row_major, auto alignment) { - // row_major and alignment are std::integral_constants of type bool and - // size_t respectively. - - // Since alignment is in bytes, it could be smaller than sizeof(DataT). - // Handle this (unlikely) case here. - if constexpr (alignment() < sizeof(DataT)) { - RAFT_EXPECTS(sizeof(DataT) <= alignment(), "Input matrix must be aligned to size of elements."); - return; - } - - // Compute number of elements that can be loaded in one instruction - // without causing misalignent errors. - constexpr int vec_len_aligned = - (alignment() % sizeof(DataT) == 0) ? alignment() / sizeof(DataT) : 1; + vec_len_aligned, + [&](auto row_major, auto vec_len_aligned) { + // row_major and vec_len are std::integral_constants of type bool and int + // respectively. // To keep compile times in check, we only specialize on veclen > 1 when // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len = distance_op.expensive_inner_loop ? 1 : vec_len_aligned; + constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); + + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; typedef typename std::conditional::type Policy; - // Create compile-time template parameter - using KP_T = kernel_params_T; + // Create compile-time template parameter + using KP_T = kernel_params_T; - return pairwise_matrix( - distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); - }); + return pairwise_matrix( + distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); + }); } template (16 / sizeof(DataT))); cutlassDistanceKernel( x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); From 5c3dcafea941d1e877c8b1c714022ba254363601 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 20 Feb 2023 15:02:53 +0100 Subject: [PATCH 36/93] Use many template parameters again --- .../detail/pairwise_matrix/dispatch.cuh | 96 +++++++------- .../detail/pairwise_matrix/kernel_sm60.cuh | 121 ++++++++---------- 2 files changed, 105 insertions(+), 112 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 75e557a420..b3362e7647 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -24,9 +24,9 @@ namespace raft::distance::detail { /** - * @brief: Computes minimal alignment of row starting elements in 2D array + * @brief: Computes minimal common alignment of the rows in a 2D array in bytes * - * The 2D matrix x is assumed to be row-major. This function computes the + * The 2D matrix `x` is assumed to be row-major. This function computes the * minimal alignment in bytes of the first elements of each row. * Output can be 16, 8, 4, 2, 1. * @@ -50,8 +50,25 @@ size_t alignment_of_2d_array(const DataT* x, size_t stride) template using vec_len_constant = std::integral_constant; +/** + * @brief: Converts run-time arguments to compile-time arguments + * + * Converts run-time arguments row_major and vec_len to compile-time arguments + * and dispatches a lambda f with these compile-time arguments. + * + * This is equivalent to copying and pasting the lambda function `f` in each of + * the switch case statements. + * + * @tparam F Type of lambda f. + * @param row_major Boolean indicating whether input arrays have row-major layout. + * @param vec_len Integer value 1, 2, or 4 specifying the Veclen template parameter of + * the KernelPolicy. + * @param f Lambda that takes two std::integral_constant parameters representing + * row_major and vec_len. + */ template -inline void dispatch(bool row_major, int vec_len, F&& f) { +void dispatch(bool row_major, int vec_len, F&& f) +{ if (row_major) { switch (vec_len) { case 4: f(std::bool_constant(), vec_len_constant<4>()); break; @@ -67,13 +84,13 @@ inline void dispatch(bool row_major, int vec_len, F&& f) { } } -template -void distance_matrix_dispatch(opT distance_op, +void distance_matrix_dispatch(OpT distance_op, IdxT m, IdxT n, IdxT k, @@ -86,8 +103,8 @@ void distance_matrix_dispatch(opT distance_op, cudaStream_t stream, bool is_row_major) { - // Determine leading dimensions and possibly flip order of passing x and y if - // column_major. + // Determine leading dimensions and, if column-major, flip order of passing x + // and y. IdxT ldx, ldy, ld_out; if (is_row_major) { ldx = k, ldy = k, ld_out = n; @@ -99,42 +116,37 @@ void distance_matrix_dispatch(opT distance_op, ldx = m, ldy = n, ld_out = n; } - size_t align_x = alignment_of_2d_array(x, ldx); - size_t align_y = alignment_of_2d_array(y, ldy); + size_t align_x = alignment_of_2d_array(x, ldx); + size_t align_y = alignment_of_2d_array(y, ldy); size_t byte_alignment = min(align_x, align_y); // Since alignment is in bytes, it could be smaller than sizeof(DataT). // Handle this (unlikely) case here. - RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, "Input matrix must be aligned to size of elements."); + RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, + "Input matrix must be aligned to size of elements."); // Compute number of elements that can be loaded in one instruction // without causing misalignent errors. int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; - dispatch( - is_row_major, - vec_len_aligned, - [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. + dispatch(is_row_major, vec_len_aligned, [&](auto row_major, auto vec_len_aligned) { + // row_major and vec_len are std::integral_constants of type bool and int + // respectively. - // To keep compile times in check, we only specialize on veclen > 1 when - // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); + // To keep compile times in check, we only specialize on veclen > 1 when + // the inner loop is relatively cheap (< 5 flops). + constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename std::conditional::type Policy; - // Create compile-time template parameter - using KP_T = kernel_params_T; - - return pairwise_matrix( - distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); - }); + return pairwise_matrix( + distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); + }); } template (16 / sizeof(DataT))); + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_aligned, static_cast(16 / sizeof(DataT))); - cutlassDistanceKernel( - x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); + cutlassDistanceKernel( + x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); }); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index eed50c36f7..1e450f9289 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -21,54 +21,28 @@ namespace raft::distance::detail { -template -struct kernel_params_T { - using DataT = data_type; - using AccT = accumulate_type; - using OutT = out_type; - using IdxT = index_type; - using PolicyT = policy; - using opT = op_type; - using FinOpT = final_op_type; - static constexpr bool is_row_major = row_major; -}; - -template -__global__ __launch_bounds__(KP_T::PolicyT::Nthreads, 2) - - void pairwise_matrix_kernel(const typename KP_T::DataT* x, - const typename KP_T::DataT* y, - const typename KP_T::DataT* _xn, - const typename KP_T::DataT* _yn, - typename KP_T::IdxT m, - typename KP_T::IdxT n, - typename KP_T::IdxT k, - typename KP_T::IdxT lda, - typename KP_T::IdxT ldb, - typename KP_T::IdxT ldd, - typename KP_T::OutT* dOutput, - typename KP_T::opT distance_op, - typename KP_T::FinOpT fin_op) +template +__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + opT distance_op, + FinOpT fin_op) { - using AccT = typename KP_T::AccT; - using DataT = typename KP_T::DataT; - using OutT = typename KP_T::OutT; - using IdxT = typename KP_T::IdxT; - - using Policy = typename KP_T::PolicyT; - - // Instantiate compile time parameters to access constexpr members. - KP_T compile_time_params{}; - extern __shared__ char smem[]; // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) @@ -80,6 +54,8 @@ __global__ __launch_bounds__(KP_T::PolicyT::Nthreads, 2) DataT * regyn, IdxT gridStrideX, IdxT gridStrideY) { + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) distance_op.template epilog( acc, regxn, regyn, gridStrideX, gridStrideY); }; @@ -100,7 +76,7 @@ __global__ __launch_bounds__(KP_T::PolicyT::Nthreads, 2) decltype(epilog_op), decltype(fin_op), decltype(row_epilog_op), - compile_time_params.is_row_major, + row_major, write_out> obj(x, y, @@ -121,32 +97,39 @@ __global__ __launch_bounds__(KP_T::PolicyT::Nthreads, 2) obj.run(); } -template -static void pairwise_matrix(typename KP_T::opT distance_op, - typename KP_T::FinOpT fin_op, - const typename KP_T::DataT* x, - const typename KP_T::DataT* y, - const typename KP_T::DataT* _xn, - const typename KP_T::DataT* _yn, - typename KP_T::IdxT m, - typename KP_T::IdxT n, - typename KP_T::IdxT k, - typename KP_T::IdxT lda, - typename KP_T::IdxT ldb, - typename KP_T::IdxT ldd, - typename KP_T::OutT* dOutput, - cudaStream_t stream) +template +void pairwise_matrix(OpT distance_op, + FinOpT fin_op, + const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + cudaStream_t stream) { - using Policy = typename KP_T::PolicyT; - using DataT = typename KP_T::DataT; - dim3 blk(Policy::Nthreads); + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) size_t smem_size = distance_op.template shared_mem_size(); - dim3 grid = launchConfigGenerator(m, n, smem_size, pairwise_matrix_kernel); + // Obtain function pointer to kernel + auto kernel = pairwise_matrix_kernel; + dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); - pairwise_matrix_kernel<<>>( + kernel<<>>( x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); - RAFT_CUDA_TRY(cudaGetLastError()); } From 2613e8a72d69278d7e2b50e6e5404e9c457dd685 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 20 Feb 2023 16:03:23 +0100 Subject: [PATCH 37/93] Remove duplicate DistanceType enum definition --- cpp/include/raft/distance/detail/distance.cuh | 48 +------------------ 1 file changed, 1 insertion(+), 47 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index b459c73bee..5887155401 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -37,53 +38,6 @@ namespace raft { namespace distance { namespace detail { -/** enum to tell how to compute distance */ -enum DistanceType : unsigned short { - - /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ - L2Expanded = 0, - /** same as above, but inside the epilogue, perform square root operation */ - L2SqrtExpanded = 1, - /** cosine distance */ - CosineExpanded = 2, - /** L1 distance */ - L1 = 3, - /** evaluate as dist_ij += (x_ik - y-jk)^2 */ - L2Unexpanded = 4, - /** same as above, but inside the epilogue, perform square root operation */ - L2SqrtUnexpanded = 5, - /** basic inner product **/ - InnerProduct = 6, - /** Chebyshev (Linf) distance **/ - Linf = 7, - /** Canberra distance **/ - Canberra = 8, - /** Generalized Minkowski distance **/ - LpUnexpanded = 9, - /** Correlation distance **/ - CorrelationExpanded = 10, - /** Jaccard distance **/ - JaccardExpanded = 11, - /** Hellinger distance **/ - HellingerExpanded = 12, - /** Haversine distance **/ - Haversine = 13, - /** Bray-Curtis distance **/ - BrayCurtis = 14, - /** Jensen-Shannon distance**/ - JensenShannon = 15, - /** Hamming distance **/ - HammingUnexpanded = 16, - /** KLDivergence **/ - KLDivergence = 17, - /** RusselRao **/ - RusselRaoExpanded = 18, - /** Dice-Sorensen distance **/ - DiceExpanded = 19, - /** Precomputed (special value) **/ - Precomputed = 100 -}; - namespace { template Date: Mon, 20 Feb 2023 17:38:24 +0100 Subject: [PATCH 38/93] Remove pairwiseDistanceMatKernel Has been replaced by pairwise_matrix_kernel --- .../detail/pairwise_distance_base.cuh | 164 ------------------ 1 file changed, 164 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 140664f394..5acdf91c67 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -268,170 +268,6 @@ struct PairwiseDistances : public BaseClass { } }; // struct PairwiseDistances -/** - * @brief the distance matrix calculation kernel for L1, L2 and cosine - * @tparam useNorms whether norms are needed - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda lambda which implements accumulation operation - * @tparam EpilogueLambda lambda which implements operation for calculating - final value. - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major(default), - false for column major - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] xn row norms of input matrix A. - * @param[in] yn row norms of input matrix B. - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix - * @param core_op the core lambda - * @param epilog_op the epilogue lambda - * @param fin_op the final gemm epilogue lambda - */ - -template -__global__ __launch_bounds__(Policy::Nthreads, 2) - - void pairwiseDistanceMatKernel(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - CoreLambda core_op, - EpilogueLambda epilog_op, - FinalLambda fin_op) -{ - extern __shared__ char smem[]; - auto rowEpilog = raft::void_op(); - - PairwiseDistances - obj( - x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, epilog_op, fin_op, rowEpilog); - obj.run(); -} - -/** - * @brief the distance matrix calculation kernel for L2 and cosine - * for GPU arch < SM 8.0, this version is to make sure we don't recompile - * these kernels for ampere or higher as we use cutlass kernel for it. - * @tparam useNorms whether norms are needed - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda lambda which implements accumulation operation - * @tparam EpilogueLambda lambda which implements operation for calculating - final value. - * @tparam FinalLambda final lambda called on final distance value - * @tparam isRowMajor true if input/output is row major(default), - false for column major - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] xn row norms of input matrix A. - * @param[in] yn row norms of input matrix B. - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix - * @param core_op the core lambda - * @param epilog_op the epilogue lambda - * @param fin_op the final gemm epilogue lambda - */ - -template -__global__ __launch_bounds__(Policy::Nthreads, 2) - - void pairwiseDistanceMatKernelPriorToAmpere(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - CoreLambda core_op, - EpilogueLambda epilog_op, - FinalLambda fin_op) -{ - //#if __CUDA_ARCH__ < 800 - // TODO: re-enable the CUDA_ARCH guard for below Ampere once cutlass based - // kernels are enabled for CUDA 12.0 - extern __shared__ char smem[]; - auto rowEpilog = raft::void_op(); - - PairwiseDistances - obj( - x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, epilog_op, fin_op, rowEpilog); - obj.run(); - //#endif -} - template dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { From c334ba33df43c9eb8e29cde0821aa63d3d531a8e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 20 Feb 2023 19:26:25 +0100 Subject: [PATCH 39/93] Remove distance::detail::pairwise_distance_impl --- cpp/include/raft/distance/detail/distance.cuh | 38 -------- cpp/include/raft/distance/distance.cuh | 88 +++++++++---------- 2 files changed, 41 insertions(+), 85 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 5887155401..8d4155356b 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include #include @@ -648,43 +647,6 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In return worksize; } -/** - * @defgroup pairwise_distance pairwise distance prims - * @{ - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace buffer which can get resized as per the - * needed workspace size - * @param metric distance metric - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - */ -template -void pairwise_distance_impl(const Type* x, - const Type* y, - Type* dist, - Index_ m, - Index_ n, - Index_ k, - rmm::device_uvector& workspace, - cudaStream_t stream, - bool isRowMajor, - Type metric_arg = 2.0f) -{ - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - distance( - x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg); -} -/** @} */ }; // namespace detail }; // namespace distance }; // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 93a5ce7f1a..90eeb90d38 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -250,67 +251,60 @@ void pairwise_distance(raft::device_resources const& handle, bool isRowMajor = true, Type metric_arg = 2.0f) { + auto stream = handle.get_stream(); + + auto dispatch = [&](auto distance_type) { + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg); + }; + switch (metric) { - case raft::distance::DistanceType::L2Expanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::Canberra: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::L2SqrtExpanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::CorrelationExpanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::CosineExpanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::CosineExpanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::L1: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::HammingUnexpanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::L2Unexpanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::HellingerExpanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::L2SqrtUnexpanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::JensenShannon: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::Linf: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::KLDivergence: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::HellingerExpanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::L1: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::LpUnexpanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); + case DistanceType::L2Expanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::Canberra: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::L2SqrtExpanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::HammingUnexpanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::L2SqrtUnexpanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::JensenShannon: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::L2Unexpanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::RusselRaoExpanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::Linf: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::KLDivergence: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::LpUnexpanded: + dispatch(std::integral_constant{}); break; - case raft::distance::DistanceType::CorrelationExpanded: - detail:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + case DistanceType::RusselRaoExpanded: + dispatch(std::integral_constant{}); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; @@ -481,4 +475,4 @@ void pairwise_distance(raft::device_resources const& handle, }; // namespace distance }; // namespace raft -#endif \ No newline at end of file +#endif From 8e432383eae1b2017b0c531270956a6a69638ce9 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 21 Feb 2023 09:57:03 +0100 Subject: [PATCH 40/93] distance_ops: Include cuda_utils.cuh --- cpp/include/raft/distance/detail/distance_ops/correlation.cuh | 2 ++ cpp/include/raft/distance/detail/distance_ops/cosine.cuh | 2 ++ cpp/include/raft/distance/detail/distance_ops/hamming.cuh | 2 ++ cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh | 2 ++ cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh | 2 ++ cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh | 2 ++ cpp/include/raft/distance/detail/distance_ops/template.cuh | 2 ++ 7 files changed, 14 insertions(+) diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index eb18355ca9..f17d67953e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::distance::detail::ops { // Describes the computation the correlation distance diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index bbc1ffcba2..aa2eac01bc 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::distance::detail::ops { // Describes the computation the cosine distance diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index c8b3b7658e..b4f610be0a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::distance::detail::ops { // Describes the computation the hamming distance diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 13a41190c1..523019f417 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::distance::detail::ops { // Describes the computation the l2 expanded distance diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index 31fbd11667..f5e2f278b7 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::distance::detail::ops { // Describes the computation the l2 unexpanded distance diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index f46a1a5e67..e114ef8224 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::distance::detail::ops { // Describes the computation the russel_rao distance diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index d7bbfc7fca..378bcf0c9f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::distance::detail::ops { // Describes the computation the template distance From e176351d9d885d8f0a918851220317f74a587e73 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 21 Feb 2023 10:33:14 +0100 Subject: [PATCH 41/93] Replace DistanceImpl with method overloads --- cpp/include/raft/distance/detail/distance.cuh | 1205 ++++++++++------- .../detail/pairwise_matrix/dispatch.cuh | 26 +- 2 files changed, 705 insertions(+), 526 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 8d4155356b..58c4dbd275 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -16,19 +16,26 @@ #pragma once +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + #include #include #include @@ -37,526 +44,710 @@ namespace raft { namespace distance { namespace detail { -namespace { -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType metric_arg = 2.0f) - { - } -}; +/** + * @brief: A tag type for overload resolution based on DistanceType + * + * It is not possible to partially specialize function templates on a single + * parameter. Intead, it is often easier to use a combination of conventional + * method overloading and a parameter with a specific tag type. The following + * type is used to help method overloading based on the DistanceType enum. + */ +template +using distance_tag = std::integral_constant; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::euclideanAlgo1( - m, n, k, x, y, dist, false, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); - } -}; +/** + * @brief Implement pairwise_matrix for specific distance + * + * There are multiple overloads for this function, one for each distance type. + * They are implemented below. The documentation of this function serves as + * documentation for all functions. The following overloads are defined: + * + * - DistanceType::Canberra: + * - DistanceType::CorrelationExpanded: + * - DistanceType::CosineExpanded: + * - DistanceType::HammingUnexpanded: + * - DistanceType::HellingerExpanded: + * - DistanceType::JensenShannon: + * - DistanceType::KLDivergence: + * - DistanceType::L1: + * - DistanceType::L2Expanded: + * - DistanceType::L2SqrtExpanded: + * - DistanceType::L2Unexpanded: + * - DistanceType::L2SqrtUnexpanded: + * - DistanceType::Linf: + * - DistanceType::LpUnexpanded: + * - DistanceType::RusselRaoExpanded: + * + * @tparam DataT Input data type + * @tparam AccT Accumulation data type + * @tparam OutT Output data type + * @tparam FinOpT Type of final operation + * @tparam IdxT Index type + * + * @param distance_type A tag type to indicate which distance is calculated. + * @param x First set of points + * @param y Second set of points + * @param out Output distance matrix + * @param m Number of points in x + * @param n Number of points in y + * @param k Dimensionality of points in x, y + * @param workspace Temporary workspace needed for computations + * @param worksize Number of bytes of the workspace + * @param stream CUDA stream + * @param is_row_major Whether the matrices are row-major or col-major + * @param metric_arg The `p` argument for Lp. + */ +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, // unused + size_t worksize, // unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT metric_arg) // unused +{ + ops::canberra_distance_op distance_op{}; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::euclideanAlgo1( - m, n, k, x, y, dist, true, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); - } -}; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::cosineAlgo1( - m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); - } -}; + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::euclideanAlgo2( - m, n, k, x, y, dist, false, fin_op, stream, isRowMajor); - } -}; +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // unused +{ + ASSERT(!(((x != y) && (worksize < 2 * (m + n) * sizeof(AccT))) || + (worksize < 2 * m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::euclideanAlgo2( - m, n, k, x, y, dist, true, fin_op, stream, isRowMajor); - } -}; + AccT* norm_col_vec = workspace; + AccT* norm_row_vec = workspace; + AccT* sq_norm_col_vec = workspace; + AccT* sq_norm_row_vec = workspace; + if (x != y) { + norm_row_vec += m; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::l1Impl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); - } -}; + raft::linalg::reduce(norm_col_vec, + x, + k, + m, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + raft::linalg::reduce(norm_row_vec, + y, + k, + n, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::chebyshevImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + sq_norm_col_vec += (m + n); + sq_norm_row_vec = sq_norm_col_vec + m; + raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + raft::linalg::rowNorm(sq_norm_row_vec, y, k, n, raft::linalg::L2Norm, is_row_major, stream); + } else { + raft::linalg::reduce(norm_col_vec, + x, + k, + m, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + sq_norm_col_vec += m; + sq_norm_row_vec = sq_norm_col_vec; + raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); } -}; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::hellingerImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + using CorrOp = ops::correlation_distance_op; + CorrOp corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); + distance_matrix_dispatch( + corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // unused +{ + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), + "OutT can be uint8_t, float, double," + "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); + + ASSERT( + !(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + + DataT* norm_A = workspace; + DataT* norm_B = workspace; + if (x != y) { + norm_B += m; + raft::linalg::rowNorm( + norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + } else { + raft::linalg::rowNorm( + norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } -}; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType metric_arg) - { - raft::distance::detail::minkowskiImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor, metric_arg); + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel otherwise. + + if constexpr (__CUDACC_VER_MAJOR__ == 12) { + // Always execute legacy kernels on CUDA 12 + ops::cosine_distance_op distance_op{}; + distance_matrix_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + } else { + const auto deviceVersion = getComputeCapability(); + if (deviceVersion.first >= 8) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using Op = ops::cosine_cutlass_op; + Op distance_op{}; + + distance_matrix_cutlass_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + } else { + // Else use "legacy" L2 + ops::cosine_distance_op distance_op{}; + distance_matrix_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + } } -}; +} -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::canberraImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::hamming_distance_op distance_op{k}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + // First sqrt x and y + const auto raft_sqrt = raft::linalg::unaryOp; + + raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); + if (x != y) { + raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } -}; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::hammingUnexpandedImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + // Then calculate Hellinger distance + ops::hellinger_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + + // Finally revert sqrt of x and y + raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); + if (x != y) { + raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } -}; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::jensenShannonImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::jensen_shannon_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + auto unaryOp_lambda = [] __device__(DataT input) { + const bool x_zero = (input == 0); + return (!x_zero) * raft::log(input + x_zero); }; + + auto unaryOp_lambda_reverse = [] __device__(DataT input) { + // reverse previous log (x) back to x using (e ^ log(x)) + const bool x_zero = (input == 0); + return (!x_zero) * raft::exp(input); }; + + // This op takes some shortcuts when x equals y. So its behavior changes based + // on this. + ops::kl_divergence_op kl_divergence{is_row_major, x == y}; + + if (x != y) { + raft::linalg::unaryOp( + (DataT*)y, y, n * k, unaryOp_lambda, stream); } -}; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::russellRaoImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + + if (x != y) { + // Now reverse previous log (x) back to x using (e ^ log(x)) + raft::linalg::unaryOp( + (DataT*)y, y, n * k, unaryOp_lambda_reverse, stream); } -}; +} -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void*, - size_t, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::klDivergenceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::l1_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl_l2_expanded( // NOTE: different name + bool perform_sqrt, // dispatch on sqrt + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), + "OutT can be uint8_t, float, double," + "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); + + ASSERT( + !(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + DataT* norm_A = workspace; + DataT* norm_B = workspace; + if (x != y) { + norm_B += m; + raft::linalg::rowNorm( + norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + raft::linalg::rowNorm( + norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + } else { + raft::linalg::rowNorm( + norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } -}; -template -struct DistanceImpl { - void run(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor, - InType) - { - raft::distance::detail::correlationImpl( - m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel otherwise. + + if constexpr (__CUDACC_VER_MAJOR__ == 12) { + // Always execute legacy kernels on CUDA 12 + ops::l2_exp_distance_op l2_op(perform_sqrt); + distance_matrix_dispatch( + l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + } else { + const auto deviceVersion = getComputeCapability(); + if (deviceVersion.first >= 8) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using L2Op = ops::l2_exp_cutlass_op; + L2Op l2_op(perform_sqrt); + + distance_matrix_cutlass_dispatch( + l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + } else { + // Else use "legacy" L2 + ops::l2_exp_distance_op l2_op(perform_sqrt); + distance_matrix_dispatch( + l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + } } -}; +} -} // anonymous namespace +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + bool perform_sqrt = false; + distance_impl_l2_expanded(perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); +} -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccType and returns the output in OutType. It's signature is - * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ -template -void distance(const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor = true, - InType metric_arg = 2.0f) +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { - DistanceImpl distImpl; - distImpl.run(x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + bool perform_sqrt = true; + distance_impl_l2_expanded(perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + bool perform_sqrt = false; + ops::l2_unexp_distance_op l2_op(perform_sqrt); + + // The unexpanded L2 does not require the norms of a and b to be calculated. + const DataT* norm_A = nullptr; + const DataT* norm_B = nullptr; + + distance_matrix_dispatch( + l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + bool perform_sqrt = true; + ops::l2_unexp_distance_op l2_op(perform_sqrt); + + // The unexpanded L2 does not require the norms of a and b to be calculated. + const DataT* norm_A = nullptr; + const DataT* norm_B = nullptr; + + distance_matrix_dispatch( + l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::chebyshev_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT metric_arg) +{ + ops::minkowski_distance_op distance_op{metric_arg}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl( + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::russel_rao_distance_op distance_op{k}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + distance_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } /** - * @brief Evaluate pairwise distances for the simple use case + * @brief Evaluate pairwise distances and write to matrix + * * @tparam DistanceType which distance to evaluate * @tparam InType input argument type * @tparam AccType accumulation type * @tparam OutType output type * @tparam Index_ Index type + * * @param x first set of points * @param y second set of points - * @param dist output distance matrix + * @param out output distance matrix * @param m number of points in x * @param n number of points in y * @param k dimensionality @@ -568,19 +759,6 @@ void distance(const InType* x, * @note if workspace is passed as nullptr, this will return in * worksize, the number of bytes of workspace required */ - -// Default final op functor which facilitates elementwise operation on -// final distance value if any. -template -struct default_fin_op { - __host__ __device__ default_fin_op() noexcept {}; - // functor signature. - __host__ __device__ OutType operator()(AccType d_val, Index g_d_idx) const noexcept - { - return d_val; - } -}; - template void distance(const InType* x, const InType* y, - OutType* dist, + OutType* out, Index_ m, Index_ n, Index_ k, @@ -598,15 +776,16 @@ void distance(const InType* x, bool isRowMajor = true, InType metric_arg = 2.0f) { - using final_op_type = default_fin_op; - final_op_type fin_op; + auto fin_op = raft::identity_op(); // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); - distance( - x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); + + distance_impl( + distance_tag{}, + x, y, out, m, n, k, reinterpret_cast(workspace), worksize, fin_op, stream, isRowMajor, metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index b3362e7647..4a7c1f999f 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -69,19 +69,19 @@ using vec_len_constant = std::integral_constant; template void dispatch(bool row_major, int vec_len, F&& f) { - if (row_major) { - switch (vec_len) { - case 4: f(std::bool_constant(), vec_len_constant<4>()); break; - case 2: f(std::bool_constant(), vec_len_constant<2>()); break; - default: f(std::bool_constant(), vec_len_constant<1>()); break; - } - } else { - switch (vec_len) { - case 4: f(std::bool_constant(), vec_len_constant<4>()); break; - case 2: f(std::bool_constant(), vec_len_constant<2>()); break; - default: f(std::bool_constant(), vec_len_constant<1>()); break; - } - } + // if (row_major) { + // switch (vec_len) { + // case 4: f(std::bool_constant(), vec_len_constant<4>()); break; + // case 2: f(std::bool_constant(), vec_len_constant<2>()); break; + // default: f(std::bool_constant(), vec_len_constant<1>()); break; + // } + // } else { + // switch (vec_len) { + // case 4: f(std::bool_constant(), vec_len_constant<4>()); break; + // case 2: f(std::bool_constant(), vec_len_constant<2>()); break; + // default: f(std::bool_constant(), vec_len_constant<1>()); break; + // } + // } } template Date: Tue, 21 Feb 2023 11:17:24 +0100 Subject: [PATCH 42/93] Remove impl files and move doc strings --- cpp/include/raft/distance/detail/canberra.cuh | 71 --- .../raft/distance/detail/chebyshev.cuh | 68 --- .../raft/distance/detail/correlation.cuh | 127 ---- cpp/include/raft/distance/detail/cosine.cuh | 123 ---- cpp/include/raft/distance/detail/distance.cuh | 569 ++++++++---------- .../distance/detail/distance_ops/canberra.cuh | 9 +- .../detail/distance_ops/chebyshev.cuh | 9 +- .../detail/distance_ops/correlation.cuh | 32 +- .../distance/detail/distance_ops/cosine.cuh | 11 +- .../distance/detail/distance_ops/hamming.cuh | 10 +- .../detail/distance_ops/hellinger.cuh | 12 +- .../detail/distance_ops/jensen_shannon.cuh | 8 + .../detail/distance_ops/kl_divergence.cuh | 19 +- .../raft/distance/detail/distance_ops/l1.cuh | 8 +- .../distance/detail/distance_ops/l2_exp.cuh | 11 +- .../distance/detail/distance_ops/l2_unexp.cuh | 8 +- .../detail/distance_ops/minkowski.cuh | 11 +- .../detail/distance_ops/russel_rao.cuh | 14 +- .../raft/distance/detail/euclidean.cuh | 169 ------ cpp/include/raft/distance/detail/hamming.cuh | 71 --- .../raft/distance/detail/hellinger.cuh | 94 --- .../raft/distance/detail/jensen_shannon.cuh | 72 --- .../raft/distance/detail/kl_divergence.cuh | 98 --- cpp/include/raft/distance/detail/l1.cuh | 51 -- .../raft/distance/detail/minkowski.cuh | 70 --- .../raft/distance/detail/russell_rao.cuh | 70 --- 26 files changed, 361 insertions(+), 1454 deletions(-) delete mode 100644 cpp/include/raft/distance/detail/canberra.cuh delete mode 100644 cpp/include/raft/distance/detail/chebyshev.cuh delete mode 100644 cpp/include/raft/distance/detail/correlation.cuh delete mode 100644 cpp/include/raft/distance/detail/cosine.cuh delete mode 100644 cpp/include/raft/distance/detail/euclidean.cuh delete mode 100644 cpp/include/raft/distance/detail/hamming.cuh delete mode 100644 cpp/include/raft/distance/detail/hellinger.cuh delete mode 100644 cpp/include/raft/distance/detail/jensen_shannon.cuh delete mode 100644 cpp/include/raft/distance/detail/kl_divergence.cuh delete mode 100644 cpp/include/raft/distance/detail/l1.cuh delete mode 100644 cpp/include/raft/distance/detail/minkowski.cuh delete mode 100644 cpp/include/raft/distance/detail/russell_rao.cuh diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh deleted file mode 100644 index 3f0c2fa268..0000000000 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "distance_ops/canberra.cuh" -#include "pairwise_matrix/dispatch.cuh" - -namespace raft { -namespace distance { -namespace detail { - - -/** - * @brief the canberra distance matrix calculation - * It computes the following equation: cij = max(cij, op(ai-bj)) - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam FinOpT user-defined epilogue lamba - * @tparam IdxT Index type - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and cols of C/D - * @param[in] k number of cols of A and B - * @param[in] pA input matrix - * @param[in] pB input matrix - * @param[out] pD output matrix - * @param[in] fin_op the final element-wise epilogue lambda - * @param[in] stream cuda stream to launch work - * @param[in] isRowMajor whether the input and output matrices are row major - */ -template -void canberraImpl(int m, - int n, - int k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - ops::canberra_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh deleted file mode 100644 index 9f49660301..0000000000 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "distance_ops/chebyshev.cuh" -#include "pairwise_matrix/dispatch.cuh" - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the chebyshev distance matrix calculation - * It computes the following equation: cij = max(cij, op(ai-bj)) - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and cols of C/D - * @param[in] k number of cols of A and B - * @param[in] pA input matrix - * @param[in] pB input matrix - * @param[out] pD output matrix - * @param[in] fin_op the final element-wise epilogue lambda - * @param[in] stream cuda stream to launch work - * @param[in] isRowMajor whether the input and output matrices are row major - */ -template -void chebyshevImpl(int m, - int n, - int k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - ops::chebyshev_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh deleted file mode 100644 index 89828c9ba2..0000000000 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include "pairwise_matrix/dispatch.cuh" -#include "distance_ops/correlation.cuh" - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the Correlation distance matrix calculation - * - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void correlationImpl(int m, - int n, - int k, - const InType* pA, - const InType* pB, - OutType* pD, - AccType* workspace, - size_t& worksize, - FinalLambda fin_op, - cudaStream_t stream, - bool isRowMajor) -{ - ASSERT(!(((pA != pB) && (worksize < 2 * (m + n) * sizeof(AccType))) || - (worksize < 2 * m * sizeof(AccType))), - "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - AccType* norm_col_vec = workspace; - AccType* norm_row_vec = workspace; - AccType* sq_norm_col_vec = workspace; - AccType* sq_norm_row_vec = workspace; - if (pA != pB) { - norm_row_vec += m; - - raft::linalg::reduce(norm_col_vec, - pA, - k, - m, - (AccType)0, - isRowMajor, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - raft::linalg::reduce(norm_row_vec, - pB, - k, - n, - (AccType)0, - isRowMajor, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - - sq_norm_col_vec += (m + n); - sq_norm_row_vec = sq_norm_col_vec + m; - raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream); - raft::linalg::rowNorm(sq_norm_row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream); - } else { - raft::linalg::reduce(norm_col_vec, - pA, - k, - m, - (AccType)0, - isRowMajor, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - sq_norm_col_vec += m; - sq_norm_row_vec = sq_norm_col_vec; - raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream); - } - - using CorrOp = ops::correlation_distance_op; - CorrOp corr_op(isRowMajor, sq_norm_col_vec, sq_norm_row_vec, m, n, k); - distance_matrix_dispatch( - corr_op, m, n, k, pA, pB, norm_col_vec, norm_row_vec, pD, fin_op, stream, isRowMajor); -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh deleted file mode 100644 index 4ae0c285f5..0000000000 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include - -#include "pairwise_matrix/dispatch.cuh" -#include "distance_ops/cosine.cuh" - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the expanded cosine distance matrix calculation - * It computes the following equation: - * C = 1 - op(A * B / sqrt(A^2) * sqrt(B^2))) - * @tparam IType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OType output data-type (for C and D matrices) - * @tparam OutputTile_ output tile size for the thread block - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @tparam in_params user-defined input parameter - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void cosineAlgo1(IdxT m, - IdxT n, - IdxT k, - const DataT* pA, - const DataT* pB, - OutT* pD, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - cudaStream_t stream, - bool isRowMajor) -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), - "OutT can be uint8_t, float, double," - "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - - ASSERT( - !(((pA != pB) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - - DataT* norm_A = workspace; - DataT* norm_B = workspace; - if (pA != pB) { - norm_B += m; - raft::linalg::rowNorm( - norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); - raft::linalg::rowNorm( - norm_B, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); - } else { - raft::linalg::rowNorm( - norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::sqrt_op{}); - } - - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel otherwise. - - if constexpr (__CUDACC_VER_MAJOR__ == 12) { - // Always execute legacy kernels on CUDA 12 - ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); - } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using Op = ops::cosine_cutlass_op; - Op distance_op{}; - - distance_matrix_cutlass_dispatch( - distance_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); - } else { - // Else use "legacy" L2 - ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); - } - } -} - -}; // end namespace detail -}; // end namespace distance -}; // end namespace raft diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 58c4dbd275..5275d26ab2 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,17 @@ #pragma once -#include #include +#include + +#include +#include +#include + #include #include +#include #include #include #include @@ -30,7 +36,6 @@ #include #include #include -#include #include #include @@ -48,7 +53,7 @@ namespace detail { * @brief: A tag type for overload resolution based on DistanceType * * It is not possible to partially specialize function templates on a single - * parameter. Intead, it is often easier to use a combination of conventional + * parameter. Instead, it is often easier to use a combination of conventional * method overloading and a parameter with a specific tag type. The following * type is used to help method overloading based on the DistanceType enum. */ @@ -97,25 +102,20 @@ using distance_tag = std::integral_constant; * @param is_row_major Whether the matrices are row-major or col-major * @param metric_arg The `p` argument for Lp. */ -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, // unused - size_t worksize, // unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT metric_arg) // unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, // unused + size_t worksize, // unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT metric_arg) // unused { ops::canberra_distance_op distance_op{}; @@ -126,29 +126,24 @@ void distance_impl( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // unused { - ASSERT(!(((x != y) && (worksize < 2 * (m + n) * sizeof(AccT))) || - (worksize < 2 * m * sizeof(AccT))), - "workspace size error"); + ASSERT( + !(((x != y) && (worksize < 2 * (m + n) * sizeof(AccT))) || (worksize < 2 * m * sizeof(AccT))), + "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); AccT* norm_col_vec = workspace; @@ -208,37 +203,30 @@ void distance_impl( corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // unused { // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), "OutT can be uint8_t, float, double," "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - ASSERT( - !(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); + ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), + "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); - DataT* norm_A = workspace; DataT* norm_B = workspace; if (x != y) { @@ -282,25 +270,20 @@ void distance_impl( } } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { ops::hamming_distance_op distance_op{k}; @@ -311,33 +294,26 @@ void distance_impl( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { // First sqrt x and y const auto raft_sqrt = raft::linalg::unaryOp; raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { - raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); - } + if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } // Then calculate Hellinger distance ops::hellinger_distance_op distance_op{}; @@ -350,32 +326,25 @@ void distance_impl( // Finally revert sqrt of x and y raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { - raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); - } + if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } RAFT_CUDA_TRY(cudaGetLastError()); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { ops::jensen_shannon_distance_op distance_op{}; @@ -386,34 +355,31 @@ void distance_impl( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { auto unaryOp_lambda = [] __device__(DataT input) { - const bool x_zero = (input == 0); - return (!x_zero) * raft::log(input + x_zero); }; + const bool x_zero = (input == 0); + return (!x_zero) * raft::log(input + x_zero); + }; auto unaryOp_lambda_reverse = [] __device__(DataT input) { - // reverse previous log (x) back to x using (e ^ log(x)) - const bool x_zero = (input == 0); - return (!x_zero) * raft::exp(input); }; + // reverse previous log (x) back to x using (e ^ log(x)) + const bool x_zero = (input == 0); + return (!x_zero) * raft::exp(input); + }; // This op takes some shortcuts when x equals y. So its behavior changes based // on this. @@ -437,26 +403,20 @@ void distance_impl( } } - -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { ops::l1_distance_op distance_op{}; @@ -472,8 +432,8 @@ template -void distance_impl_l2_expanded( // NOTE: different name - bool perform_sqrt, // dispatch on sqrt +void distance_impl_l2_expanded( // NOTE: different name + bool perform_sqrt, // dispatch on sqrt const DataT* x, const DataT* y, OutT* out, @@ -491,9 +451,8 @@ void distance_impl_l2_expanded( // NOTE: different name "OutT can be uint8_t, float, double," "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - ASSERT( - !(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); + ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), + "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); DataT* norm_A = workspace; @@ -539,73 +498,60 @@ void distance_impl_l2_expanded( // NOTE: different name } } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { bool perform_sqrt = false; - distance_impl_l2_expanded(perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); + distance_impl_l2_expanded( + perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { bool perform_sqrt = true; - distance_impl_l2_expanded(perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); + distance_impl_l2_expanded( + perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { bool perform_sqrt = false; ops::l2_unexp_distance_op l2_op(perform_sqrt); @@ -618,25 +564,20 @@ void distance_impl( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { bool perform_sqrt = true; ops::l2_unexp_distance_op l2_op(perform_sqrt); @@ -649,25 +590,20 @@ void distance_impl( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { ops::chebyshev_distance_op distance_op{}; @@ -678,25 +614,20 @@ void distance_impl( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT metric_arg) +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT metric_arg) { ops::minkowski_distance_op distance_op{metric_arg}; @@ -707,25 +638,20 @@ void distance_impl( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } -template -void distance_impl( - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT) // metric_arg unused +template +void distance_impl(distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + DataT) // metric_arg unused { ops::russel_rao_distance_op distance_op{k}; @@ -785,7 +711,18 @@ void distance(const InType* x, distance_impl( distance_tag{}, - x, y, out, m, n, k, reinterpret_cast(workspace), worksize, fin_op, stream, isRowMajor, metric_arg); + x, + y, + out, + m, + n, + k, + reinterpret_cast(workspace), + worksize, + fin_op, + stream, + isRowMajor, + metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index e9c16d6d6d..6491b24e3d 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -20,8 +20,13 @@ namespace raft::distance::detail::ops { -// Describes the computation the canberra distance - +/** + * @brief The canberra distance matrix calculation + * + * It computes the following equation: + * + * c_ij = sum_k |x_ik - y_kj| / ( |x_ik| + |y_kj| ) + */ struct canberra_distance_op { // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh b/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh index a68d9fc21c..d390f75460 100644 --- a/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh @@ -20,8 +20,13 @@ namespace raft::distance::detail::ops { -// Describes the computation the chebyshev distance - +/** + * @brief the L_inf (Chebyshev) distance matrix calculation + * + * It computes the following equation: + * + * c_ij = max_k | x_ik - y_kj | + */ struct chebyshev_distance_op { // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index f17d67953e..11cc3ed4f4 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -20,9 +20,14 @@ namespace raft::distance::detail::ops { -// Describes the computation the correlation distance - - +/** @brief The correlation distance + * + * It computes the following equation: + * + * d(x, y) = ((x - mean(x)) â‹… (y - mean(y))) + * / + * (|| x - mean(x) ||_2 || y - mean(y) ||_2) + */ template struct correlation_distance_op { const DataT_struct* x2n; @@ -31,19 +36,13 @@ struct correlation_distance_op { IdxT_struct n; IdxT_struct k; - correlation_distance_op( - bool is_row_major, - const DataT_struct* x2n_, - const DataT_struct* y2n_, - IdxT_struct m_, - IdxT_struct n_, - IdxT_struct k_ - ) noexcept - : x2n(x2n_), - y2n(y2n_), - m(m_), - n(n_), - k(k_) + correlation_distance_op(bool is_row_major, + const DataT_struct* x2n_, + const DataT_struct* y2n_, + IdxT_struct m_, + IdxT_struct n_, + IdxT_struct k_) noexcept + : x2n(x2n_), y2n(y2n_), m(m_), n(n_), k(k_) { // The distance op is typically created before the row-major/col-major // swapping has been done. So we do it here. @@ -53,7 +52,6 @@ struct correlation_distance_op { } } - // Load norms of input data static constexpr bool use_norms = true; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index aa2eac01bc..d26b5aeda0 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -20,8 +20,13 @@ namespace raft::distance::detail::ops { -// Describes the computation the cosine distance - +/** + * @brief the expanded cosine distance matrix calculation + * + * It computes the following equation: + * + * d(x, y) = 1 - (x â‹… y) / ( ||x||_2 ||y||_2) + */ struct cosine_distance_op { // Load norms of input data static constexpr bool use_norms = true; @@ -60,7 +65,6 @@ struct cosine_distance_op { } }; - template struct cosine_cutlass_op { __device__ cosine_cutlass_op() noexcept {} @@ -71,5 +75,4 @@ struct cosine_cutlass_op { __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; - } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index b4f610be0a..02087e2874 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -20,13 +20,17 @@ namespace raft::distance::detail::ops { -// Describes the computation the hamming distance - +/** + * @brief the Hamming Unexpanded distance matrix calculation + * It computes the following equation: + * + * c_ij = sum_k (x_ik != y_kj) / k + */ template struct hamming_distance_op { IdxT_struct k; - hamming_distance_op(IdxT_struct k_) noexcept : k(k_) { } + hamming_distance_op(IdxT_struct k_) noexcept : k(k_) {} // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index b0fae700b5..0314565a03 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -19,10 +19,14 @@ namespace raft::distance::detail::ops { -// Describes the computation the hellinger distance -// -// Fill in the TODO items. - +/** + * @brief the Hellinger distance matrix calculation + * + * It computes the following equation: + * + * c_ij = sqrt(1 - sum_k sqrt(x_ik * y_kj)) + * + */ struct hellinger_distance_op { // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index 124010e96d..5e00faef74 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -21,6 +21,14 @@ namespace raft::distance::detail::ops { // Describes the computation the jensen_shannon distance +/** + * @brief the Jensen Shannon distance matrix calculation + * + * It computes the following equation: + * + * c_ij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) + * + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) + */ struct jensen_shannon_distance_op { // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index a97582aa5a..fe6e0dbbe1 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -19,18 +19,21 @@ namespace raft::distance::detail::ops { -// Describes the computation of the kl_divergence +/** + * @brief the KL Divergence distance matrix calculation + * + * It computes the following equation: + * + * c_ij = 0.5 * sum(x * log (x / y)); + */ struct kl_divergence_op { const bool is_row_major; const bool x_equal_y; - kl_divergence_op( - bool row_major_, - bool x_equal_y_=false - ) noexcept - : is_row_major(row_major_), - x_equal_y(x_equal_y_) - { } + kl_divergence_op(bool row_major_, bool x_equal_y_ = false) noexcept + : is_row_major(row_major_), x_equal_y(x_equal_y_) + { + } // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 4bb4a8796c..bb71a7801f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -19,7 +19,13 @@ namespace raft::distance::detail::ops { -// Describes the computation the l1 distance +/** + * @brief the L1 distance matrix calculation + * + * It computes the following equation: + * + * c_ij = sum_k abs(x_ik - y_kj) + */ struct l1_distance_op { // Do not load norms of data, the computation of L1 distance does not use them. static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 523019f417..d491493a63 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -20,9 +20,14 @@ namespace raft::distance::detail::ops { -// Describes the computation the l2 expanded distance -// -// TODO: more explanation. +/** + * @brief the expanded euclidean distance matrix calculation + * + * It computes the following equation: + * + * c_ij = - 2 sum_k x_ik * y_kj + ||x_i.||_2 + ||y_.j||_2 + * + */ struct l2_exp_distance_op { bool sqrt; diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index f5e2f278b7..6e75cc95e8 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -20,7 +20,13 @@ namespace raft::distance::detail::ops { -// Describes the computation the l2 unexpanded distance +/** + * @brief the unexpanded euclidean distance matrix calculation + * + * It computes the following equation: + * + * c_ij = optional_sqrt ( sum_k (x_ik - y_kj)^2 ) + */ struct l2_unexp_distance_op { bool sqrt; diff --git a/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh b/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh index 8deb42d1fe..0640cc72a7 100644 --- a/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh @@ -19,13 +19,18 @@ namespace raft::distance::detail::ops { -// Describes the computation the minkowski distance - +/** + * @brief the unexpanded Lp (Minkowski) distance matrix calculation + * + * It computes the following equation: + * + * c_ij = (sum_k |x_ik - y_jk|^p)^(1/p) + */ template struct minkowski_distance_op { DataT_struct p; - minkowski_distance_op(DataT_struct p_) noexcept : p(p_) { } + minkowski_distance_op(DataT_struct p_) noexcept : p(p_) {} // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index e114ef8224..f9fbc7221b 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -20,17 +20,19 @@ namespace raft::distance::detail::ops { -// Describes the computation the russel_rao distance - +/** + * @brief the Russell Rao distance matrix calculation + * + * It computes the following equation: + * + * c_ij = (k - (sum_k x_ik * y_kj)) / k + */ template struct russel_rao_distance_op { IdxT_struct k; const float one_over_k; - russel_rao_distance_op(IdxT_struct k_) noexcept - : k(k_), - one_over_k(1.0f / k_) - { } + russel_rao_distance_op(IdxT_struct k_) noexcept : k(k_), one_over_k(1.0f / k_) {} // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh deleted file mode 100644 index 3cdc5489a6..0000000000 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include "pairwise_matrix/dispatch.cuh" -#include "distance_ops/l2_exp.cuh" -#include "distance_ops/l2_unexp.cuh" - -namespace raft { -namespace distance { -namespace detail { - -// /** -// * @brief the expanded euclidean distance matrix calculation -// * It computes the following equation: C = op(A^2 + B^2 - 2AB) -// * @tparam InType input data-type (for A and B matrices) -// * @tparam AccType accumulation data-type -// * @tparam OutType output data-type (for C and D matrices) -// * @tparam FinalLambda the final lambda called by FragmentMultiplyAdd_ -// * @tparam Index_ index type -// * @param m number of rows of A and C/D -// * @param n number of columns of B and C/D -// * @param k number of cols of A and rows of B -// * @param pA input matrix -// * @param pB input matrix -// * @param pD output matrix -// * @param enable_sqrt if the square root is computed or not -// * @param workspace temporary workspace needed for computations -// * @param worksize number of bytes of the workspace -// * @param fin_op the final gemm epilogue lambda -// * @param stream cuda stream where to launch work -// * @param isRowMajor whether the input and output matrices are row major -// */ -template -void euclideanAlgo1(IdxT m, - IdxT n, - IdxT k, - const DataT* pA, - const DataT* pB, - OutT* pD, - bool enable_sqrt, - AccT* workspace, - size_t& worksize, - FinOpT fin_op, - cudaStream_t stream, - bool isRowMajor) -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), - "OutT can be uint8_t, float, double," - "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - - ASSERT( - !(((pA != pB) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), - "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - DataT* norm_A = workspace; - DataT* norm_B = workspace; - if (pA != pB) { - norm_B += m; - raft::linalg::rowNorm( - norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - raft::linalg::rowNorm( - norm_B, pB, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - } else { - raft::linalg::rowNorm( - norm_A, pA, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); - } - - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel otherwise. - - if constexpr (__CUDACC_VER_MAJOR__ == 12) { - // Always execute legacy kernels on CUDA 12 - ops::l2_exp_distance_op l2_op(enable_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); - } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using L2Op = ops::l2_exp_cutlass_op; - L2Op l2_op(enable_sqrt); - - distance_matrix_cutlass_dispatch( - l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); - } else { - // Else use "legacy" L2 - ops::l2_exp_distance_op l2_op(enable_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); - } - } -} - - -/** - * @brief the unexpanded euclidean distance matrix calculation - * It computes the following equation: cij = op((ai-bj)^2) - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param enable_sqrt if the square root is computed or not - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void euclideanAlgo2(IdxT m, - IdxT n, - IdxT k, - const DataT* pA, - const DataT* pB, - OutT* pD, - bool enable_sqrt, - FinOpT fin_op, - cudaStream_t stream, - bool isRowMajor) -{ - ops::l2_unexp_distance_op l2_op(enable_sqrt); - - // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; - - distance_matrix_dispatch( - l2_op, m, n, k, pA, pB, norm_A, norm_B, pD, fin_op, stream, isRowMajor); -} - -}; // end namespace detail -}; // end namespace distance -}; // end namespace raft diff --git a/cpp/include/raft/distance/detail/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh deleted file mode 100644 index 824e930023..0000000000 --- a/cpp/include/raft/distance/detail/hamming.cuh +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "distance_ops/hamming.cuh" -#include "pairwise_matrix/dispatch.cuh" - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the Hamming Unexpanded distance matrix calculation - * It computes the following equation: - Cij = sum(x_i != y_i) / k - * - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void hammingUnexpandedImpl(int m, - int n, - int k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - ops::hamming_distance_op distance_op{k}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh deleted file mode 100644 index 306977f266..0000000000 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include - -#include "pairwise_matrix/dispatch.cuh" -#include "distance_ops/hellinger.cuh" - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the Hellinger distance matrix calculation - * It computes the following equation: - sqrt(1 - sum(sqrt(x_k * y_k)) - * This distance computation modifies A and B by computing a sqrt - * and then performing a `pow(x, 2)` to convert it back. Because of this, - * it is possible that the values in A and B might differ slightly - * after this is invoked. - * - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void hellingerImpl(int m, - int n, - int k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // First sqrt x and y - const auto raft_sqrt = raft::linalg::unaryOp; - - raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { - raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); - } - - // Then calculate Hellinger distance - ops::hellinger_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); - - // Finally revert sqrt of x and y - raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { - raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); - } - - RAFT_CUDA_TRY(cudaGetLastError()); -} -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh deleted file mode 100644 index 71339e0c1a..0000000000 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "distance_ops/jensen_shannon.cuh" -#include "pairwise_matrix/dispatch.cuh" - - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the Jensen Shannon distance matrix calculation - * It computes the following equation: - Cij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) - + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) - * - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void jensenShannonImpl(int m, - int n, - int k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - ops::jensen_shannon_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh deleted file mode 100644 index e2f7bf2beb..0000000000 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include - -#include "distance_ops/kl_divergence.cuh" -#include "pairwise_matrix/dispatch.cuh" - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the KL Divergence distance matrix calculation - * It computes the following equation: - Cij = 0.5 * sum(x * log (x / y)); - * This distance computation modifies A or B by computing a log(x) - * and then performing a `pow(e, log(x))` to convert it back. Because of this, - * it is possible that the values in A or B might differ slightly - * after this is invoked. - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void klDivergenceImpl(int m, - int n, - int k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - auto unaryOp_lambda = [] __device__(DataT input) { - const bool x_zero = (input == 0); - return (!x_zero) * raft::log(input + x_zero); }; - - auto unaryOp_lambda_reverse = [] __device__(DataT input) { - // reverse previous log (x) back to x using (e ^ log(x)) - const bool x_zero = (input == 0); - return (!x_zero) * raft::exp(input); }; - - // This op takes some shortcuts when x equals y. So its behavior changes based - // on this. - ops::kl_divergence_op kl_divergence{is_row_major, x == y}; - - if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); - } - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); - - if (x != y) { - // Now reverse previous log (x) back to x using (e ^ log(x)) - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda_reverse, stream); - } -} -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh deleted file mode 100644 index cceb432c7d..0000000000 --- a/cpp/include/raft/distance/detail/l1.cuh +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "distance_ops/l1.cuh" -#include "pairwise_matrix/dispatch.cuh" - -namespace raft { -namespace distance { -namespace detail { - -template -void l1Impl(int m, - int n, - int k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - ops::l1_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh deleted file mode 100644 index 778ceb45cf..0000000000 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "pairwise_matrix/dispatch.cuh" -#include "distance_ops/minkowski.cuh" - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the unexpanded minkowski distance matrix calculation - * It computes the following equation: cij = sum(|x - y|^p)^(1/p) - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ index type - * @param[in] m number of rows of A and C/D - * @param[in] n number of rows of B and cols of C/D - * @param[in] k number of cols of A and B - * @param[in] pA input matrix - * @param[in] pB input matrix - * @param[out] pD output matrix - * @param[in] fin_op the final gemm epilogue lambda - * @param[in] stream cuda stream to launch work - * @param[in] isRowMajor whether the input and output matrices are row major - * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. - */ -template -void minkowskiImpl(IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - DataT metric_arg) -{ - ops::minkowski_distance_op distance_op{metric_arg}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} -}; // end namespace detail -}; // end namespace distance -}; // end namespace raft diff --git a/cpp/include/raft/distance/detail/russell_rao.cuh b/cpp/include/raft/distance/detail/russell_rao.cuh deleted file mode 100644 index 6bf5ae04bb..0000000000 --- a/cpp/include/raft/distance/detail/russell_rao.cuh +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "distance_ops/russel_rao.cuh" -#include "pairwise_matrix/dispatch.cuh" - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief the Russell Rao distance matrix calculation - * It computes the following equation: - Cij = (k - sum(x_i * y_i)) / k - * - * @tparam InType input data-type (for A and B matrices) - * @tparam AccType accumulation data-type - * @tparam OutType output data-type (for C and D matrices) - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major - */ -template -void russellRaoImpl(int m, - int n, - int k, - const DataT* x, - const DataT* y, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - ops::russel_rao_distance_op distance_op{k}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - distance_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} -} // namespace detail -} // namespace distance -} // namespace raft From 34ccddc0f09c36190f6da783249925cbd4cd2791 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 21 Feb 2023 11:18:33 +0100 Subject: [PATCH 43/93] Update readme --- cpp/include/raft/distance/detail/README.org | 53 ++++++++++++--------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org index 03d540cb84..99e59547d0 100644 --- a/cpp/include/raft/distance/detail/README.org +++ b/cpp/include/raft/distance/detail/README.org @@ -1,25 +1,32 @@ #+title: Readme -- [X] Euclidean - - *Notes*: - - enable_sqrt is now a runtime parameter. Was it a compile time - parameter before? - - CUTLASS fails on CUDA 12 (but prior to refactoring CUDA 12 did not work - either). I have not yet tested if everything works correctly on CUDA 11. -- [X] canberra.cuh -- [X] chebyshev.cuh -- [X] correlation.cuh -- [X] cosine.cuh - - *Notes*: cutlass fails on CUDA 12 (but prior to refactoring CUDA 12 did not - work either). I have not yet tested if everything works correctly on - CUDA 11. -- [X] hamming.cuh -- [X] hellinger.cuh -- [X] jensen_shannon.cuh -- [X] kl_divergence.cuh - - *Notes*: the isRowMajor and x_equal_y boolean parameters where previously - template / constexpr parameters. Now they are passed by value. This greatly - reduces the number of kernels, but may have negative consequences for run - time. -- [X] minkowski.cuh -- [X] russell_rao.cuh +* Overview + +| Metric | Epilog | Uses norms | Has params | Pre- & post-processing | Expensive inner loop | Depends on row_major | CUTLASS | +|----------------+--------+------------+---------------------------+------------------------+----------------------+----------------------+---------| +| Canberra | | | | | x | | | +| Chebyshev | | | | | | | | +| Correlation | x | x (twice) | x (many) | compute norms | | x | | +| Cosine | x | x | | compute norms | | | x | +| Hamming | x | | x (k) | | | | | +| Hellinger | x | | | sqrt and square | | | | +| Jensen Shannon | x | | | | x | | | +| KL divergence | x | | x (row_major, x_equals_y) | yes | x | x | | +| L1 | | | | | | | | +| L2 expanded | x | x | x (sqrt) | compute norms | | | x | +| L2 unexpanded | x | | x (sqrt) | | | | | +| Minkowski | x | | x (p) | | x | | | +| Russel-Rao | x | | x (k, 1/k) | | | | | + +* Tasks + +** TODO Architecture-conditional compilation +** TODO Clean up template arguments for kernel +** TODO Can we remove DataT_struct? +** TODO Include raft_cuda_utils +** TODO rename chebyshev -> Linf +** TODO remove this note about workspace + +: * @note if workspace is passed as nullptr, this will return in +: * worksize, the number of bytes of workspace required +** TODO Think of something wrt templates of distance_ops From 6a12ded58ba14f14b66b5d54dab266888b9b6e2e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 21 Feb 2023 12:23:45 +0100 Subject: [PATCH 44/93] Reenable device code generation Some code in dispatch was commented out in a futile attempt to keep compile times limited. --- .../detail/pairwise_matrix/dispatch.cuh | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 4a7c1f999f..b3362e7647 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -69,19 +69,19 @@ using vec_len_constant = std::integral_constant; template void dispatch(bool row_major, int vec_len, F&& f) { - // if (row_major) { - // switch (vec_len) { - // case 4: f(std::bool_constant(), vec_len_constant<4>()); break; - // case 2: f(std::bool_constant(), vec_len_constant<2>()); break; - // default: f(std::bool_constant(), vec_len_constant<1>()); break; - // } - // } else { - // switch (vec_len) { - // case 4: f(std::bool_constant(), vec_len_constant<4>()); break; - // case 2: f(std::bool_constant(), vec_len_constant<2>()); break; - // default: f(std::bool_constant(), vec_len_constant<1>()); break; - // } - // } + if (row_major) { + switch (vec_len) { + case 4: f(std::bool_constant(), vec_len_constant<4>()); break; + case 2: f(std::bool_constant(), vec_len_constant<2>()); break; + default: f(std::bool_constant(), vec_len_constant<1>()); break; + } + } else { + switch (vec_len) { + case 4: f(std::bool_constant(), vec_len_constant<4>()); break; + case 2: f(std::bool_constant(), vec_len_constant<2>()); break; + default: f(std::bool_constant(), vec_len_constant<1>()); break; + } + } } template Date: Tue, 21 Feb 2023 21:37:14 +0100 Subject: [PATCH 45/93] Readd overload of raft::distance::detail::distance --- cpp/include/raft/distance/detail/distance.cuh | 62 ++++++++++++++++--- cpp/include/raft/distance/distance.cuh | 2 +- 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 35a5b798b3..3e5b676294 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -722,12 +722,12 @@ void distance_impl(raft::resources const& handle, } /** - * @brief Evaluate pairwise distances and write to matrix - * + * @brief Evaluate pairwise distances with the user epilogue lamba allowed * @tparam DistanceType which distance to evaluate * @tparam InType input argument type * @tparam AccType accumulation type * @tparam OutType output type + * @tparam FinalLambda user-defined epilogue lamba * @tparam Index_ Index type * * @param x first set of points @@ -738,15 +738,20 @@ void distance_impl(raft::resources const& handle, * @param k dimensionality * @param workspace temporary workspace needed for computations * @param worksize number of bytes of the workspace + * @param fin_op the final gemm epilogue lambda + * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * - * @note if workspace is passed as nullptr, this will return in - * worksize, the number of bytes of workspace required + * @note fin_op: This is a device lambda which is supposed to operate upon the + * input which is AccType and returns the output in OutType. It's signature is + * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs + * any other parameters, feel free to pass them via closure. */ template void distance(raft::resources const& handle, const InType* x, @@ -757,17 +762,16 @@ void distance(raft::resources const& handle, Index_ k, void* workspace, size_t worksize, + FinalLambda fin_op, bool isRowMajor = true, InType metric_arg = 2.0f) { - auto fin_op = raft::identity_op(); - // raft distance support inputs as float/double and output as uint8_t/float/double. static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), "OutType can be uint8_t, float, double," "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); - distance_impl( + distance_impl( handle, distance_tag{}, x, @@ -784,6 +788,50 @@ void distance(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); } +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + * + * @note if workspace is passed as nullptr, this will return in + * worksize, the number of bytes of workspace required + */ +template +void distance(raft::resources const& handle, + const InType* x, + const InType* y, + OutType* out, + Index_ m, + Index_ n, + Index_ k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + InType metric_arg = 2.0f) +{ + auto fin_op = raft::identity_op(); + + distance( + handle, x, y, out, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); +} + /** * @brief Return the exact workspace size to compute the distance * @tparam DistanceType which distance to evaluate diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 59bf52a2ca..ddda68f789 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -253,7 +253,7 @@ void pairwise_distance(raft::resources const& handle, bool isRowMajor = true, Type metric_arg = 2.0f) { - auto stream = handle.get_stream(); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto dispatch = [&](auto distance_type) { auto worksize = getWorkspaceSize(x, y, m, n, k); From ca29e2d008827743748c0e4416b2717b1ea844e3 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 21 Feb 2023 22:02:55 +0100 Subject: [PATCH 46/93] Fix style --- cpp/include/raft/distance/detail/distance.cuh | 29 +++++++++---------- .../distance/detail/distance_ops/template.cuh | 2 +- .../detail/pairwise_distance_cutlass_base.cuh | 6 ++-- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 3e5b676294..573d5c2778 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -304,7 +304,6 @@ void distance_impl(raft::resources const& handle, distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } - template void distance_impl(raft::resources const& handle, distance_tag distance_type, @@ -320,18 +319,18 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - raft::linalg::gemm(handle, - out, - const_cast(x), - const_cast(y), - m, - n, - k, - !is_row_major, - !is_row_major, - is_row_major, - stream); + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + raft::linalg::gemm(handle, + out, + const_cast(x), + const_cast(y), + m, + n, + k, + !is_row_major, + !is_row_major, + is_row_major, + stream); } template @@ -560,7 +559,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - bool perform_sqrt = false; + bool perform_sqrt = false; cudaStream_t stream = raft::resource::get_cuda_stream(handle); distance_impl_l2_expanded( perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); @@ -581,7 +580,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - bool perform_sqrt = true; + bool perform_sqrt = true; cudaStream_t stream = raft::resource::get_cuda_stream(handle); distance_impl_l2_expanded( perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index 378bcf0c9f..1d2d681b18 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -27,7 +27,7 @@ namespace raft::distance::detail::ops { struct template_distance_op { TODO member; - template_distance_op(TODO member_) noexcept : member(member_) { } + template_distance_op(TODO member_) noexcept : member(member_) {} // Load norms of input data static constexpr bool use_norms = TODO; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index 0d26d940b3..2ab5c69b0d 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -169,8 +169,8 @@ void cutlassDistanceKernel(const DataT* x, CUTLASS_CHECK(status); } -}; // namespace detail -}; // namespace distance -}; // namespace raft +}; // namespace detail +}; // namespace distance +}; // namespace raft #pragma GCC diagnostic pop From 28c95a12901cc311d73f30ab9fe2b5596ff00ff3 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 11:31:11 +0100 Subject: [PATCH 47/93] Fix 11.8 compilation error --- cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index b3362e7647..23d0f34489 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -198,7 +198,7 @@ void distance_matrix_cutlass_dispatch(opT cutlass_op, // respectively. // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_aligned, static_cast(16 / sizeof(DataT))); + constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); cutlassDistanceKernel( x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); From a5592b9d2165fbd49efecbcf08c94d83b305200c Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 11:40:41 +0100 Subject: [PATCH 48/93] Rename minkowski -> lp_unexp --- cpp/include/raft/distance/detail/distance.cuh | 4 ++-- .../detail/distance_ops/{minkowski.cuh => lp_unexp.cuh} | 4 ++-- cpp/test/CMakeLists.txt | 2 +- cpp/test/distance/{dist_minkowski.cu => dist_lp_unexp.cu} | 0 4 files changed, 5 insertions(+), 5 deletions(-) rename cpp/include/raft/distance/detail/distance_ops/{minkowski.cuh => lp_unexp.cuh} (96%) rename cpp/test/distance/{dist_minkowski.cu => dist_lp_unexp.cu} (100%) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 573d5c2778..6d14fcca28 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -38,7 +38,7 @@ #include #include #include -#include +#include #include #include @@ -683,7 +683,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT metric_arg) { - ops::minkowski_distance_op distance_op{metric_arg}; + ops::lp_unexp_distance_op distance_op{metric_arg}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; diff --git a/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh similarity index 96% rename from cpp/include/raft/distance/detail/distance_ops/minkowski.cuh rename to cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index 0640cc72a7..4af6888ddf 100644 --- a/cpp/include/raft/distance/detail/distance_ops/minkowski.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -27,10 +27,10 @@ namespace raft::distance::detail::ops { * c_ij = (sum_k |x_ik - y_jk|^p)^(1/p) */ template -struct minkowski_distance_op { +struct lp_unexp_distance_op { DataT_struct p; - minkowski_distance_op(DataT_struct p_) noexcept : p(p_) {} + lp_unexp_distance_op(DataT_struct p_) noexcept : p(p_) {} // Load norms of input data static constexpr bool use_norms = false; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 575e8cf84b..928412568a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -121,7 +121,7 @@ if(BUILD_TESTS) test/distance/dist_jensen_shannon.cu test/distance/dist_kl_divergence.cu test/distance/dist_l1.cu - test/distance/dist_minkowski.cu + test/distance/dist_lp_unexp.cu test/distance/dist_russell_rao.cu test/distance/masked_nn.cu test/distance/masked_nn_compress_to_bits.cu diff --git a/cpp/test/distance/dist_minkowski.cu b/cpp/test/distance/dist_lp_unexp.cu similarity index 100% rename from cpp/test/distance/dist_minkowski.cu rename to cpp/test/distance/dist_lp_unexp.cu From 265ba0718aed82f8d5107119040ed8f7e4e53888 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 12:11:08 +0100 Subject: [PATCH 49/93] Rename Chebyshev -> l_inf --- cpp/CMakeLists.txt | 4 ++-- cpp/include/raft/distance/detail/distance.cuh | 4 ++-- .../distance/detail/distance_ops/{chebyshev.cuh => l_inf.cuh} | 2 +- .../specializations/detail/{chebyshev.cuh => l_inf.cuh} | 0 cpp/include/raft/distance/specializations/distance.cuh | 2 +- ...double_double_int.cu => l_inf_double_double_double_int.cu} | 0 ...loat_float_float_int.cu => l_inf_float_float_float_int.cu} | 0 cpp/test/CMakeLists.txt | 2 +- cpp/test/distance/{dist_chebyshev.cu => dist_l_inf.cu} | 0 9 files changed, 7 insertions(+), 7 deletions(-) rename cpp/include/raft/distance/detail/distance_ops/{chebyshev.cuh => l_inf.cuh} (98%) rename cpp/include/raft/distance/specializations/detail/{chebyshev.cuh => l_inf.cuh} (100%) rename cpp/src/distance/distance/specializations/detail/{chebyshev_double_double_double_int.cu => l_inf_double_double_double_int.cu} (100%) rename cpp/src/distance/distance/specializations/detail/{chebyshev_float_float_float_int.cu => l_inf_float_float_float_int.cu} (100%) rename cpp/test/distance/{dist_chebyshev.cu => dist_l_inf.cu} (100%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7e5b10b227..679a1747c1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -317,8 +317,6 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/cluster/kmeans_init_plus_plus_float.cu src/distance/distance/specializations/detail/canberra_double_double_double_int.cu src/distance/distance/specializations/detail/canberra_float_float_float_int.cu - src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu - src/distance/distance/specializations/detail/chebyshev_float_float_float_int.cu src/distance/distance/specializations/detail/correlation_double_double_double_int.cu src/distance/distance/specializations/detail/correlation_float_float_float_int.cu src/distance/distance/specializations/detail/cosine_double_double_double_int.cu @@ -352,6 +350,8 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu + src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu + src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 6d14fcca28..95cc9afa42 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -28,7 +28,6 @@ #include #include -#include #include #include #include @@ -38,6 +37,7 @@ #include #include #include +#include #include #include @@ -657,7 +657,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - ops::chebyshev_distance_op distance_op{}; + ops::l_inf_distance_op distance_op{}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; diff --git a/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh similarity index 98% rename from cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh rename to cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index d390f75460..0d515faa23 100644 --- a/cpp/include/raft/distance/detail/distance_ops/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -27,7 +27,7 @@ namespace raft::distance::detail::ops { * * c_ij = max_k | x_ik - y_kj | */ -struct chebyshev_distance_op { +struct l_inf_distance_op { // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/specializations/detail/chebyshev.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh similarity index 100% rename from cpp/include/raft/distance/specializations/detail/chebyshev.cuh rename to cpp/include/raft/distance/specializations/detail/l_inf.cuh diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index a0c35ca9a8..8daa398b49 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include #include @@ -31,6 +30,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu similarity index 100% rename from cpp/src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu rename to cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu diff --git a/cpp/src/distance/distance/specializations/detail/chebyshev_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu similarity index 100% rename from cpp/src/distance/distance/specializations/detail/chebyshev_float_float_float_int.cu rename to cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 928412568a..f0347b09be 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -109,7 +109,6 @@ if(BUILD_TESTS) PATH test/distance/dist_adj.cu test/distance/dist_canberra.cu - test/distance/dist_chebyshev.cu test/distance/dist_correlation.cu test/distance/dist_cos.cu test/distance/dist_euc_exp.cu @@ -121,6 +120,7 @@ if(BUILD_TESTS) test/distance/dist_jensen_shannon.cu test/distance/dist_kl_divergence.cu test/distance/dist_l1.cu + test/distance/dist_l_inf.cu test/distance/dist_lp_unexp.cu test/distance/dist_russell_rao.cu test/distance/masked_nn.cu diff --git a/cpp/test/distance/dist_chebyshev.cu b/cpp/test/distance/dist_l_inf.cu similarity index 100% rename from cpp/test/distance/dist_chebyshev.cu rename to cpp/test/distance/dist_l_inf.cu From 7ccb8a7e5c428c7d8acda207f5194a82a2e275a5 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 12:14:27 +0100 Subject: [PATCH 50/93] Rename euc -> l2 --- cpp/test/CMakeLists.txt | 6 +++--- cpp/test/distance/{dist_euc_exp.cu => dist_l2_exp.cu} | 0 .../distance/{dist_eucsqrt_exp.cu => dist_l2_sqrt_exp.cu} | 0 cpp/test/distance/{dist_euc_unexp.cu => dist_l2_unexp.cu} | 0 4 files changed, 3 insertions(+), 3 deletions(-) rename cpp/test/distance/{dist_euc_exp.cu => dist_l2_exp.cu} (100%) rename cpp/test/distance/{dist_eucsqrt_exp.cu => dist_l2_sqrt_exp.cu} (100%) rename cpp/test/distance/{dist_euc_unexp.cu => dist_l2_unexp.cu} (100%) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index f0347b09be..aa4487e9d5 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -111,15 +111,15 @@ if(BUILD_TESTS) test/distance/dist_canberra.cu test/distance/dist_correlation.cu test/distance/dist_cos.cu - test/distance/dist_euc_exp.cu - test/distance/dist_euc_unexp.cu - test/distance/dist_eucsqrt_exp.cu test/distance/dist_hamming.cu test/distance/dist_hellinger.cu test/distance/dist_inner_product.cu test/distance/dist_jensen_shannon.cu test/distance/dist_kl_divergence.cu test/distance/dist_l1.cu + test/distance/dist_l2_exp.cu + test/distance/dist_l2_unexp.cu + test/distance/dist_l2_sqrt_exp.cu test/distance/dist_l_inf.cu test/distance/dist_lp_unexp.cu test/distance/dist_russell_rao.cu diff --git a/cpp/test/distance/dist_euc_exp.cu b/cpp/test/distance/dist_l2_exp.cu similarity index 100% rename from cpp/test/distance/dist_euc_exp.cu rename to cpp/test/distance/dist_l2_exp.cu diff --git a/cpp/test/distance/dist_eucsqrt_exp.cu b/cpp/test/distance/dist_l2_sqrt_exp.cu similarity index 100% rename from cpp/test/distance/dist_eucsqrt_exp.cu rename to cpp/test/distance/dist_l2_sqrt_exp.cu diff --git a/cpp/test/distance/dist_euc_unexp.cu b/cpp/test/distance/dist_l2_unexp.cu similarity index 100% rename from cpp/test/distance/dist_euc_unexp.cu rename to cpp/test/distance/dist_l2_unexp.cu From 874d014ccb6c7e6816ffdfedeb3927c91f6883f2 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 12:22:40 +0100 Subject: [PATCH 51/93] Update copyright headers Files have moved --- cpp/test/distance/dist_l2_exp.cu | 2 +- cpp/test/distance/dist_l2_sqrt_exp.cu | 2 +- cpp/test/distance/dist_l2_unexp.cu | 2 +- cpp/test/distance/dist_l_inf.cu | 2 +- cpp/test/distance/dist_lp_unexp.cu | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/test/distance/dist_l2_exp.cu b/cpp/test/distance/dist_l2_exp.cu index 567e279691..ae67215e51 100644 --- a/cpp/test/distance/dist_l2_exp.cu +++ b/cpp/test/distance/dist_l2_exp.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/distance/dist_l2_sqrt_exp.cu b/cpp/test/distance/dist_l2_sqrt_exp.cu index d717158649..94d254f44b 100644 --- a/cpp/test/distance/dist_l2_sqrt_exp.cu +++ b/cpp/test/distance/dist_l2_sqrt_exp.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/distance/dist_l2_unexp.cu b/cpp/test/distance/dist_l2_unexp.cu index 311ad190e2..d74a41d2a4 100644 --- a/cpp/test/distance/dist_l2_unexp.cu +++ b/cpp/test/distance/dist_l2_unexp.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/distance/dist_l_inf.cu b/cpp/test/distance/dist_l_inf.cu index abad828de7..b9d6413a10 100644 --- a/cpp/test/distance/dist_l_inf.cu +++ b/cpp/test/distance/dist_l_inf.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/distance/dist_lp_unexp.cu b/cpp/test/distance/dist_lp_unexp.cu index af2661da3a..9d6f5921a7 100644 --- a/cpp/test/distance/dist_lp_unexp.cu +++ b/cpp/test/distance/dist_lp_unexp.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 757fb44f304979e6a4b4dcacb0adccba905f4952 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 12:25:29 +0100 Subject: [PATCH 52/93] Remove misleading note about workspace nullptr --- cpp/include/raft/distance/detail/distance.cuh | 3 --- cpp/include/raft/distance/distance.cuh | 3 --- 2 files changed, 6 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 95cc9afa42..bea5ced976 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -804,9 +804,6 @@ void distance(raft::resources const& handle, * @param worksize number of bytes of the workspace * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major - * - * @note if workspace is passed as nullptr, this will return in - * worksize, the number of bytes of workspace required */ template Date: Wed, 22 Feb 2023 12:26:51 +0100 Subject: [PATCH 53/93] Remove notes file --- cpp/include/raft/distance/detail/README.org | 32 --------------------- 1 file changed, 32 deletions(-) delete mode 100644 cpp/include/raft/distance/detail/README.org diff --git a/cpp/include/raft/distance/detail/README.org b/cpp/include/raft/distance/detail/README.org deleted file mode 100644 index 99e59547d0..0000000000 --- a/cpp/include/raft/distance/detail/README.org +++ /dev/null @@ -1,32 +0,0 @@ -#+title: Readme - -* Overview - -| Metric | Epilog | Uses norms | Has params | Pre- & post-processing | Expensive inner loop | Depends on row_major | CUTLASS | -|----------------+--------+------------+---------------------------+------------------------+----------------------+----------------------+---------| -| Canberra | | | | | x | | | -| Chebyshev | | | | | | | | -| Correlation | x | x (twice) | x (many) | compute norms | | x | | -| Cosine | x | x | | compute norms | | | x | -| Hamming | x | | x (k) | | | | | -| Hellinger | x | | | sqrt and square | | | | -| Jensen Shannon | x | | | | x | | | -| KL divergence | x | | x (row_major, x_equals_y) | yes | x | x | | -| L1 | | | | | | | | -| L2 expanded | x | x | x (sqrt) | compute norms | | | x | -| L2 unexpanded | x | | x (sqrt) | | | | | -| Minkowski | x | | x (p) | | x | | | -| Russel-Rao | x | | x (k, 1/k) | | | | | - -* Tasks - -** TODO Architecture-conditional compilation -** TODO Clean up template arguments for kernel -** TODO Can we remove DataT_struct? -** TODO Include raft_cuda_utils -** TODO rename chebyshev -> Linf -** TODO remove this note about workspace - -: * @note if workspace is passed as nullptr, this will return in -: * worksize, the number of bytes of workspace required -** TODO Think of something wrt templates of distance_ops From 885bda66bd94f39b0e39053cedf627d1d3a6e6c2 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 14:02:17 +0100 Subject: [PATCH 54/93] Put template on struct instead of methods --- cpp/include/raft/distance/detail/distance.cuh | 36 +++++++++---------- .../distance/detail/distance_ops/canberra.cuh | 6 ++-- .../detail/distance_ops/correlation.cuh | 29 ++++++++------- .../distance/detail/distance_ops/cosine.cuh | 6 ++-- .../distance/detail/distance_ops/hamming.cuh | 11 +++--- .../detail/distance_ops/hellinger.cuh | 6 ++-- .../detail/distance_ops/jensen_shannon.cuh | 6 ++-- .../detail/distance_ops/kl_divergence.cuh | 6 ++-- .../raft/distance/detail/distance_ops/l1.cuh | 6 ++-- .../distance/detail/distance_ops/l2_exp.cuh | 6 ++-- .../distance/detail/distance_ops/l2_unexp.cuh | 6 ++-- .../distance/detail/distance_ops/l_inf.cuh | 6 ++-- .../distance/detail/distance_ops/lp_unexp.cuh | 11 +++--- .../detail/distance_ops/russel_rao.cuh | 11 +++--- .../distance/detail/distance_ops/template.cuh | 4 +-- .../detail/pairwise_matrix/kernel_sm60.cuh | 4 +-- 16 files changed, 78 insertions(+), 82 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index bea5ced976..621e2d15b9 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -120,7 +120,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT metric_arg) // unused { - ops::canberra_distance_op distance_op{}; + ops::canberra_distance_op distance_op{}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; @@ -203,8 +203,8 @@ void distance_impl(raft::resources const& handle, raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); } - using CorrOp = ops::correlation_distance_op; - CorrOp corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); + using OpT = ops::correlation_distance_op; + OpT corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); distance_matrix_dispatch( corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); } @@ -257,7 +257,7 @@ void distance_impl(raft::resources const& handle, if constexpr (__CUDACC_VER_MAJOR__ == 12) { // Always execute legacy kernels on CUDA 12 - ops::cosine_distance_op distance_op{}; + ops::cosine_distance_op distance_op{}; distance_matrix_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { @@ -271,7 +271,7 @@ void distance_impl(raft::resources const& handle, distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { // Else use "legacy" L2 - ops::cosine_distance_op distance_op{}; + ops::cosine_distance_op distance_op{}; distance_matrix_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } @@ -293,7 +293,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - ops::hamming_distance_op distance_op{k}; + ops::hamming_distance_op distance_op{k}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; @@ -357,7 +357,7 @@ void distance_impl(raft::resources const& handle, if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } // Then calculate Hellinger distance - ops::hellinger_distance_op distance_op{}; + ops::hellinger_distance_op distance_op{}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; @@ -387,7 +387,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - ops::jensen_shannon_distance_op distance_op{}; + ops::jensen_shannon_distance_op distance_op{}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; @@ -428,7 +428,7 @@ void distance_impl(raft::resources const& handle, // This op takes some shortcuts when x equals y. So its behavior changes based // on this. - ops::kl_divergence_op kl_divergence{is_row_major, x == y}; + ops::kl_divergence_op kl_divergence{is_row_major, x == y}; if (x != y) { raft::linalg::unaryOp( @@ -463,13 +463,13 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - ops::l1_distance_op distance_op{}; + ops::l1_distance_op distance_op{}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + distance_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -523,7 +523,7 @@ void distance_impl_l2_expanded( // NOTE: different name if constexpr (__CUDACC_VER_MAJOR__ == 12) { // Always execute legacy kernels on CUDA 12 - ops::l2_exp_distance_op l2_op(perform_sqrt); + ops::l2_exp_distance_op l2_op(perform_sqrt); distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { @@ -537,7 +537,7 @@ void distance_impl_l2_expanded( // NOTE: different name l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { // Else use "legacy" L2 - ops::l2_exp_distance_op l2_op(perform_sqrt); + ops::l2_exp_distance_op l2_op(perform_sqrt); distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } @@ -602,7 +602,7 @@ void distance_impl(raft::resources const& handle, DataT) // metric_arg unused { bool perform_sqrt = false; - ops::l2_unexp_distance_op l2_op(perform_sqrt); + ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. const DataT* norm_A = nullptr; @@ -630,7 +630,7 @@ void distance_impl(raft::resources const& handle, DataT) // metric_arg unused { bool perform_sqrt = true; - ops::l2_unexp_distance_op l2_op(perform_sqrt); + ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. const DataT* norm_A = nullptr; @@ -657,7 +657,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - ops::l_inf_distance_op distance_op{}; + ops::l_inf_distance_op distance_op{}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; @@ -683,7 +683,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT metric_arg) { - ops::lp_unexp_distance_op distance_op{metric_arg}; + ops::lp_unexp_distance_op distance_op{metric_arg}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; @@ -709,7 +709,7 @@ void distance_impl(raft::resources const& handle, bool is_row_major, DataT) // metric_arg unused { - ops::russel_rao_distance_op distance_op{k}; + ops::russel_rao_distance_op distance_op{k}; const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 6491b24e3d..5ddf02e705 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -27,6 +27,7 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k |x_ik - y_kj| / ( |x_ik| + |y_kj| ) */ +template struct canberra_distance_op { // Load norms of input data static constexpr bool use_norms = false; @@ -36,13 +37,12 @@ struct canberra_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { const auto diff = raft::abs(x - y); @@ -52,7 +52,7 @@ struct canberra_distance_op { acc += ((add != 0) * diff / (add + (add == 0))); }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 11cc3ed4f4..d46cbf6718 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -28,26 +28,26 @@ namespace raft::distance::detail::ops { * / * (|| x - mean(x) ||_2 || y - mean(y) ||_2) */ -template +template struct correlation_distance_op { - const DataT_struct* x2n; - const DataT_struct* y2n; - IdxT_struct m; - IdxT_struct n; - IdxT_struct k; + const DataT* x2n; + const DataT* y2n; + IdxT m; + IdxT n; + IdxT k; correlation_distance_op(bool is_row_major, - const DataT_struct* x2n_, - const DataT_struct* y2n_, - IdxT_struct m_, - IdxT_struct n_, - IdxT_struct k_) noexcept + const DataT* x2n_, + const DataT* y2n_, + IdxT m_, + IdxT n_, + IdxT k_) noexcept : x2n(x2n_), y2n(y2n_), m(m_), n(n_), k(k_) { // The distance op is typically created before the row-major/col-major // swapping has been done. So we do it here. if (!is_row_major) { - std::swap(x2n, y2n); + std::swap(x2n, y2n); std::swap(m, n); } } @@ -60,19 +60,18 @@ struct correlation_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } - template DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index d26b5aeda0..422ec4a3aa 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -27,6 +27,7 @@ namespace raft::distance::detail::ops { * * d(x, y) = 1 - (x â‹… y) / ( ||x||_2 ||y||_2) */ +template struct cosine_distance_op { // Load norms of input data static constexpr bool use_norms = true; @@ -36,19 +37,18 @@ struct cosine_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } - template DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 02087e2874..6d050154d7 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -26,11 +26,11 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k (x_ik != y_kj) / k */ -template +template struct hamming_distance_op { - IdxT_struct k; + IdxT k; - hamming_distance_op(IdxT_struct k_) noexcept : k(k_) {} + hamming_distance_op(IdxT k_) noexcept : k(k_) {} // Load norms of input data static constexpr bool use_norms = false; @@ -40,19 +40,18 @@ struct hamming_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { acc += (x != y); }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index 0314565a03..c5e2b84ac2 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -27,6 +27,7 @@ namespace raft::distance::detail::ops { * c_ij = sqrt(1 - sum_k sqrt(x_ik * y_kj)) * */ +template struct hellinger_distance_op { // Load norms of input data static constexpr bool use_norms = false; @@ -36,13 +37,12 @@ struct hellinger_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { // This is sqrt(x) * sqrt(y). @@ -50,7 +50,7 @@ struct hellinger_distance_op { acc += product; }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index 5e00faef74..df5aadcf3b 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -29,6 +29,7 @@ namespace raft::distance::detail::ops { * c_ij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) * + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) */ +template struct jensen_shannon_distance_op { // Load norms of input data static constexpr bool use_norms = false; @@ -38,13 +39,12 @@ struct jensen_shannon_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { const DataT m = 0.5f * (x + y); @@ -56,7 +56,7 @@ struct jensen_shannon_distance_op { acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index fe6e0dbbe1..526927243f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -26,6 +26,7 @@ namespace raft::distance::detail::ops { * * c_ij = 0.5 * sum(x * log (x / y)); */ +template struct kl_divergence_op { const bool is_row_major; const bool x_equal_y; @@ -43,13 +44,12 @@ struct kl_divergence_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { // TODO: make sure that these branches get hoisted out of main loop.. Could @@ -75,7 +75,7 @@ struct kl_divergence_op { } }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index bb71a7801f..f152f1d83a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -26,6 +26,7 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k abs(x_ik - y_kj) */ +template struct l1_distance_op { // Do not load norms of data, the computation of L1 distance does not use them. static constexpr bool use_norms = false; @@ -35,19 +36,18 @@ struct l1_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { acc += raft::abs(x - y); }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index d491493a63..785e7804d6 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -28,6 +28,7 @@ namespace raft::distance::detail::ops { * c_ij = - 2 sum_k x_ik * y_kj + ||x_i.||_2 + ||y_.j||_2 * */ +template struct l2_exp_distance_op { bool sqrt; @@ -41,19 +42,18 @@ struct l2_exp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } - template DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index 6e75cc95e8..e03eb0a97e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -27,6 +27,7 @@ namespace raft::distance::detail::ops { * * c_ij = optional_sqrt ( sum_k (x_ik - y_kj)^2 ) */ +template struct l2_unexp_distance_op { bool sqrt; @@ -40,20 +41,19 @@ struct l2_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { const auto diff = x - y; acc += diff * diff; }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index 0d515faa23..caa1379133 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -27,6 +27,7 @@ namespace raft::distance::detail::ops { * * c_ij = max_k | x_ik - y_kj | */ +template struct l_inf_distance_op { // Load norms of input data static constexpr bool use_norms = false; @@ -36,20 +37,19 @@ struct l_inf_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { const auto diff = raft::abs(x - y); acc = raft::max(acc, diff); }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index 4af6888ddf..a4a090d058 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -26,11 +26,11 @@ namespace raft::distance::detail::ops { * * c_ij = (sum_k |x_ik - y_jk|^p)^(1/p) */ -template +template struct lp_unexp_distance_op { - DataT_struct p; + DataT p; - lp_unexp_distance_op(DataT_struct p_) noexcept : p(p_) {} + lp_unexp_distance_op(DataT p_) noexcept : p(p_) {} // Load norms of input data static constexpr bool use_norms = false; @@ -40,20 +40,19 @@ struct lp_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { const auto diff = raft::abs(x - y); acc += raft::pow(diff, p); }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index f9fbc7221b..0bac3beaff 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -27,12 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = (k - (sum_k x_ik * y_kj)) / k */ -template +template struct russel_rao_distance_op { - IdxT_struct k; + IdxT k; const float one_over_k; - russel_rao_distance_op(IdxT_struct k_) noexcept : k(k_), one_over_k(1.0f / k_) {} + russel_rao_distance_op(IdxT k_) noexcept : k(k_), one_over_k(1.0f / k_) {} // Load norms of input data static constexpr bool use_norms = false; @@ -42,19 +42,18 @@ struct russel_rao_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template + template constexpr size_t shared_mem_size() { return Policy::SmemSize; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index 1d2d681b18..b978cf2a36 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -24,6 +24,7 @@ namespace raft::distance::detail::ops { // // Fill in the TODO items. +template struct template_distance_op { TODO member; @@ -43,13 +44,12 @@ struct template_distance_op { return Policy::SmemSize + TODO; } - template DI void core(AccT& acc, DataT& x, DataT& y) const { TODO; }; - template + template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT* regxn, DataT* regyn, diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 1e450f9289..6856c09c37 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -56,7 +56,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(co IdxT gridStrideY) { // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog( + distance_op.template epilog( acc, regxn, regyn, gridStrideX, gridStrideY); }; @@ -123,7 +123,7 @@ void pairwise_matrix(OpT distance_op, dim3 blk(Policy::Nthreads); // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - size_t smem_size = distance_op.template shared_mem_size(); + size_t smem_size = distance_op.template shared_mem_size(); // Obtain function pointer to kernel auto kernel = pairwise_matrix_kernel; dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); From cd38ec646a6166f8263b8e0bad98aef39ed65898 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 14:35:24 +0100 Subject: [PATCH 55/93] Fix style --- .../distance/detail/distance_ops/correlation.cuh | 13 +++---------- .../raft/distance/detail/distance_ops/cosine.cuh | 5 +---- .../raft/distance/detail/distance_ops/hamming.cuh | 5 +---- .../raft/distance/detail/distance_ops/l1.cuh | 5 +---- .../raft/distance/detail/distance_ops/l2_exp.cuh | 5 +---- .../distance/detail/distance_ops/russel_rao.cuh | 5 +---- .../raft/distance/detail/distance_ops/template.cuh | 5 +---- .../distance/detail/pairwise_matrix/kernel_sm60.cuh | 3 +-- 8 files changed, 10 insertions(+), 36 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index d46cbf6718..3832104280 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -36,12 +36,8 @@ struct correlation_distance_op { IdxT n; IdxT k; - correlation_distance_op(bool is_row_major, - const DataT* x2n_, - const DataT* y2n_, - IdxT m_, - IdxT n_, - IdxT k_) noexcept + correlation_distance_op( + bool is_row_major, const DataT* x2n_, const DataT* y2n_, IdxT m_, IdxT n_, IdxT k_) noexcept : x2n(x2n_), y2n(y2n_), m(m_), n(n_), k(k_) { // The distance op is typically created before the row-major/col-major @@ -66,10 +62,7 @@ struct correlation_distance_op { return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } - DI void core(AccT& acc, DataT& x, DataT& y) const - { - acc += x * y; - }; + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index 422ec4a3aa..c3f3b75e62 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -43,10 +43,7 @@ struct cosine_distance_op { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } - DI void core(AccT& acc, DataT& x, DataT& y) const - { - acc += x * y; - }; + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 6d050154d7..98acf11560 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -46,10 +46,7 @@ struct hamming_distance_op { return Policy::SmemSize; } - DI void core(AccT& acc, DataT& x, DataT& y) const - { - acc += (x != y); - }; + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += (x != y); }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index f152f1d83a..b02971bac7 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -42,10 +42,7 @@ struct l1_distance_op { return Policy::SmemSize; } - DI void core(AccT& acc, DataT& x, DataT& y) const - { - acc += raft::abs(x - y); - }; + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += raft::abs(x - y); }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 785e7804d6..b68c44c8ba 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -48,10 +48,7 @@ struct l2_exp_distance_op { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } - DI void core(AccT& acc, DataT& x, DataT& y) const - { - acc += x * y; - }; + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index 0bac3beaff..7acd858e49 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -48,10 +48,7 @@ struct russel_rao_distance_op { return Policy::SmemSize; } - DI void core(AccT& acc, DataT& x, DataT& y) const - { - acc += x * y; - }; + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index b978cf2a36..b0f40123aa 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -44,10 +44,7 @@ struct template_distance_op { return Policy::SmemSize + TODO; } - DI void core(AccT& acc, DataT& x, DataT& y) const - { - TODO; - }; + DI void core(AccT& acc, DataT& x, DataT& y) const { TODO; }; template DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 6856c09c37..db7ceb64f4 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -56,8 +56,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(co IdxT gridStrideY) { // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog( - acc, regxn, regyn, gridStrideX, gridStrideY); + distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); }; // No support for row_epilog_op. From 749d000dbfd05a6c429eec3ab8475c21b6e38e65 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 16:03:46 +0100 Subject: [PATCH 56/93] Add dispatch based on compute architecture --- cpp/include/raft/distance/detail/distance.cuh | 26 ++-- .../detail/pairwise_matrix/dispatch.cuh | 9 +- .../detail/pairwise_matrix/kernel_sm60.cuh | 23 ++- cpp/include/raft/util/arch.cuh | 133 ++++++++++++++++++ 4 files changed, 173 insertions(+), 18 deletions(-) create mode 100644 cpp/include/raft/util/arch.cuh diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 621e2d15b9..da119b6a45 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -46,6 +46,7 @@ #include #include #include +#include #include namespace raft { @@ -261,8 +262,11 @@ void distance_impl(raft::resources const& handle, distance_matrix_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { + auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + + if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. using Op = ops::cosine_cutlass_op; Op distance_op{}; @@ -272,8 +276,8 @@ void distance_impl(raft::resources const& handle, } else { // Else use "legacy" L2 ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_matrix_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); } } } @@ -527,8 +531,11 @@ void distance_impl_l2_expanded( // NOTE: different name distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - const auto deviceVersion = getComputeCapability(); - if (deviceVersion.first >= 8) { + auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + + if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. using L2Op = ops::l2_exp_cutlass_op; L2Op l2_op(perform_sqrt); @@ -536,10 +543,11 @@ void distance_impl_l2_expanded( // NOTE: different name distance_matrix_cutlass_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - // Else use "legacy" L2 + // Else use "legacy" L2. Compile *only* for architectures in the legacy + // range. For newer architectures, compile empty kernels. ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_matrix_dispatch( + l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 23d0f34489..75680027d8 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -17,6 +17,7 @@ #include "kernel_sm60.cuh" #include +#include #include #include #include @@ -89,7 +90,8 @@ template + typename IdxT = int, + typename SM_compat_t = raft::arch::SM_range> void distance_matrix_dispatch(OpT distance_op, IdxT m, IdxT n, @@ -101,7 +103,8 @@ void distance_matrix_dispatch(OpT distance_op, OutT* out, FinOpT fin_op, cudaStream_t stream, - bool is_row_major) + bool is_row_major, + SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future())) { // Determine leading dimensions and, if column-major, flip order of passing x // and y. @@ -145,7 +148,7 @@ void distance_matrix_dispatch(OpT distance_op, typedef typename std::conditional::type Policy; return pairwise_matrix( - distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream); + distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream, sm_compat_range); }); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index db7ceb64f4..7404c47e66 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include @@ -28,7 +29,8 @@ template + typename FinOpT, + typename SM_compat_t> __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(const DataT* x, const DataT* y, const DataT* _xn, @@ -41,8 +43,15 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(co IdxT ldd, OutT* dOutput, opT distance_op, - FinOpT fin_op) + FinOpT fin_op, + SM_compat_t sm_compat_range) { + // Early exit to minimize the size of the kernel when it is not supposed to be compiled. + if constexpr(! sm_compat_range.contains(raft::arch::SM_compute_arch())) { + assert(false); + return; + } + extern __shared__ char smem[]; // Wrap operator back into lambdas. This is temporary and should be removed. (TODO) @@ -103,7 +112,8 @@ template + typename FinOpT, + typename SM_compat_t> void pairwise_matrix(OpT distance_op, FinOpT fin_op, const DataT* x, @@ -117,18 +127,19 @@ void pairwise_matrix(OpT distance_op, IdxT ldb, IdxT ldd, OutT* dOutput, - cudaStream_t stream) + cudaStream_t stream, + SM_compat_t sm_compat_range) { dim3 blk(Policy::Nthreads); // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) size_t smem_size = distance_op.template shared_mem_size(); // Obtain function pointer to kernel - auto kernel = pairwise_matrix_kernel; + auto kernel = pairwise_matrix_kernel; dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); kernel<<>>( - x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op); + x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op, sm_compat_range); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh new file mode 100644 index 0000000000..554805da9e --- /dev/null +++ b/cpp/include/raft/util/arch.cuh @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace raft::arch { + +/* raft::arch provides the following facilities: + * + * - raft::arch::SM_XX : hardcoded compile-time constants for various compute + * architectures. The values raft::arch::SM_min and raft::arch::SM_future + * represent architectures that are always smaller and larger (respectively) + * than any architecture that can be encountered in practice. + * + * - raft::arch::SM_compute_arch : a compile-time value for the *current* + * compute architecture that a kernel is compiled with. It can only be used + * inside kernels with a template argument. + * + * - raft::arch::kernel_runtime_arch : a function that computes at *run-time* + * which version of a kernel will launch (i.e., it will return the compute + * architecture of the version of the kernel that will be launched by the + * driver). + * + * - raft::arch::SM_range : a compile-time value to represent an open interval + * of compute architectures. This can be used to check if the current + * compile-time architecture is in a specified compatibility range. + */ + +// inner::SM_generic is a template to create a generic compile-time SM +// architecture constant. +namespace inner { +template +struct SM_generic { +public: + __host__ __device__ constexpr int value() const { + return n; + } +}; + +// A +__global__ inline void dummy_runtime_kernel() {} +} + +// A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) +// and SM_MIN and SM_FUTURE, that allow specifying an open interval of +// compatible compute architectures. +using SM_min = inner::SM_generic<350>; +using SM_60 = inner::SM_generic<600>; +using SM_70 = inner::SM_generic<700>; +using SM_75 = inner::SM_generic<750>; +using SM_80 = inner::SM_generic<800>; +using SM_86 = inner::SM_generic<860>; +using SM_90 = inner::SM_generic<900>; +using SM_future = inner::SM_generic<99999>; + +// This is a type that uses the __CUDA_ARCH__ macro to obtain the compile-time +// compute architecture. It can only be used where __CUDA_ARCH__ is defined, +// i.e., inside a __global__ function template. +struct SM_compute_arch { + template + __host__ __device__ constexpr int value() const { +#ifdef __CUDA_ARCH__ + return __CUDA_ARCH__; +#else + static_assert(dummy != 0, + "SM_compute_arch.value() is only callable from a __global__ function template. " + "A way to create a function template is by adding 'template '."); + return -1; +#endif + } +}; + +// A runtime value for the actual compute architecture of a kernel. +// +// A single kernel can be compiled for several "virtual" compute architectures. +// When a program runs, the driver picks the version of the kernel that most +// closely matches the current hardware. This struct reflects the virtual +// compute architecture of the version of the kernel that the driver picks when +// the kernel runs. +struct SM_runtime { + friend SM_runtime kernel_runtime_arch(); +private: + const int _version; + SM_runtime(int version) + : _version (version) {} + +public: + __host__ __device__ int value() const { + return _version; + } +}; + +// Computes which compute architecture of a kernel will run +// +// Semantics are described above in the documentation of SM_runtime. +SM_runtime kernel_runtime_arch() { + auto kernel = inner::dummy_runtime_kernel; + cudaFuncAttributes attributes; + cudaFuncGetAttributes(&attributes, kernel); + + return SM_runtime(10 * attributes.ptxVersion); +} + +// SM_range represents a range of SM architectures. It can be used to +// conditionally compile a kernel. +template +struct SM_range { +private: + const SM_MIN _min; + const SM_MAX _max; +public: + __host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) + : _min(min), _max(max) {} + + template + __host__ __device__ constexpr bool contains(SM_t current) const { + return _min.value() <= current.value() && current.value() < _max.value(); + } +}; + +} // namespace raft::arch From 72628613991f55c08e8e70e96c82ed050d84936b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 22 Feb 2023 16:39:09 +0100 Subject: [PATCH 57/93] Fix style --- cpp/include/raft/distance/detail/distance.cuh | 39 +++++++++++--- .../detail/pairwise_matrix/dispatch.cuh | 47 ++++++++++------ .../detail/pairwise_matrix/kernel_sm60.cuh | 43 +++++++++------ cpp/include/raft/util/arch.cuh | 53 +++++++++---------- 4 files changed, 114 insertions(+), 68 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index da119b6a45..7ebc7b3414 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -45,8 +45,8 @@ #include #include -#include #include +#include #include namespace raft { @@ -262,9 +262,9 @@ void distance_impl(raft::resources const& handle, distance_matrix_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto runtime_arch = raft::arch::kernel_runtime_arch(); auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. @@ -276,8 +276,25 @@ void distance_impl(raft::resources const& handle, } else { // Else use "legacy" L2 ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); + distance_matrix_dispatch(distance_op, + m, + n, + k, + x, + y, + norm_A, + norm_B, + out, + fin_op, + stream, + is_row_major, + legacy_range); } } } @@ -531,9 +548,9 @@ void distance_impl_l2_expanded( // NOTE: different name distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - auto runtime_arch = raft::arch::kernel_runtime_arch(); + auto runtime_arch = raft::arch::kernel_runtime_arch(); auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. @@ -546,7 +563,13 @@ void distance_impl_l2_expanded( // NOTE: different name // Else use "legacy" L2. Compile *only* for architectures in the legacy // range. For newer architectures, compile empty kernels. ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( + distance_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 75680027d8..9bbeca1e90 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -17,9 +17,9 @@ #include "kernel_sm60.cuh" #include -#include #include #include +#include #include namespace raft::distance::detail { @@ -90,21 +90,22 @@ template > -void distance_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future())) +void distance_matrix_dispatch( + OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major, + SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future())) { // Determine leading dimensions and, if column-major, flip order of passing x // and y. @@ -148,7 +149,21 @@ void distance_matrix_dispatch(OpT distance_op, typedef typename std::conditional::type Policy; return pairwise_matrix( - distance_op, fin_op, x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, stream, sm_compat_range); + distance_op, + fin_op, + x, + y, + x_norm, + y_norm, + m, + n, + k, + ldx, + ldy, + ld_out, + out, + stream, + sm_compat_range); }); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 7404c47e66..3f6474deeb 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -16,9 +16,9 @@ #pragma once #include -#include #include #include +#include namespace raft::distance::detail { @@ -31,23 +31,24 @@ template -__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - opT distance_op, - FinOpT fin_op, - SM_compat_t sm_compat_range) +__global__ __launch_bounds__(Policy::Nthreads, + 2) void pairwise_matrix_kernel(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + opT distance_op, + FinOpT fin_op, + SM_compat_t sm_compat_range) { // Early exit to minimize the size of the kernel when it is not supposed to be compiled. - if constexpr(! sm_compat_range.contains(raft::arch::SM_compute_arch())) { + if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { assert(false); return; } @@ -135,7 +136,15 @@ void pairwise_matrix(OpT distance_op, // https://en.cppreference.com/w/cpp/language/dependent_name) size_t smem_size = distance_op.template shared_mem_size(); // Obtain function pointer to kernel - auto kernel = pairwise_matrix_kernel; + auto kernel = pairwise_matrix_kernel; dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); kernel<<>>( diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 554805da9e..5103c2c591 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -43,26 +43,24 @@ namespace raft::arch { namespace inner { template struct SM_generic { -public: - __host__ __device__ constexpr int value() const { - return n; - } + public: + __host__ __device__ constexpr int value() const { return n; } }; // A __global__ inline void dummy_runtime_kernel() {} -} +} // namespace inner // A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) // and SM_MIN and SM_FUTURE, that allow specifying an open interval of // compatible compute architectures. -using SM_min = inner::SM_generic<350>; -using SM_60 = inner::SM_generic<600>; -using SM_70 = inner::SM_generic<700>; -using SM_75 = inner::SM_generic<750>; -using SM_80 = inner::SM_generic<800>; -using SM_86 = inner::SM_generic<860>; -using SM_90 = inner::SM_generic<900>; +using SM_min = inner::SM_generic<350>; +using SM_60 = inner::SM_generic<600>; +using SM_70 = inner::SM_generic<700>; +using SM_75 = inner::SM_generic<750>; +using SM_80 = inner::SM_generic<800>; +using SM_86 = inner::SM_generic<860>; +using SM_90 = inner::SM_generic<900>; using SM_future = inner::SM_generic<99999>; // This is a type that uses the __CUDA_ARCH__ macro to obtain the compile-time @@ -70,7 +68,8 @@ using SM_future = inner::SM_generic<99999>; // i.e., inside a __global__ function template. struct SM_compute_arch { template - __host__ __device__ constexpr int value() const { + __host__ __device__ constexpr int value() const + { #ifdef __CUDA_ARCH__ return __CUDA_ARCH__; #else @@ -91,21 +90,20 @@ struct SM_compute_arch { // the kernel runs. struct SM_runtime { friend SM_runtime kernel_runtime_arch(); -private: + + private: const int _version; - SM_runtime(int version) - : _version (version) {} + SM_runtime(int version) : _version(version) {} -public: - __host__ __device__ int value() const { - return _version; - } + public: + __host__ __device__ int value() const { return _version; } }; // Computes which compute architecture of a kernel will run // // Semantics are described above in the documentation of SM_runtime. -SM_runtime kernel_runtime_arch() { +SM_runtime kernel_runtime_arch() +{ auto kernel = inner::dummy_runtime_kernel; cudaFuncAttributes attributes; cudaFuncGetAttributes(&attributes, kernel); @@ -117,17 +115,18 @@ SM_runtime kernel_runtime_arch() { // conditionally compile a kernel. template struct SM_range { -private: + private: const SM_MIN _min; const SM_MAX _max; -public: - __host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) - : _min(min), _max(max) {} + + public: + __host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) : _min(min), _max(max) {} template - __host__ __device__ constexpr bool contains(SM_t current) const { + __host__ __device__ constexpr bool contains(SM_t current) const + { return _min.value() <= current.value() && current.value() < _max.value(); } }; -} // namespace raft::arch +} // namespace raft::arch From 09a30501c00b36af296fe6513135c5b8b95a69a6 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 6 Mar 2023 16:55:28 +0100 Subject: [PATCH 58/93] Fix linker error: multiple definition.. --- cpp/include/raft/util/arch.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 5103c2c591..ef703a8486 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -47,7 +47,7 @@ struct SM_generic { __host__ __device__ constexpr int value() const { return n; } }; -// A +// A dummy kernel that is used to determine the runtime architecture. __global__ inline void dummy_runtime_kernel() {} } // namespace inner @@ -102,7 +102,7 @@ struct SM_runtime { // Computes which compute architecture of a kernel will run // // Semantics are described above in the documentation of SM_runtime. -SM_runtime kernel_runtime_arch() +inline SM_runtime kernel_runtime_arch() { auto kernel = inner::dummy_runtime_kernel; cudaFuncAttributes attributes; From 646722114f872d13f5c0fcdc3d911e3f0abbad66 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 6 Mar 2023 17:59:04 +0100 Subject: [PATCH 59/93] Update cpp/include/raft/distance/detail/distance_ops/canberra.cuh Co-authored-by: Tamas Bela Feher --- cpp/include/raft/distance/detail/distance_ops/canberra.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 5ddf02e705..45bea08a95 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -48,7 +48,7 @@ struct canberra_distance_op { const auto diff = raft::abs(x - y); const auto add = raft::abs(x) + raft::abs(y); // deal with potential for 0 in denominator by - // forcing 1/0 instead + // forcing 0/1 instead acc += ((add != 0) * diff / (add + (add == 0))); }; From a83461e816417da12e313018e9b5a08437207044 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 6 Mar 2023 18:04:39 +0100 Subject: [PATCH 60/93] Update cpp/include/raft/distance/detail/distance.cuh Co-authored-by: Tamas Bela Feher --- cpp/include/raft/distance/detail/distance.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 621e2d15b9..7887eb96be 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -270,7 +270,7 @@ void distance_impl(raft::resources const& handle, distance_matrix_cutlass_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } else { - // Else use "legacy" L2 + // Else use "legacy" cosine kernel ops::cosine_distance_op distance_op{}; distance_matrix_dispatch( distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); From 393edf337c43594ef233799ae63e73ccc3fd3451 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 6 Mar 2023 18:25:14 +0100 Subject: [PATCH 61/93] Add note about alignment in case of byte input --- .../raft/distance/detail/pairwise_matrix/dispatch.cuh | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 23d0f34489..c95241cd0d 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -127,7 +127,15 @@ void distance_matrix_dispatch(OpT distance_op, // Compute number of elements that can be loaded in one instruction // without causing misalignent errors. - int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; + int vec_len_aligned; + if (byte_alignment % sizeof(DataT) == 0) { + // In the future, we might support `int8_t` input. In that case, + // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to + // prevent adding more cases in dispatch (which are expensive to compile). + vec_len_aligned = min(4, byte_alignment / sizeof(DataT)); + } else { + vec_len_aligned = 1; + } dispatch(is_row_major, vec_len_aligned, [&](auto row_major, auto vec_len_aligned) { // row_major and vec_len are std::integral_constants of type bool and int From 48a0c21ce78454b1cb10d7875bcf3985c27df3f7 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 7 Mar 2023 09:46:03 +0100 Subject: [PATCH 62/93] Fix --- cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index c95241cd0d..9def354600 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -16,6 +16,7 @@ #pragma once #include "kernel_sm60.cuh" +#include #include #include #include @@ -132,7 +133,7 @@ void distance_matrix_dispatch(OpT distance_op, // In the future, we might support `int8_t` input. In that case, // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to // prevent adding more cases in dispatch (which are expensive to compile). - vec_len_aligned = min(4, byte_alignment / sizeof(DataT)); + vec_len_aligned = std::min(4, int(byte_alignment / sizeof(DataT))); } else { vec_len_aligned = 1; } From 1a6636f994d3e6e48cd955ecd9ab59c159a54a1c Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 7 Mar 2023 14:20:43 +0100 Subject: [PATCH 63/93] Implement review feedback --- cpp/include/raft/util/arch.cuh | 35 +++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index ef703a8486..dfa29334f5 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -38,9 +38,9 @@ namespace raft::arch { * compile-time architecture is in a specified compatibility range. */ -// inner::SM_generic is a template to create a generic compile-time SM +// detail::SM_generic is a template to create a generic compile-time SM // architecture constant. -namespace inner { +namespace detail { template struct SM_generic { public: @@ -49,30 +49,39 @@ struct SM_generic { // A dummy kernel that is used to determine the runtime architecture. __global__ inline void dummy_runtime_kernel() {} -} // namespace inner +} // namespace detail // A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) // and SM_MIN and SM_FUTURE, that allow specifying an open interval of // compatible compute architectures. -using SM_min = inner::SM_generic<350>; -using SM_60 = inner::SM_generic<600>; -using SM_70 = inner::SM_generic<700>; -using SM_75 = inner::SM_generic<750>; -using SM_80 = inner::SM_generic<800>; -using SM_86 = inner::SM_generic<860>; -using SM_90 = inner::SM_generic<900>; -using SM_future = inner::SM_generic<99999>; +using SM_min = detail::SM_generic<350>; +using SM_60 = detail::SM_generic<600>; +using SM_70 = detail::SM_generic<700>; +using SM_75 = detail::SM_generic<750>; +using SM_80 = detail::SM_generic<800>; +using SM_86 = detail::SM_generic<860>; +using SM_90 = detail::SM_generic<900>; +using SM_future = detail::SM_generic<99999>; // This is a type that uses the __CUDA_ARCH__ macro to obtain the compile-time // compute architecture. It can only be used where __CUDA_ARCH__ is defined, // i.e., inside a __global__ function template. struct SM_compute_arch { template - __host__ __device__ constexpr int value() const + __device__ constexpr int value() const { #ifdef __CUDA_ARCH__ return __CUDA_ARCH__; #else + // This function should not be called in host code (because __CUDA_ARCH__ is + // not defined). This function is constexpr and thus can be called in host + // code (due to the --expt-relaxed-constexpr compile flag). We would like to + // provide an intelligible error message when this function is called in + // host code, which we do below. + // + // To make sure the static_assert only fires in host code, we use a dummy + // template parameter as described in P2593: + // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html static_assert(dummy != 0, "SM_compute_arch.value() is only callable from a __global__ function template. " "A way to create a function template is by adding 'template '."); @@ -104,7 +113,7 @@ struct SM_runtime { // Semantics are described above in the documentation of SM_runtime. inline SM_runtime kernel_runtime_arch() { - auto kernel = inner::dummy_runtime_kernel; + auto kernel = detail::dummy_runtime_kernel; cudaFuncAttributes attributes; cudaFuncGetAttributes(&attributes, kernel); From 31648028720265e66c5f705137257c870e7faad9 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 13 Mar 2023 18:55:27 +0100 Subject: [PATCH 64/93] Determine runtime arch using kernel pointer Determine runtime architecture using a kernel pointer that will be run. This avoids some problems with using a dummy kernel. It does require some reshuffling of the code to make sure that dispatch continues to work smoothly. --- cpp/include/raft/distance/detail/distance.cuh | 131 ++------- .../distance/detail/distance_ops/all_ops.cuh | 35 +++ .../distance/detail/distance_ops/canberra.cuh | 6 +- .../detail/distance_ops/correlation.cuh | 6 +- .../distance/detail/distance_ops/cosine.cuh | 27 +- .../distance/detail/distance_ops/cutlass.cuh | 40 +++ .../distance/detail/distance_ops/hamming.cuh | 6 +- .../detail/distance_ops/hellinger.cuh | 6 +- .../detail/distance_ops/jensen_shannon.cuh | 6 +- .../detail/distance_ops/kl_divergence.cuh | 6 +- .../raft/distance/detail/distance_ops/l1.cuh | 6 +- .../distance/detail/distance_ops/l2_exp.cuh | 48 ++-- .../distance/detail/distance_ops/l2_unexp.cuh | 6 +- .../distance/detail/distance_ops/l_inf.cuh | 6 +- .../distance/detail/distance_ops/lp_unexp.cuh | 6 +- .../detail/distance_ops/russel_rao.cuh | 6 +- .../distance/detail/distance_ops/template.cuh | 6 +- .../detail/pairwise_distance_cutlass_base.cuh | 36 +-- .../detail/pairwise_matrix/dispatch.cuh | 251 ++++-------------- .../pairwise_matrix/dispatch_layout.cuh | 115 ++++++++ .../detail/pairwise_matrix/dispatch_sm60.cuh | 76 ++++++ .../detail/pairwise_matrix/dispatch_sm80.cuh | 62 +++++ .../detail/pairwise_matrix/kernel_sm60.cuh | 170 +++++++----- .../detail/pairwise_matrix/params.cuh | 61 +++++ cpp/include/raft/util/arch.cuh | 12 +- 25 files changed, 703 insertions(+), 433 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/all_ops.cuh create mode 100644 cpp/include/raft/distance/detail/distance_ops/cutlass.cuh create mode 100644 cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh create mode 100644 cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh create mode 100644 cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh create mode 100644 cpp/include/raft/distance/detail/pairwise_matrix/params.cuh diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index f0c550ed43..f469250b45 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -27,20 +27,7 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - +#include #include #include @@ -127,7 +114,7 @@ void distance_impl(raft::resources const& handle, const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -206,7 +193,7 @@ void distance_impl(raft::resources const& handle, using OpT = ops::correlation_distance_op; OpT corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); - distance_matrix_dispatch( + pairwise_matrix_dispatch( corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); } @@ -249,54 +236,9 @@ void distance_impl(raft::resources const& handle, norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel otherwise. - - if constexpr (__CUDACC_VER_MAJOR__ == 12) { - // Always execute legacy kernels on CUDA 12 - ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } else { - auto runtime_arch = raft::arch::kernel_runtime_arch(); - auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using Op = ops::cosine_cutlass_op; - Op distance_op{}; - - distance_matrix_cutlass_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } else { - // Else use "legacy" cosine kernel - ops::cosine_distance_op distance_op{}; - distance_matrix_dispatch(distance_op, - m, - n, - k, - x, - y, - norm_A, - norm_B, - out, - fin_op, - stream, - is_row_major, - legacy_range); - } - } + ops::cosine_distance_op distance_op{}; + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } template @@ -321,7 +263,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -383,7 +325,7 @@ void distance_impl(raft::resources const& handle, const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); // Finally revert sqrt of x and y @@ -415,7 +357,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -459,7 +401,7 @@ void distance_impl(raft::resources const& handle, const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; - distance_matrix_dispatch( + pairwise_matrix_dispatch( kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); if (x != y) { @@ -490,7 +432,7 @@ void distance_impl(raft::resources const& handle, const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -535,44 +477,9 @@ void distance_impl_l2_expanded( // NOTE: different name norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel otherwise. - - if constexpr (__CUDACC_VER_MAJOR__ == 12) { - // Always execute legacy kernels on CUDA 12 - ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } else { - auto runtime_arch = raft::arch::kernel_runtime_arch(); - auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using L2Op = ops::l2_exp_cutlass_op; - L2Op l2_op(perform_sqrt); - - distance_matrix_cutlass_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); - } else { - // Else use "legacy" L2. Compile *only* for architectures in the legacy - // range. For newer architectures, compile empty kernels. - ops::l2_exp_distance_op l2_op(perform_sqrt); - distance_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major, legacy_range); - } - } + ops::l2_exp_distance_op distance_op{perform_sqrt}; + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } template @@ -641,7 +548,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } @@ -669,7 +576,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); } @@ -695,7 +602,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -721,7 +628,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } @@ -747,7 +654,7 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_matrix_dispatch( + pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } diff --git a/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh new file mode 100644 index 0000000000..3e8f4e86fb --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// Defines a named requirement "has_cutlass_op" +#include + +// The distance operations: +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 45bea08a95..930294ce31 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k |x_ik - y_kj| / ( |x_ik| + |y_kj| ) */ -template +template struct canberra_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 3832104280..289b69070a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -28,8 +28,12 @@ namespace raft::distance::detail::ops { * / * (|| x - mean(x) ||_2 || y - mean(y) ||_2) */ -template +template struct correlation_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + const DataT* x2n; const DataT* y2n; IdxT m; diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index c3f3b75e62..7c37c27b4e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -20,6 +20,17 @@ namespace raft::distance::detail::ops { +// Epilogue operator for CUTLASS based kernel +template +struct cosine_cutlass_op { + __device__ cosine_cutlass_op() noexcept {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + /** * @brief the expanded cosine distance matrix calculation * @@ -27,8 +38,12 @@ namespace raft::distance::detail::ops { * * d(x, y) = 1 - (x â‹… y) / ( ||x||_2 ||y||_2) */ -template +template struct cosine_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = true; // Whether the core function requires so many instructions that it makes sense @@ -60,16 +75,8 @@ struct cosine_distance_op { } } } -}; -template -struct cosine_cutlass_op { - __device__ cosine_cutlass_op() noexcept {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); - } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } + cosine_cutlass_op get_cutlass_op() { return cosine_cutlass_op(); } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh new file mode 100644 index 0000000000..d3eb90467b --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance::detail::ops { + +// This file defines the named requirement "has_cutlass_op" that can be used to +// determine if a distance operation has a CUTLASS op that can be used to pass +// to CUTLASS. Examples of distance operations that satisfy this requirement are +// cosine_distance_op and l2_exp_distance_op. + +// Primary template handles types that do not support CUTLASS. +// This pattern is described in: +// https://en.cppreference.com/w/cpp/types/void_t +template +struct has_cutlass_op : std::false_type { +}; + +// Specialization recognizes types that do support CUTLASS +template +struct has_cutlass_op> : std::true_type { +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 98acf11560..1cfdcfdc73 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k (x_ik != y_kj) / k */ -template +template struct hamming_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + IdxT k; hamming_distance_op(IdxT k_) noexcept : k(k_) {} diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index c5e2b84ac2..c4aecc7a6f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * c_ij = sqrt(1 - sum_k sqrt(x_ik * y_kj)) * */ -template +template struct hellinger_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index df5aadcf3b..41eeb9dd83 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -29,8 +29,12 @@ namespace raft::distance::detail::ops { * c_ij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) * + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) */ -template +template struct jensen_shannon_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index 526927243f..d046b62c30 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops { * * c_ij = 0.5 * sum(x * log (x / y)); */ -template +template struct kl_divergence_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + const bool is_row_major; const bool x_equal_y; diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index b02971bac7..8ec4000827 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops { * * c_ij = sum_k abs(x_ik - y_kj) */ -template +template struct l1_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Do not load norms of data, the computation of L1 distance does not use them. static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index fb00f8d66a..2a7af53813 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -20,6 +20,26 @@ namespace raft::distance::detail::ops { +// Epilogue operator for CUTLASS based kernel +template +struct l2_exp_cutlass_op { + bool sqrt; + + __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} + __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + // outVal could be negative due to numerical instability, especially when + // calculating self distance. + // clamp to 0 to avoid potential NaN in sqrt + outVal = outVal * (outVal > DataT(0.0)); + return sqrt ? raft::sqrt(outVal) : outVal; + } + + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + /** * @brief the expanded euclidean distance matrix calculation * @@ -28,8 +48,12 @@ namespace raft::distance::detail::ops { * c_ij = - 2 sum_k x_ik * y_kj + ||x_i.||_2 + ||y_.j||_2 * */ -template +template struct l2_exp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + bool sqrt; l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} @@ -62,6 +86,8 @@ struct l2_exp_distance_op { #pragma unroll for (int j = 0; j < Policy::AccColsPerTh; ++j) { DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + // val could be negative due to numerical instability, especially when + // calculating self distance. Clamp to 0 to avoid potential NaN in sqrt acc[i][j] = val * (val > DataT(0.0)); } } @@ -75,26 +101,8 @@ struct l2_exp_distance_op { } } } -}; - -// Epilogue operator for CUTLASS based kernel -template -struct l2_exp_cutlass_op { - bool sqrt; - - __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} - __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - // outVal could be negative due to numerical instability, especially when - // calculating self distance. - // clamp to 0 to avoid potential NaN in sqrt - outVal = outVal * (outVal > DataT(0.0)); - return sqrt ? raft::sqrt(outVal) : outVal; - } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } + l2_exp_cutlass_op get_cutlass_op() { return l2_exp_cutlass_op(sqrt); } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index e03eb0a97e..f0ea591eaf 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = optional_sqrt ( sum_k (x_ik - y_kj)^2 ) */ -template +template struct l2_unexp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + bool sqrt; l2_unexp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index caa1379133..fb21fb1a21 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = max_k | x_ik - y_kj | */ -template +template struct l_inf_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + // Load norms of input data static constexpr bool use_norms = false; // Whether the core function requires so many instructions that it makes sense diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index a4a090d058..71dfd51a6e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -26,8 +26,12 @@ namespace raft::distance::detail::ops { * * c_ij = (sum_k |x_ik - y_jk|^p)^(1/p) */ -template +template struct lp_unexp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + DataT p; lp_unexp_distance_op(DataT p_) noexcept : p(p_) {} diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index 7acd858e49..ea09e4d1db 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -27,8 +27,12 @@ namespace raft::distance::detail::ops { * * c_ij = (k - (sum_k x_ik * y_kj)) / k */ -template +template struct russel_rao_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + IdxT k; const float one_over_k; diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index b0f40123aa..6998f3cad4 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -24,8 +24,12 @@ namespace raft::distance::detail::ops { // // Fill in the TODO items. -template +template struct template_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + TODO member; template_distance_op(TODO member_) noexcept : member(member_) {} diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index 2ab5c69b0d..c5fdd28117 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -26,6 +26,7 @@ #endif #include +#include #include #include @@ -36,6 +37,8 @@ #include #include +#include + #include "./pairwise_distance_epilogue_elementwise.h" #include "./pairwise_distance_gemm.h" @@ -59,26 +62,29 @@ template -void cutlassDistanceKernel(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - DistanceFn dist_op, - cudaStream_t stream) +typename std::enable_if::value>::type cutlassDistanceKernel( + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + OpT distance_op, + cudaStream_t stream) { static_assert(!(std::is_same::value), "OutType bool is not supported use uint8_t instead"); + auto dist_op = distance_op.get_cutlass_op(); + using DistanceFn = decltype(dist_op); using EpilogueOutputOp = cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise -#include -#include +#include +#include +#include +#include +#include #include #include -#include +#include namespace raft::distance::detail { -/** - * @brief: Computes minimal common alignment of the rows in a 2D array in bytes - * - * The 2D matrix `x` is assumed to be row-major. This function computes the - * minimal alignment in bytes of the first elements of each row. - * Output can be 16, 8, 4, 2, 1. - * - * @param x Base pointer of row-major input matrix - * @param stride Stride in number of element between consecutive rows. - */ -template -size_t alignment_of_2d_array(const DataT* x, size_t stride) -{ - auto base = reinterpret_cast(x); - size_t stride_bytes = sizeof(DataT) * stride; - - for (int align = 16; align >= 0; align /= 2) { - bool base_aligned = base % align == 0; - bool stride_aligned = stride_bytes % align == 0; - if (base_aligned && stride_aligned) { return align; } - } - return 1; -} - -template -using vec_len_constant = std::integral_constant; - -/** - * @brief: Converts run-time arguments to compile-time arguments - * - * Converts run-time arguments row_major and vec_len to compile-time arguments - * and dispatches a lambda f with these compile-time arguments. - * - * This is equivalent to copying and pasting the lambda function `f` in each of - * the switch case statements. - * - * @tparam F Type of lambda f. - * @param row_major Boolean indicating whether input arrays have row-major layout. - * @param vec_len Integer value 1, 2, or 4 specifying the Veclen template parameter of - * the KernelPolicy. - * @param f Lambda that takes two std::integral_constant parameters representing - * row_major and vec_len. - */ -template -void dispatch(bool row_major, int vec_len, F&& f) -{ - if (row_major) { - switch (vec_len) { - case 4: f(std::bool_constant(), vec_len_constant<4>()); break; - case 2: f(std::bool_constant(), vec_len_constant<2>()); break; - default: f(std::bool_constant(), vec_len_constant<1>()); break; - } - } else { - switch (vec_len) { - case 4: f(std::bool_constant(), vec_len_constant<4>()); break; - case 2: f(std::bool_constant(), vec_len_constant<2>()); break; - default: f(std::bool_constant(), vec_len_constant<1>()); break; - } - } -} - template > -void distance_matrix_dispatch( - OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major, - SM_compat_t sm_compat_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future())) + typename IdxT = int> +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) { - // Determine leading dimensions and, if column-major, flip order of passing x - // and y. - IdxT ldx, ldy, ld_out; + // Create kernel parameter struct + pairwise_matrix_params params; if (is_row_major) { - ldx = k, ldy = k, ld_out = n; - } else { - // Flip x, y, and m, n. - std::swap(x, y); - std::swap(x_norm, y_norm); - std::swap(m, n); - ldx = m, ldy = n, ld_out = n; - } - - size_t align_x = alignment_of_2d_array(x, ldx); - size_t align_y = alignment_of_2d_array(y, ldy); - size_t byte_alignment = min(align_x, align_y); - - // Since alignment is in bytes, it could be smaller than sizeof(DataT). - // Handle this (unlikely) case here. - RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, - "Input matrix must be aligned to size of elements."); - - // Compute number of elements that can be loaded in one instruction - // without causing misalignent errors. - int vec_len_aligned; - if (byte_alignment % sizeof(DataT) == 0) { - // In the future, we might support `int8_t` input. In that case, - // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to - // prevent adding more cases in dispatch (which are expensive to compile). - vec_len_aligned = std::min(4, int(byte_alignment / sizeof(DataT))); + params = make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, is_row_major); } else { - vec_len_aligned = 1; + // Flip x and y + params = make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, is_row_major); } - dispatch(is_row_major, vec_len_aligned, [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // To keep compile times in check, we only specialize on veclen > 1 when - // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel below SM_80 - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; + constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; + constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); - return pairwise_matrix( - distance_op, - fin_op, - x, - y, - x_norm, - y_norm, - m, - n, - k, - ldx, - ldy, - ld_out, - out, - stream, - sm_compat_range); - }); -} - -template -void distance_matrix_cutlass_dispatch(opT cutlass_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Determine leading dimensions and possibly flip order of passing x and y if - // column_major. - IdxT ldx, ldy, ld_out; - if (is_row_major) { - ldx = k, ldy = k, ld_out = n; + if constexpr (is_ctk_12 || cutlass_op_unavailable) { + // Always execute legacy kernels on CUDA 12 + auto any_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()); + pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); } else { - std::swap(x, y); - std::swap(x_norm, y_norm); - std::swap(m, n); - ldx = m, ldy = n, ld_out = n; + auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); + auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + + // Get pointer to SM60 kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); + void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); + auto runtime_arch = raft::arch::kernel_runtime_arch(kernel_ptr); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); + } else { + // Reuse kernel wrapper that we obtained above. This avoids performing the + // dispatch twice. + sm60_wrapper.launch(distance_op, params, stream); + } } - - size_t align_x = alignment_of_2d_array(x, ldx); - size_t align_y = alignment_of_2d_array(y, ldy); - size_t byte_alignment = min(align_x, align_y); - - // Since alignment is in bytes, it could be smaller than sizeof(DataT). - // Handle this (unlikely) case here. - RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, - "Input matrix must be aligned to size of elements."); - - // Compute number of elements that can be loaded in one instruction - // without causing misalignent errors. - int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; - - dispatch(is_row_major, vec_len_aligned, [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); - - cutlassDistanceKernel( - x, y, x_norm, y_norm, m, n, k, ldx, ldy, ld_out, out, fin_op, cutlass_op, stream); - }); } }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh new file mode 100644 index 0000000000..c1e4c08af4 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "kernel_sm60.cuh" +#include +#include + +namespace raft::distance::detail { + +/** + * @brief: Computes minimal common alignment of the rows in a 2D array in bytes + * + * The 2D matrix `x` is assumed to be row-major. This function computes the + * minimal alignment in bytes of the first elements of each row. + * Output can be 16, 8, 4, 2, 1. + * + * @param x Base pointer of row-major input matrix + * @param stride Stride in number of element between consecutive rows. + */ +template +size_t alignment_of_2d_array(const DataT* x, size_t stride) +{ + auto base = reinterpret_cast(x); + size_t stride_bytes = sizeof(DataT) * stride; + + for (int align = 16; align >= 0; align /= 2) { + bool base_aligned = base % align == 0; + bool stride_aligned = stride_bytes % align == 0; + if (base_aligned && stride_aligned) { return align; } + } + return 1; +} + +/** + * @brief: Computes the vec_len parameter kernel policy parameter + * + * @param params Kernel parameters + */ +template +int determine_vec_len(pairwise_matrix_params params) +{ + size_t align_x = alignment_of_2d_array(params.x, params.ldx); + size_t align_y = alignment_of_2d_array(params.y, params.ldy); + size_t byte_alignment = min(align_x, align_y); + + // Since alignment is in bytes, it could be smaller than sizeof(DataT). + // Handle this (unlikely) case here. + RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, + "Input matrix must be aligned to size of elements."); + + // Compute number of elements that can be loaded in one instruction + // without causing misalignent errors. + int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; + + // In the future, pairwise_matrix might support `int8_t` input. In that case, + // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to + // prevent adding more cases in dispatch_layout below (which are expensive to + // compile). + vec_len_aligned = std::min(vec_len_aligned, 4); + + return vec_len_aligned; +} + +template +using vec_len_constant = std::integral_constant; + +/** + * @brief: Converts run-time arguments to compile-time arguments + * + * Converts run-time arguments row_major and vec_len to compile-time arguments + * and dispatches a lambda f with these compile-time arguments. + * + * This is equivalent to copying and pasting the lambda function `f` in each of + * the switch case statements. + * + * @tparam F Type of lambda f. + * @param row_major Boolean indicating whether input arrays have row-major layout. + * @param vec_len Integer value 1, 2, or 4 specifying the Veclen template parameter of + * the KernelPolicy. + * @param f Lambda that takes two std::integral_constant parameters representing + * row_major and vec_len. + */ +template +auto dispatch_layout(bool row_major, int vec_len, F&& f) +{ + if (row_major) { + switch (vec_len) { + case 4: return f(std::bool_constant(), vec_len_constant<4>()); + case 2: return f(std::bool_constant(), vec_len_constant<2>()); + default: return f(std::bool_constant(), vec_len_constant<1>()); + } + } else { + switch (vec_len) { + case 4: return f(std::bool_constant(), vec_len_constant<4>()); + case 2: return f(std::bool_constant(), vec_len_constant<2>()); + default: return f(std::bool_constant(), vec_len_constant<1>()); + } + } +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh new file mode 100644 index 0000000000..6e284007ea --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::distance::detail { + +template +pairwise_matrix_sm60_wrapper pairwise_matrix_sm60_get_wrapper( + OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range) +{ + int vec_len = determine_vec_len(params); + + return dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // row_major and vec_len are std::integral_constants of type bool and int + // respectively. + + // To keep compile times in check, we only specialize on veclen > 1 when + // the inner loop is relatively cheap (< 5 flops). + constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); + + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); + + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + typedef typename std::conditional::type Policy; + + auto wrapper = + make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); + + return wrapper; + }); +} + +template +void pairwise_matrix_sm60_dispatch(OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range, + cudaStream_t stream) +{ + auto wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, sm_compat_range); + + wrapper.launch(distance_op, params, stream); +} + +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh new file mode 100644 index 0000000000..ec2d522c25 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // std::min +#include +#include + +namespace raft::distance::detail { + +template +void pairwise_matrix_sm80_dispatch(OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range, + cudaStream_t stream) +{ + int vec_len = determine_vec_len(params); + + dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // row_major and vec_len are std::integral_constants of type bool and int + // respectively. + + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); + + using AccT = typename OpT::AccT; + cutlassDistanceKernel(params.x, + params.y, + params.x_norm, + params.y_norm, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.out, + params.fin_op, + distance_op, + stream); + }); +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index adf1efc65c..6e3ab7b26b 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -18,36 +18,24 @@ #include #include #include +#include #include namespace raft::distance::detail { template -__global__ __launch_bounds__(Policy::Nthreads, - 2) void pairwise_matrix_kernel(const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - opT distance_op, - FinOpT fin_op, - SM_compat_t sm_compat_range) + typename FinOpT> +__global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( + OpT distance_op, pairwise_matrix_params params) { // Early exit to minimize the size of the kernel when it is not supposed to be compiled. + constexpr SM_compat_t sm_compat_range{}; if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { assert(false); return; @@ -55,6 +43,8 @@ __global__ __launch_bounds__(Policy::Nthreads, extern __shared__ char smem[]; + using AccT = typename OpT::AccT; + // Wrap operator back into lambdas. This is temporary and should be removed. // See: https://github.com/rapidsai/raft/issues/1323 auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { @@ -84,73 +74,123 @@ __global__ __launch_bounds__(Policy::Nthreads, Policy, decltype(core_op), decltype(epilog_op), - decltype(fin_op), + decltype(params.fin_op), decltype(row_epilog_op), row_major, write_out> - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - _xn, - _yn, - dOutput, + obj(params.x, + params.y, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.x_norm, + params.y_norm, + params.out, smem, core_op, epilog_op, - fin_op, + params.fin_op, row_epilog_op); obj.run(); } template + typename FinOpT> void pairwise_matrix(OpT distance_op, - FinOpT fin_op, - const DataT* x, - const DataT* y, - const DataT* _xn, - const DataT* _yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - cudaStream_t stream, - SM_compat_t sm_compat_range) + pairwise_matrix_params params, + cudaStream_t stream) { dim3 blk(Policy::Nthreads); // Use .template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) size_t smem_size = distance_op.template shared_mem_size(); // Obtain function pointer to kernel - auto kernel = pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(m, n, smem_size, kernel); - - kernel<<>>( - x, y, _xn, _yn, m, n, k, lda, ldb, ldd, dOutput, distance_op, fin_op, sm_compat_range); + auto kernel = + pairwise_matrix_kernel; + dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); + + kernel<<>>(distance_op, params); RAFT_CUDA_TRY(cudaGetLastError()); } +// The type of a pointer to the pairwise matrix kernel. The following template +// arguments are type-erased: +// +// - The kernel policy +// - row_major +// - SM_compat_t +template +using pairwise_matrix_kernel_t = void (*)(OpT, pairwise_matrix_params); + +// A wrapper for the pairwise matrix kernel launch. Includes kernel launch +// parameters. +template +struct pairwise_matrix_sm60_wrapper { + dim3 grid; + dim3 block; + int smem_size; + pairwise_matrix_kernel_t kernel_ptr; + + void launch(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) + { + kernel_ptr<<>>(distance_op, params); + RAFT_CUDA_TRY(cudaGetLastError()); + } +}; + +/** @brief: Create kernel launch wrapper for pairwise matrix kernel + * + * This can be used to type-erase the kernel execution policy, row_major, and SM + * compatibility range. + * + * @tparam Policy: Kernel execution policy + * @tparam row_major: Indicates whether input matrices are row major + * @tparam OpT: Type of distance operation + * @tparam IdxT: Index type + * @tparam DataT: Data type + * @tparam OutT: Output data type + * @tparam FinOpT: Final operation type + * @tparam SM_compat_t: Type of the SM architecture compatibility + * + * @param distance_op: Distance operation + * @param params: Parameters + * @param sm_compat_range: Which SM architectures to compile for. + */ +template +pairwise_matrix_sm60_wrapper make_pairwise_matrix_sm60_wrapper( + OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range) +{ + dim3 block(Policy::Nthreads); + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_size = distance_op.template shared_mem_size(); + // Obtain function pointer to kernel + auto kernel = + pairwise_matrix_kernel; + dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); + + return pairwise_matrix_sm60_wrapper{ + grid, block, smem_size, kernel}; +} + }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh new file mode 100644 index 0000000000..d7fc2b28c3 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace raft::distance::detail { + +template +struct pairwise_matrix_params { + IdxT m; + IdxT n; + IdxT k; + IdxT ldx; + IdxT ldy; + IdxT ld_out; + const DataT* x; + const DataT* y; + const DataT* x_norm; + const DataT* y_norm; + OutT* out; + FinOpT fin_op; + bool is_row_major; +}; + +template +pairwise_matrix_params make_params(IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + bool is_row_major) +{ + // Determine leading dimensions. + IdxT ldx, ldy, ld_out; + if (is_row_major) { + ldx = k, ldy = k, ld_out = n; + } else { + ldx = m, ldy = n, ld_out = m; + } + + return pairwise_matrix_params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; +} + +} // namespace raft::distance::detail diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index dfa29334f5..13294196a7 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -98,7 +98,7 @@ struct SM_compute_arch { // compute architecture of the version of the kernel that the driver picks when // the kernel runs. struct SM_runtime { - friend SM_runtime kernel_runtime_arch(); + friend SM_runtime kernel_runtime_arch(void*); private: const int _version; @@ -111,9 +111,14 @@ struct SM_runtime { // Computes which compute architecture of a kernel will run // // Semantics are described above in the documentation of SM_runtime. -inline SM_runtime kernel_runtime_arch() +// +// This function requires a pointer to the kernel that will run. Other methods +// to determine the architecture (that do not require a pointer) can be error +// prone. See: +// // https://github.com/NVIDIA/cub/issues/545 +inline SM_runtime kernel_runtime_arch(void* kernel) { - auto kernel = detail::dummy_runtime_kernel; + // TODO: consider error handling... cudaFuncAttributes attributes; cudaFuncGetAttributes(&attributes, kernel); @@ -130,6 +135,7 @@ struct SM_range { public: __host__ __device__ constexpr SM_range(SM_MIN min, SM_MAX max) : _min(min), _max(max) {} + __host__ __device__ constexpr SM_range() : _min(SM_MIN()), _max(SM_MAX()) {} template __host__ __device__ constexpr bool contains(SM_t current) const From 35014e659a5d9037c7b89b9a1572e58f11812d60 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 08:57:10 +0100 Subject: [PATCH 65/93] Fix Gram compilation error --- .../distance/detail/pairwise_matrix/dispatch.cuh | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 78d5a92cc7..651490b6be 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -45,14 +45,10 @@ void pairwise_matrix_dispatch(OpT distance_op, cudaStream_t stream, bool is_row_major) { - // Create kernel parameter struct - pairwise_matrix_params params; - if (is_row_major) { - params = make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, is_row_major); - } else { - // Flip x and y - params = make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, is_row_major); - } + // Create kernel parameter struct. Flip x and y if column major. + pairwise_matrix_params params = + is_row_major ? make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, is_row_major) + : make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, is_row_major); // On CUDA 12: // - always execute normal kernel From 15161980ae67fa2c47808a8b5092c17ef5c2ebd7 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 11:57:51 +0100 Subject: [PATCH 66/93] Reformat comments --- cpp/include/raft/util/arch.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 13294196a7..8c48b87269 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -115,10 +115,9 @@ struct SM_runtime { // This function requires a pointer to the kernel that will run. Other methods // to determine the architecture (that do not require a pointer) can be error // prone. See: -// // https://github.com/NVIDIA/cub/issues/545 +// https://github.com/NVIDIA/cub/issues/545 inline SM_runtime kernel_runtime_arch(void* kernel) { - // TODO: consider error handling... cudaFuncAttributes attributes; cudaFuncGetAttributes(&attributes, kernel); From e399afaecf789d707d0e5fa25c08ea8eb799c0df Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 13:21:04 +0100 Subject: [PATCH 67/93] Fix kl_divergence index type --- cpp/include/raft/distance/detail/distance.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index f469250b45..3abbf05cfe 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -391,7 +391,7 @@ void distance_impl(raft::resources const& handle, // This op takes some shortcuts when x equals y. So its behavior changes based // on this. - ops::kl_divergence_op kl_divergence{is_row_major, x == y}; + ops::kl_divergence_op kl_divergence{is_row_major, x == y}; if (x != y) { raft::linalg::unaryOp( From f738d0dda4af56bef5e7a80b24469e2622770a25 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 14:17:44 +0100 Subject: [PATCH 68/93] Remove spurious includes from pairwise_distance_base --- cpp/include/raft/distance/detail/pairwise_distance_base.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 0293f10c29..d5779b2eaf 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -16,10 +16,8 @@ #pragma once #include #include -#include #include #include -#include #include From fa09bf7d692aaac7e41e60e25faafd3cc827585b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 14:43:14 +0100 Subject: [PATCH 69/93] Instantiate kernel launch code --- cpp/CMakeLists.txt | 4 - .../detail/00_write_template.py | 172 ++++++++++++++++++ .../specializations/detail/canberra.cuh | 73 ++++---- .../specializations/detail/correlation.cuh | 72 ++++---- .../specializations/detail/cosine.cuh | 72 ++++---- .../detail/hamming_unexpanded.cuh | 72 ++++---- .../detail/hellinger_expanded.cuh | 73 ++++---- .../specializations/detail/jensen_shannon.cuh | 73 ++++---- .../specializations/detail/kl_divergence.cuh | 72 ++++---- .../distance/specializations/detail/l1.cuh | 73 ++++---- .../specializations/detail/l2_expanded.cuh | 72 ++++---- .../detail/l2_sqrt_expanded.cuh | 54 ------ .../detail/l2_sqrt_unexpanded.cuh | 54 ------ .../specializations/detail/l2_unexpanded.cuh | 72 ++++---- .../distance/specializations/detail/l_inf.cuh | 73 ++++---- .../specializations/detail/lp_unexpanded.cuh | 72 ++++---- .../specializations/detail/russel_rao.cuh | 73 ++++---- .../distance/specializations/distance.cuh | 2 - .../detail/00_write_template.py | 164 +++++++++++++++++ .../canberra_double_double_double_int.cu | 46 +++-- .../detail/canberra_float_float_float_int.cu | 45 +++-- .../correlation_double_double_double_int.cu | 45 ++--- .../correlation_float_float_float_int.cu | 45 +++-- .../detail/cosine_double_double_double_int.cu | 44 +++-- .../detail/cosine_float_float_float_int.cu | 44 +++-- ...ing_unexpanded_double_double_double_int.cu | 45 ++--- ...amming_unexpanded_float_float_float_int.cu | 45 +++-- ...inger_expanded_double_double_double_int.cu | 45 ++--- ...ellinger_expanded_float_float_float_int.cu | 44 +++-- ...jensen_shannon_double_double_double_int.cu | 45 +++-- .../jensen_shannon_float_float_float_int.cu | 45 +++-- .../kl_divergence_double_double_double_int.cu | 45 +++-- .../kl_divergence_float_float_float_int.cu | 45 +++-- .../detail/l1_double_double_double_int.cu | 45 +++-- .../detail/l1_float_float_float_int.cu | 45 +++-- .../l2_expanded_double_double_double_int.cu | 46 +++-- .../l2_expanded_float_float_float_int.cu | 45 +++-- ..._sqrt_expanded_double_double_double_int.cu | 38 ---- .../l2_sqrt_expanded_float_float_float_int.cu | 38 ---- ...qrt_unexpanded_double_double_double_int.cu | 38 ---- ...2_sqrt_unexpanded_float_float_float_int.cu | 38 ---- .../l2_unexpanded_double_double_double_int.cu | 45 +++-- .../l2_unexpanded_float_float_float_int.cu | 45 +++-- .../detail/l_inf_double_double_double_int.cu | 44 +++-- .../detail/l_inf_float_float_float_int.cu | 45 +++-- .../lp_unexpanded_double_double_double_int.cu | 45 +++-- .../lp_unexpanded_float_float_float_int.cu | 45 +++-- .../russel_rao_double_double_double_int.cu | 46 ++--- .../russel_rao_float_float_float_int.cu | 45 +++-- 49 files changed, 1517 insertions(+), 1196 deletions(-) create mode 100644 cpp/include/raft/distance/specializations/detail/00_write_template.py delete mode 100644 cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh create mode 100644 cpp/src/distance/distance/specializations/detail/00_write_template.py delete mode 100644 cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu delete mode 100644 cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 83390ea881..5253534dd1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -344,10 +344,6 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/distance/specializations/detail/l1_double_double_double_int.cu src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu - src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu - src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu - src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py new file mode 100644 index 0000000000..364d3cb6cb --- /dev/null +++ b/cpp/include/raft/distance/specializations/detail/00_write_template.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +# This template manages all files in this directory, apart from +# inner_product.cuh and kernels.cuh. + + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +start_template = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance::detail { + +""" + +extern_template = """extern template void +pairwise_matrix_dispatch( + OpT, + IdxT, + IdxT, + IdxT, + const DataT*, + const DataT*, + const DataT*, + const DataT*, + OutT*, + FinopT, + cudaStream_t , + bool); +""" + +end_template = """} // namespace raft::distance::detail +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + + + + +op_instances = [ + dict( + path_prefix="canberra", + OpT="ops::canberra_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="correlation", + OpT="ops::correlation_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="cosine", + OpT="ops::cosine_distance_op", + # cosine uses CUTLASS for SM80+ + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="hamming_unexpanded", + OpT="ops::hamming_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="hellinger_expanded", + OpT="ops::hellinger_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="ops::jensen_shannon_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="kl_divergence", + OpT="ops::kl_divergence_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="l1", + OpT="ops::l1_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="l2_expanded", + OpT="ops::l2_exp_distance_op", + # L2 expanded uses CUTLASS for SM80+ + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="l2_unexpanded", + OpT="ops::l2_unexp_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="l_inf", + OpT="ops::l_inf_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="lp_unexpanded", + OpT="ops::lp_unexp_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="russel_rao", + OpT="ops::russel_rao_distance_op", + SM_compat_t="raft::arch::SM_range", + ), +] + +def fill_in(s, template): + for k, v in template.items(): + s = s.replace(k, v) + return s + +for op_instance in op_instances: + path = fill_in("path_prefix.cuh", op_instance) + with open(path, "w") as f: + f.write(start_template) + + for data_type_instance in data_type_instances: + op_data_instance = { + k : fill_in(v, data_type_instance) + for k, v in op_instance.items() + } + instance = { + **op_data_instance, + **data_type_instance, + "FinopT": "decltype(raft::identity_op())", + } + + text = fill_in(extern_template, instance) + + f.write(text) + + f.write(end_template) diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh index badce715a5..6f86b8bce5 100644 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ b/cpp/include/raft/distance/specializations/detail/canberra.cuh @@ -16,37 +16,48 @@ #pragma once -#include #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::canberra_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::canberra_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh index 013a0d43a3..ec1fe25e1d 100644 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ b/cpp/include/raft/distance/specializations/detail/correlation.cuh @@ -18,36 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::correlation_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::correlation_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh index c88bd1b0f6..39326c1f79 100644 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ b/cpp/include/raft/distance/specializations/detail/cosine.cuh @@ -18,36 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::cosine_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::cosine_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh index 3c5cad3315..7f1147c948 100644 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh @@ -18,36 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::hamming_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::hamming_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh index bf214c046f..e7ae32cd16 100644 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh @@ -18,37 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::hellinger_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::hellinger_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index 145834fb70..95158358bb 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -18,37 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::jensen_shannon_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::jensen_shannon_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh index f0928916cd..7dd5898ba7 100644 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh @@ -18,36 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::kl_divergence_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::kl_divergence_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh index 23261a2571..f3378af6a4 100644 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ b/cpp/include/raft/distance/specializations/detail/l1.cuh @@ -18,35 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +extern template void +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::l1_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::l1_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh index f953018b7d..9dd7d3ec6f 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh @@ -18,36 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::l2_exp_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::l2_exp_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh deleted file mode 100644 index 9f5f6a3706..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh deleted file mode 100644 index 94531ddc33..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh index 224b21fce8..d9f69bd426 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh @@ -18,36 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::l2_unexp_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::l2_unexp_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh index 9a46d7b488..e42271908d 100644 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ b/cpp/include/raft/distance/specializations/detail/l_inf.cuh @@ -18,35 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +extern template void +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::l_inf_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::l_inf_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh index e05ef02c42..58599c3a80 100644 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh @@ -18,36 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::lp_unexp_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::lp_unexp_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh index afc87997c0..4228b85ce6 100644 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh @@ -18,37 +18,46 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft +pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::russel_rao_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +extern template void +pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int, + raft::arch::SM_range>( + ops::russel_rao_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index 8daa398b49..a34f696e9e 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -27,8 +27,6 @@ #include #include #include -#include -#include #include #include #include diff --git a/cpp/src/distance/distance/specializations/detail/00_write_template.py b/cpp/src/distance/distance/specializations/detail/00_write_template.py new file mode 100644 index 0000000000..a3ea1b92e3 --- /dev/null +++ b/cpp/src/distance/distance/specializations/detail/00_write_template.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +template = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::identity_op +#include + +#include +#include // raft::arch::SM_compat_range + +namespace raft::distance::detail { + +template void +pairwise_matrix_dispatch( + OpT, + IdxT, + IdxT, + IdxT, + const DataT*, + const DataT*, + const DataT*, + const DataT*, + OutT*, + FinopT, + cudaStream_t , + bool); + + +} // namespace raft::distance::detail +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), + +] + + + + +op_instances = [ + dict( + path_prefix="canberra", + OpT="ops::canberra_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="correlation", + OpT="ops::correlation_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="cosine", + OpT="ops::cosine_distance_op", + # cosine uses CUTLASS for SM80+ + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="hamming_unexpanded", + OpT="ops::hamming_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="hellinger_expanded", + OpT="ops::hellinger_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="ops::jensen_shannon_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="kl_divergence", + OpT="ops::kl_divergence_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="l1", + OpT="ops::l1_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="l2_expanded", + OpT="ops::l2_exp_distance_op", + # L2 expanded uses CUTLASS for SM80+ + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="l2_unexpanded", + OpT="ops::l2_unexp_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="l_inf", + OpT="ops::l_inf_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="lp_unexpanded", + OpT="ops::lp_unexp_distance_op", + SM_compat_t="raft::arch::SM_range", + ), + dict( + path_prefix="russel_rao", + OpT="ops::russel_rao_distance_op", + SM_compat_t="raft::arch::SM_range", + ), +] + +def fill_in(s, template): + for k, v in template.items(): + s = s.replace(k, v) + return s + +for op_instance in op_instances: + for data_type_instance in data_type_instances: + op_data_instance = { + k : fill_in(v, data_type_instance) + for k, v in op_instance.items() + } + instance = { + **op_data_instance, + **data_type_instance, + "FinopT": "decltype(raft::identity_op())", + } + + text = fill_in(template, instance) + + path = fill_in("path_prefix_DataT_AccT_OutT_IdxT.cu", instance) + with open(path, "w") as f: + f.write(text) diff --git a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu index 4e9e608792..e575b6e6f7 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu @@ -14,24 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +#include +#include // raft::arch::SM_compat_range + +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::canberra_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu index 6dfc385e55..c2f94e5a32 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::canberra_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu index 2df77a4b5d..1661c8c968 100644 --- a/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu @@ -14,27 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { +#include +#include // raft::arch::SM_compat_range -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::correlation_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu index 76ed00afa6..672809c681 100644 --- a/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::correlation_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu index 3e0bcb92ed..23180715ed 100644 --- a/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu @@ -14,26 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { +#include +#include // raft::arch::SM_compat_range -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::cosine_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu index 23131ce2c7..609ad0cba9 100644 --- a/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu @@ -14,26 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { +#include +#include // raft::arch::SM_compat_range -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::cosine_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu index b618fd024c..07aeb6d160 100644 --- a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu @@ -14,27 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { +#include +#include // raft::arch::SM_compat_range -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::hamming_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu index 18e7aad9e9..dc8cf7f11d 100644 --- a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::hamming_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu index 08ab20cfe5..88ab818301 100644 --- a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu @@ -14,27 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { +#include +#include // raft::arch::SM_compat_range -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::hellinger_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu index 79eed075fb..ab5682c634 100644 --- a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu @@ -14,26 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { +#include +#include // raft::arch::SM_compat_range -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::hellinger_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index ed84ee6dc4..d19daf2409 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::jensen_shannon_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu index a241af767c..9a2db5fc4b 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::jensen_shannon_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu index c4c944d123..8a95650814 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::kl_divergence_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu index aa1db5a837..379cdd0ab7 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::kl_divergence_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu index 391a1c2aa4..db427c10da 100644 --- a/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::l1_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu index 7b45e52ca1..672f53fa8c 100644 --- a/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::l1_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu index 8c5f746fa2..f1428608ae 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu @@ -14,24 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +#include +#include // raft::arch::SM_compat_range + +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::l2_exp_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu index c266125f98..f8810be55f 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::l2_exp_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu deleted file mode 100644 index 399b120527..0000000000 --- a/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu deleted file mode 100644 index 66de212b8e..0000000000 --- a/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu deleted file mode 100644 index 562d93b2de..0000000000 --- a/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { - -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu deleted file mode 100644 index 386bbafc5f..0000000000 --- a/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu index 7733c3af48..b73104dcad 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::l2_unexp_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu index 4ea18d31de..99412148b8 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::l2_unexp_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu index 74414f8fd6..bbba0ee026 100644 --- a/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu @@ -14,26 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { +#include +#include // raft::arch::SM_compat_range -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::l_inf_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu index e418fc455f..d25774dffa 100644 --- a/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::l_inf_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu index 402cb51b7e..7e7d5e851d 100644 --- a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { +#include +#include // raft::arch::SM_compat_range -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::lp_unexp_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu index 7efe2b3349..401cfaada9 100644 --- a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::lp_unexp_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu index b1e6f5e1f4..4775b8b93e 100644 --- a/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu @@ -14,26 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + double, + double, + double, + decltype(raft::identity_op()), + int>(ops::russel_rao_distance_op, + int, + int, + int, + const double*, + const double*, + const double*, + const double*, + double*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu index 1e12bcd705..dcca59b6ca 100644 --- a/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu @@ -14,25 +14,30 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +#include +#include // raft::arch::SM_compat_range -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_dispatch, + float, + float, + float, + decltype(raft::identity_op()), + int>(ops::russel_rao_distance_op, + int, + int, + int, + const float*, + const float*, + const float*, + const float*, + float*, + decltype(raft::identity_op()), + cudaStream_t, + bool); + +} // namespace raft::distance::detail From cf5b23654d0d96d0bf5591493e3e196cbd95f576 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 15:43:40 +0100 Subject: [PATCH 70/93] Add instantiation point Also: - limit and document includes The following test (add this to cpp/test/CMakeLists.txt): configureTest( NAME pairwise_test PATH test/distance/gram.cu src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu test/distance/dist_adj.cu test/distance/dist_canberra.cu test/distance/dist_correlation.cu test/distance/dist_cos.cu test/distance/dist_hamming.cu test/distance/dist_hellinger.cu test/distance/dist_inner_product.cu test/distance/dist_jensen_shannon.cu test/distance/dist_kl_divergence.cu test/distance/dist_l1.cu test/distance/dist_l2_exp.cu test/distance/dist_l2_unexp.cu test/distance/dist_l2_sqrt_exp.cu test/distance/dist_l_inf.cu test/distance/dist_lp_unexp.cu test/distance/dist_russell_rao.cu src/distance/distance/specializations/detail/canberra_double_double_double_int.cu src/distance/distance/specializations/detail/canberra_float_float_float_int.cu src/distance/distance/specializations/detail/correlation_double_double_double_int.cu src/distance/distance/specializations/detail/correlation_float_float_float_int.cu src/distance/distance/specializations/detail/cosine_double_double_double_int.cu src/distance/distance/specializations/detail/cosine_float_float_float_int.cu src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu src/distance/distance/specializations/detail/l1_float_float_float_int.cu src/distance/distance/specializations/detail/l1_double_double_double_int.cu src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu ) target_compile_definitions(pairwise_test PUBLIC "RAFT_DISTANCE_COMPILED") has the following compile times: pairwise_test 0.5 seconds build.ninja 4.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu.o 18.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu.o 20.3 seconds akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu.o 21.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_double_double_double_int.cu.o 22.3 seconds akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu.o 23.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu.o 23.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu.o 24.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu.o 24.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu.o 26.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu.o 26.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu.o 27.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_float_float_float_int.cu.o 28.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu.o 29.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu.o 29.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu.o 30.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu.o 33.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu.o 34.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu.o 34.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu.o 37.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu.o 37.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu.o 38.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu.o 39.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu.o 39.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu.o 40.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu.o 44.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu.o 45.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu.o 46.2 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_canberra.cu.o 48.4 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_hellinger.cu.o 48.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu.o 48.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_russell_rao.cu.o 48.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_inner_product.cu.o 50.2 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_lp_unexp.cu.o 50.5 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_unexp.cu.o 50.8 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l1.cu.o 50.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l_inf.cu.o 52.0 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_jensen_shannon.cu.o 52.0 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_hamming.cu.o 52.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_kl_divergence.cu.o 54.3 seconds CMakeFiles/pairwise_test.dir/test/distance/gram.cu.o 54.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu.o 54.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu.o 59.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu.o 59.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu.o 59.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_sqrt_exp.cu.o 59.8 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_cos.cu.o 60.1 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_exp.cu.o 61.0 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_correlation.cu.o 61.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu.o 68.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu.o 77.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_adj.cu.o 122.4 seconds Generating script: from pathlib import Path from collections import Counter def parse_ninja_log(log_path): text = Path(log_path).read_text() start, end, mtime, path, cmd = list(zip(*[line.split("\t") for line in text.splitlines()[1:]])) start = list(map(int, start)) end = list(map(int, end)) seconds = [(e - s) / 1000. for e, s in zip(end, start)] mtime = list(map(int, mtime)) return dict( start=start, end=end, seconds=seconds, mtime=mtime, path=path, cmd=cmd ) def discard_earlier_builds(d): prev_end = 0 start_index = 0 # end must be monotonically increasing. If we find and end value that is # lower than the end value on the previous row, we know that a new build has # started. for i, end in enumerate(d['end']): if end < prev_end: start_index = i prev_end = end return {k: v[start_index:] for k, v in d.items()} log = discard_earlier_builds(parse_ninja_log("./cpp/build/.ninja_log")) times = dict(zip(log['path'], log['seconds'])) times = sorted(zip(log['path'], log['seconds']), key=lambda x: x[1]) for p, s in times: print(f"{p[-120:]:<120} {s:6.1f} seconds") --- cpp/include/raft/distance/detail/distance.cuh | 107 ++++++++---------- .../detail/pairwise_matrix/dispatch.cuh | 95 +++++++++++----- .../pairwise_matrix/dispatch_layout.cuh | 7 +- .../detail/pairwise_matrix/dispatch_sm60.cuh | 8 +- .../detail/pairwise_matrix/dispatch_sm80.cuh | 6 +- .../detail/00_write_template.py | 29 ++--- .../specializations/detail/canberra.cuh | 55 +++------ .../specializations/detail/correlation.cuh | 55 +++------ .../specializations/detail/cosine.cuh | 54 +++------ .../detail/hamming_unexpanded.cuh | 55 +++------ .../detail/hellinger_expanded.cuh | 55 +++------ .../specializations/detail/jensen_shannon.cuh | 55 +++------ .../specializations/detail/kl_divergence.cuh | 53 +++------ .../distance/specializations/detail/l1.cuh | 53 +++------ .../specializations/detail/l2_expanded.cuh | 54 +++------ .../specializations/detail/l2_unexpanded.cuh | 55 +++------ .../distance/specializations/detail/l_inf.cuh | 54 +++------ .../specializations/detail/lp_unexpanded.cuh | 55 +++------ .../specializations/detail/russel_rao.cuh | 55 +++------ .../detail/00_write_template.py | 81 +++++++------ .../canberra_double_double_double_int.cu | 32 ++---- .../detail/canberra_float_float_float_int.cu | 32 ++---- .../correlation_double_double_double_int.cu | 32 ++---- .../correlation_float_float_float_int.cu | 32 ++---- .../detail/cosine_double_double_double_int.cu | 33 +++--- .../detail/cosine_float_float_float_int.cu | 33 +++--- ...ing_unexpanded_double_double_double_int.cu | 32 ++---- ...amming_unexpanded_float_float_float_int.cu | 32 ++---- ...inger_expanded_double_double_double_int.cu | 32 ++---- ...ellinger_expanded_float_float_float_int.cu | 32 ++---- ...jensen_shannon_double_double_double_int.cu | 33 +++--- .../jensen_shannon_float_float_float_int.cu | 33 +++--- .../kl_divergence_double_double_double_int.cu | 32 ++---- .../kl_divergence_float_float_float_int.cu | 32 ++---- .../detail/l1_double_double_double_int.cu | 32 ++---- .../detail/l1_float_float_float_int.cu | 32 ++---- .../l2_expanded_double_double_double_int.cu | 33 +++--- .../l2_expanded_float_float_float_int.cu | 33 +++--- .../l2_unexpanded_double_double_double_int.cu | 32 ++---- .../l2_unexpanded_float_float_float_int.cu | 32 ++---- .../detail/l_inf_double_double_double_int.cu | 32 ++---- .../detail/l_inf_float_float_float_int.cu | 32 ++---- .../lp_unexpanded_double_double_double_int.cu | 32 ++---- .../lp_unexpanded_float_float_float_int.cu | 32 ++---- .../russel_rao_double_double_double_int.cu | 32 ++---- .../russel_rao_float_float_float_int.cu | 32 ++---- 46 files changed, 709 insertions(+), 1170 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 3abbf05cfe..7493c4e558 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -16,25 +16,18 @@ #pragma once -#include -#include - -#include -#include -#include -#include -#include - #include - +#include #include #include - +#include +#include #include #include -#include -#include -#include +#include +#include +#include +#include namespace raft { namespace distance { @@ -140,14 +133,14 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - AccT* norm_col_vec = workspace; - AccT* norm_row_vec = workspace; - AccT* sq_norm_col_vec = workspace; - AccT* sq_norm_row_vec = workspace; + AccT* x_norm = workspace; + AccT* y_norm = workspace; + AccT* sq_x_norm = workspace; + AccT* sq_y_norm = workspace; if (x != y) { - norm_row_vec += m; + y_norm += m; - raft::linalg::reduce(norm_col_vec, + raft::linalg::reduce(x_norm, x, k, m, @@ -158,7 +151,7 @@ void distance_impl(raft::resources const& handle, false, raft::identity_op(), raft::add_op()); - raft::linalg::reduce(norm_row_vec, + raft::linalg::reduce(y_norm, y, k, n, @@ -170,12 +163,12 @@ void distance_impl(raft::resources const& handle, raft::identity_op(), raft::add_op()); - sq_norm_col_vec += (m + n); - sq_norm_row_vec = sq_norm_col_vec + m; - raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); - raft::linalg::rowNorm(sq_norm_row_vec, y, k, n, raft::linalg::L2Norm, is_row_major, stream); + sq_x_norm += (m + n); + sq_y_norm = sq_x_norm + m; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); } else { - raft::linalg::reduce(norm_col_vec, + raft::linalg::reduce(x_norm, x, k, m, @@ -186,15 +179,15 @@ void distance_impl(raft::resources const& handle, false, raft::identity_op(), raft::add_op()); - sq_norm_col_vec += m; - sq_norm_row_vec = sq_norm_col_vec; - raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + sq_x_norm += m; + sq_y_norm = sq_x_norm; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); } using OpT = ops::correlation_distance_op; - OpT corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); + OpT corr_op(is_row_major, sq_x_norm, sq_y_norm, m, n, k); pairwise_matrix_dispatch( - corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); + corr_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -223,22 +216,22 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - DataT* norm_A = workspace; - DataT* norm_B = workspace; + DataT* x_norm = workspace; + DataT* y_norm = workspace; if (x != y) { - norm_B += m; + y_norm += m; raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); raft::linalg::rowNorm( - norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } else { raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } ops::cosine_distance_op distance_op{}; pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -389,10 +382,6 @@ void distance_impl(raft::resources const& handle, return (!x_zero) * raft::exp(input); }; - // This op takes some shortcuts when x equals y. So its behavior changes based - // on this. - ops::kl_divergence_op kl_divergence{is_row_major, x == y}; - if (x != y) { raft::linalg::unaryOp( (DataT*)y, y, n * k, unaryOp_lambda, stream); @@ -401,8 +390,12 @@ void distance_impl(raft::resources const& handle, const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; - pairwise_matrix_dispatch( - kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + // This op takes some shortcuts when x equals y. So its behavior changes based + // on this. + ops::kl_divergence_op distance_op{is_row_major, x == y}; + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); if (x != y) { // Now reverse previous log (x) back to x using (e ^ log(x)) @@ -464,22 +457,22 @@ void distance_impl_l2_expanded( // NOTE: different name "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); - DataT* norm_A = workspace; - DataT* norm_B = workspace; + DataT* x_norm = workspace; + DataT* y_norm = workspace; if (x != y) { - norm_B += m; + y_norm += m; raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); raft::linalg::rowNorm( - norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } else { raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } ops::l2_exp_distance_op distance_op{perform_sqrt}; pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -543,13 +536,13 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -571,13 +564,13 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 651490b6be..b5bed6e53d 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,41 +15,55 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the normal implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/distance/specializations/detail/. Previously, + * raft::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions + * do not require as large an include files set, which speeds up the build. + */ + +#include // ops::has_cutlass_op +#include // dispatch_sm60 +#include // pairwise_matrix_params + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include the correct +// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh +// and the specializations in src/distance/distance/specializations/detail/. namespace raft::distance::detail { +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance specializations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - pairwise_matrix_params params = - is_row_major ? make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, is_row_major) - : make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, is_row_major); + typename SM_compat_t> +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); +template +void pairwise_matrix_instantiation_point(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) +{ // On CUDA 12: // - always execute normal kernel // @@ -87,4 +101,31 @@ void pairwise_matrix_dispatch(OpT distance_op, } } +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Create kernel parameter struct. Flip x and y if column major. + pairwise_matrix_params params = + is_row_major ? make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, is_row_major) + : make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, is_row_major); + + pairwise_matrix_instantiation_point(distance_op, params, stream); +} + }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh index c1e4c08af4..dc58d0e2bf 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -15,9 +15,10 @@ */ #pragma once -#include "kernel_sm60.cuh" -#include -#include +#include // std::min +#include // size_t +#include // pairwise_matrix_params +#include // std::integral_constant namespace raft::distance::detail { diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh index 6e284007ea..cb0fd59da2 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh @@ -15,10 +15,10 @@ */ #pragma once -#include -#include -#include -#include +#include // std::min +#include // dispatch_layout +#include // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 namespace raft::distance::detail { diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh index ec2d522c25..6fafe381e5 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh @@ -15,9 +15,9 @@ */ #pragma once -#include // std::min -#include -#include +#include // std::min +#include // cutlassDistanceKernel +#include // dispatch_layout namespace raft::distance::detail { diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py index 364d3cb6cb..861264e3a0 100644 --- a/cpp/include/raft/distance/specializations/detail/00_write_template.py +++ b/cpp/include/raft/distance/specializations/detail/00_write_template.py @@ -30,26 +30,15 @@ """ -extern_template = """extern template void -pairwise_matrix_dispatch( - OpT, - IdxT, - IdxT, - IdxT, - const DataT*, - const DataT*, - const DataT*, - const DataT*, - OutT*, - FinopT, - cudaStream_t , - bool); +extern_template = """ +extern template void pairwise_matrix_instantiation_point( + OpT, + pairwise_matrix_params, + cudaStream_t); """ end_template = """} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh index 6f86b8bce5..c1eb140a45 100644 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ b/cpp/include/raft/distance/specializations/detail/canberra.cuh @@ -21,43 +21,22 @@ namespace raft::distance::detail { extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::canberra_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::canberra_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh index ec1fe25e1d..2aec977be4 100644 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ b/cpp/include/raft/distance/specializations/detail/correlation.cuh @@ -21,43 +21,22 @@ namespace raft::distance::detail { extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::correlation_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::correlation_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh index 39326c1f79..92317f0de6 100644 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ b/cpp/include/raft/distance/specializations/detail/cosine.cuh @@ -20,44 +20,22 @@ namespace raft::distance::detail { -extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( ops::cosine_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::cosine_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh index 7f1147c948..be06070514 100644 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh @@ -21,43 +21,22 @@ namespace raft::distance::detail { extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::hamming_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::hamming_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh index e7ae32cd16..b7d9dac1a1 100644 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh @@ -21,43 +21,22 @@ namespace raft::distance::detail { extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::hellinger_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::hellinger_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index 95158358bb..b51cc32b62 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -21,43 +21,22 @@ namespace raft::distance::detail { extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::jensen_shannon_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::jensen_shannon_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh index 7dd5898ba7..5e1a125dea 100644 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh @@ -20,44 +20,21 @@ namespace raft::distance::detail { -extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( ops::kl_divergence_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); -extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( ops::kl_divergence_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh index f3378af6a4..c44953bf02 100644 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ b/cpp/include/raft/distance/specializations/detail/l1.cuh @@ -20,44 +20,21 @@ namespace raft::distance::detail { -extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( ops::l1_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); -extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( ops::l1_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh index 9dd7d3ec6f..5e427af021 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh @@ -20,44 +20,22 @@ namespace raft::distance::detail { -extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( ops::l2_exp_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::l2_exp_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh index d9f69bd426..840760c4db 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh @@ -21,43 +21,22 @@ namespace raft::distance::detail { extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::l2_unexp_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::l2_unexp_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh index e42271908d..b10d1b8098 100644 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ b/cpp/include/raft/distance/specializations/detail/l_inf.cuh @@ -20,44 +20,22 @@ namespace raft::distance::detail { -extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( ops::l_inf_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::l_inf_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh index 58599c3a80..e7632ead6c 100644 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh @@ -21,43 +21,22 @@ namespace raft::distance::detail { extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::lp_unexp_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::lp_unexp_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh index 4228b85ce6..0c6f4c993e 100644 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh @@ -21,43 +21,22 @@ namespace raft::distance::detail { extern template void -pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::russel_rao_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); + extern template void -pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int, - raft::arch::SM_range>( - ops::russel_rao_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/00_write_template.py b/cpp/src/distance/distance/specializations/detail/00_write_template.py index a3ea1b92e3..81b8731546 100644 --- a/cpp/src/distance/distance/specializations/detail/00_write_template.py +++ b/cpp/src/distance/distance/specializations/detail/00_write_template.py @@ -19,33 +19,22 @@ */ #include // raft::identity_op -#include +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +INCLUDE_SM_HEADERS #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void -pairwise_matrix_dispatch( - OpT, - IdxT, - IdxT, - IdxT, - const DataT*, - const DataT*, - const DataT*, - const DataT*, - OutT*, - FinopT, - cudaStream_t , - bool); - +template void pairwise_matrix_instantiation_point( + OpT, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail """ @@ -63,81 +52,75 @@ OutT="double", IdxT="int", ), - ] - - - op_instances = [ dict( path_prefix="canberra", OpT="ops::canberra_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="correlation", OpT="ops::correlation_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="cosine", OpT="ops::cosine_distance_op", - # cosine uses CUTLASS for SM80+ - SM_compat_t="raft::arch::SM_range", + archs = [60, 80], ), dict( path_prefix="hamming_unexpanded", OpT="ops::hamming_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="hellinger_expanded", OpT="ops::hellinger_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), # inner product is handled by cublas. dict( path_prefix="jensen_shannon", OpT="ops::jensen_shannon_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="kl_divergence", OpT="ops::kl_divergence_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="l1", OpT="ops::l1_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="l2_expanded", OpT="ops::l2_exp_distance_op", - # L2 expanded uses CUTLASS for SM80+ - SM_compat_t="raft::arch::SM_range", + archs = [60, 80], ), dict( path_prefix="l2_unexpanded", OpT="ops::l2_unexp_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="l_inf", OpT="ops::l_inf_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="lp_unexpanded", OpT="ops::lp_unexp_distance_op", - SM_compat_t="raft::arch::SM_range", + archs = [60], ), dict( path_prefix="russel_rao", OpT="ops::russel_rao_distance_op", - SM_compat_t="raft::arch::SM_range", - ), + archs = [60], + ), ] def fill_in(s, template): @@ -145,7 +128,21 @@ def fill_in(s, template): s = s.replace(k, v) return s +def fill_include_sm_headers(op_instance): + include_headers ="\n".join([ + f"#include " + for arch in op_instance["archs"] + ]) + + return { + "path_prefix": op_instance["path_prefix"], + "OpT": op_instance["OpT"], + "INCLUDE_SM_HEADERS": include_headers + } + for op_instance in op_instances: + op_instance = fill_include_sm_headers(op_instance) + for data_type_instance in data_type_instances: op_data_instance = { k : fill_in(v, data_type_instance) diff --git a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu index e575b6e6f7..71ce79ad28 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::canberra_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu index c2f94e5a32..84c1cfe4e2 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::canberra_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu index 1661c8c968..d684273826 100644 --- a/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::correlation_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu index 672809c681..c83bb2b204 100644 --- a/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::correlation_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu index 23180715ed..202ee96ee5 100644 --- a/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu @@ -14,30 +14,23 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::cosine_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu index 609ad0cba9..6b221aa2b5 100644 --- a/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu @@ -14,30 +14,23 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::cosine_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu index 07aeb6d160..a1a3ebc601 100644 --- a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::hamming_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu index dc8cf7f11d..8d596db93b 100644 --- a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::hamming_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu index 88ab818301..cd1b37de7e 100644 --- a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::hellinger_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu index ab5682c634..b67121f6af 100644 --- a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::hellinger_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index d19daf2409..738a9406be 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -14,30 +14,23 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::jensen_shannon_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu index 9a2db5fc4b..1685494010 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu @@ -14,30 +14,23 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::jensen_shannon_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu index 8a95650814..c3a77c7a8f 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::kl_divergence_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu index 379cdd0ab7..75c17fdb10 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::kl_divergence_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu index db427c10da..516384c967 100644 --- a/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::l1_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu index 672f53fa8c..a3535a75a6 100644 --- a/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::l1_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu index f1428608ae..474c031e01 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu @@ -14,30 +14,23 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::l2_exp_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu index f8810be55f..334a367453 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu @@ -14,30 +14,23 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::l2_exp_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu index b73104dcad..41a70341d0 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::l2_unexp_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu index 99412148b8..ac27e35d01 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::l2_unexp_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu index bbba0ee026..4e06d0264a 100644 --- a/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::l_inf_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu index d25774dffa..c19a8e6016 100644 --- a/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::l_inf_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu index 7e7d5e851d..c3c8d2b96f 100644 --- a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::lp_unexp_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu index 401cfaada9..ec8317d9d4 100644 --- a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::lp_unexp_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu index 4775b8b93e..d842cebd44 100644 --- a/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - double, - double, - double, - decltype(raft::identity_op()), - int>(ops::russel_rao_distance_op, - int, - int, - int, - const double*, - const double*, - const double*, - const double*, - double*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu index dcca59b6ca..179599f549 100644 --- a/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu @@ -14,30 +14,22 @@ * limitations under the License. */ -#include // raft::identity_op -#include +#include // raft::identity_op +#include // ops::* -#include +#include // pairwise_matrix_instantiation_point +#include #include // raft::arch::SM_compat_range namespace raft::distance::detail { -template void pairwise_matrix_dispatch, - float, - float, - float, - decltype(raft::identity_op()), - int>(ops::russel_rao_distance_op, - int, - int, - int, - const float*, - const float*, - const float*, - const float*, - float*, - decltype(raft::identity_op()), - cudaStream_t, - bool); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail From da2eb6958ff3c0a8ac16a9c2ddaf0d78e28ac90f Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 17:22:30 +0100 Subject: [PATCH 71/93] Add *_essentials headers This limits the number of unneeded headers included in the specializations. Results: pairwise_test 0.5 seconds build.ninja 4.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu.o 7.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_double_double_double_int.cu.o 7.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu.o 8.1 seconds akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu.o 8.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu.o 8.5 seconds akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu.o 9.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu.o 9.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu.o 9.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu.o 13.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu.o 13.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu.o 13.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu.o 14.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu.o 14.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu.o 15.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu.o 20.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu.o 24.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu.o 28.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu.o 31.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu.o 32.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu.o 34.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu.o 35.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu.o 37.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu.o 38.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu.o 39.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu.o 40.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu.o 42.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu.o 42.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu.o 43.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu.o 44.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu.o 45.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu.o 46.5 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_inner_product.cu.o 48.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_lp_unexp.cu.o 50.4 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_hamming.cu.o 50.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_hellinger.cu.o 51.0 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l_inf.cu.o 51.2 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_canberra.cu.o 51.5 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_unexp.cu.o 51.8 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l1.cu.o 52.2 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_russell_rao.cu.o 52.8 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_jensen_shannon.cu.o 53.4 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_kl_divergence.cu.o 53.9 seconds CMakeFiles/pairwise_test.dir/test/distance/gram.cu.o 55.6 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_exp.cu.o 58.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu.o 59.6 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_sqrt_exp.cu.o 59.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_cos.cu.o 60.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_correlation.cu.o 61.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu.o 72.6 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_adj.cu.o 120.5 seconds --- .../distance/detail/distance_ops/canberra.cuh | 2 +- .../detail/distance_ops/correlation.cuh | 2 +- .../distance/detail/distance_ops/cosine.cuh | 2 +- .../distance/detail/distance_ops/hamming.cuh | 2 +- .../detail/distance_ops/hellinger.cuh | 2 +- .../detail/distance_ops/jensen_shannon.cuh | 2 +- .../detail/distance_ops/kl_divergence.cuh | 2 +- .../raft/distance/detail/distance_ops/l1.cuh | 2 +- .../distance/detail/distance_ops/l2_exp.cuh | 2 +- .../distance/detail/distance_ops/l2_unexp.cuh | 2 +- .../distance/detail/distance_ops/l_inf.cuh | 2 +- .../distance/detail/distance_ops/lp_unexp.cuh | 2 +- .../detail/distance_ops/russel_rao.cuh | 2 +- .../distance/detail/distance_ops/template.cuh | 2 +- .../detail/pairwise_distance_base.cuh | 15 ++- .../pairwise_matrix/dispatch_layout.cuh | 2 +- .../detail/pairwise_matrix/kernel_sm60.cuh | 35 +----- cpp/include/raft/util/cuda_dev_essentials.cuh | 91 +++++++++++++++ cpp/include/raft/util/cuda_rt_essentials.hpp | 60 ++++++++++ cpp/include/raft/util/cuda_utils.cuh | 105 +----------------- cpp/include/raft/util/cudart_utils.hpp | 38 +------ cpp/include/raft/util/device_loads_stores.cuh | 4 +- 22 files changed, 185 insertions(+), 193 deletions(-) create mode 100644 cpp/include/raft/util/cuda_dev_essentials.cuh create mode 100644 cpp/include/raft/util/cuda_rt_essentials.hpp diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 930294ce31..2215ded8e2 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 289b69070a..8cbca6ef75 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index 7c37c27b4e..c103cf6121 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 1cfdcfdc73..2495233dee 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index c4aecc7a6f..0b01a0e967 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index 41eeb9dd83..d82dfe8463 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index d046b62c30..8f3260a799 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 8ec4000827..5330be4f0c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 2a7af53813..cb7702396a 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index f0ea591eaf..f8105462a1 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index fb21fb1a21..108c0cd8ef 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index 71dfd51a6e..fa1048b753 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index ea09e4d1db..745251771f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index 6998f3cad4..e4aa281776 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index d5779b2eaf..a051bdf4cd 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -14,12 +14,11 @@ * limitations under the License. */ #pragma once -#include -#include -#include -#include +#include // raft::linalg::Contractions_NT +#include // ceildiv +#include // RAFT_CUDA_TRY -#include +#include // size_t namespace raft { namespace distance { @@ -272,7 +271,11 @@ struct PairwiseDistances : public BaseClass { template dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { - const auto numSMs = raft::getMultiProcessorCount(); + int devId; + RAFT_CUDA_TRY(cudaGetDevice(&devId)); + int numSMs; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId)); + int numBlocksPerSm = 0; dim3 grid; diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh index dc58d0e2bf..08f5155cf4 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -17,9 +17,9 @@ #include // std::min #include // size_t +#include // RAFT_EXPECTS #include // pairwise_matrix_params #include // std::integral_constant - namespace raft::distance::detail { /** diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 6e3ab7b26b..410dfa1080 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -15,11 +15,11 @@ */ #pragma once -#include -#include -#include -#include -#include +#include // assert +#include // raft::void_op +#include // PairwiseDistances +#include // pairwise_matrix_params +#include // raft::arch::SM_compute_arch namespace raft::distance::detail { @@ -97,31 +97,6 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( obj.run(); } -template -void pairwise_matrix(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) -{ - dim3 blk(Policy::Nthreads); - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - size_t smem_size = distance_op.template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - kernel<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); -} - // The type of a pointer to the pairwise matrix kernel. The following template // arguments are type-erased: // diff --git a/cpp/include/raft/util/cuda_dev_essentials.cuh b/cpp/include/raft/util/cuda_dev_essentials.cuh new file mode 100644 index 0000000000..5080dc33ee --- /dev/null +++ b/cpp/include/raft/util/cuda_dev_essentials.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file provides a few essential functions for use in __device__ code. The +// scope is necessarily limited to ensure that compilation times are minimized. +// Please make sure not to include large / expensive files from here. + +namespace raft { + +/** helper macro for device inlined functions */ +#define DI inline __device__ +#define HDI inline __host__ __device__ +#define HD __host__ __device__ + +/** + * @brief Provide a ceiling division operation ie. ceil(a / b) + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType ceildiv(IntType a, IntType b) +{ + return (a + b - 1) / b; +} + +/** + * @brief Provide an alignment function ie. ceil(a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType alignTo(IntType a, IntType b) +{ + return ceildiv(a, b) * b; +} + +/** + * @brief Provide an alignment function ie. (a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType alignDown(IntType a, IntType b) +{ + return (a / b) * b; +} + +/** + * @brief Check if the input is a power of 2 + * @tparam IntType data type (checked only for integers) + */ +template +constexpr HDI bool isPo2(IntType num) +{ + return (num && !(num & (num - 1))); +} + +/** + * @brief Give logarithm of the number to base-2 + * @tparam IntType data type (checked only for integers) + */ +template +constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) +{ + return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); +} + +/** number of threads per warp */ +static const int WarpSize = 32; + +/** get the laneId of the current thread */ +DI int laneId() +{ + int id; + asm("mov.s32 %0, %%laneid;" : "=r"(id)); + return id; +} + +} // namespace raft diff --git a/cpp/include/raft/util/cuda_rt_essentials.hpp b/cpp/include/raft/util/cuda_rt_essentials.hpp new file mode 100644 index 0000000000..e5f3af4e61 --- /dev/null +++ b/cpp/include/raft/util/cuda_rt_essentials.hpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file provides a few essential functions that wrap the CUDA runtime API. +// The scope is necessarily limited to ensure that compilation times are +// minimized. Please make sure not to include large / expensive files from here. + +#include +#include + +namespace raft { + +/** + * @brief Exception thrown when a CUDA error is encountered. + */ +struct cuda_error : public raft::exception { + explicit cuda_error(char const* const message) : raft::exception(message) {} + explicit cuda_error(std::string const& message) : raft::exception(message) {} +}; + +} // namespace raft + +/** + * @brief Error checking macro for CUDA runtime API functions. + * + * Invokes a CUDA runtime API function call, if the call does not return + * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an + * exception detailing the CUDA error that occurred + * + */ +#define RAFT_CUDA_TRY(call) \ + do { \ + cudaError_t const status = call; \ + if (status != cudaSuccess) { \ + cudaGetLastError(); \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "CUDA error encountered at: ", \ + "call='%s', Reason=%s:%s", \ + #call, \ + cudaGetErrorName(status), \ + cudaGetErrorString(status)); \ + throw raft::cuda_error(msg); \ + } \ + } while (0) diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 5be9dc999a..687a6b4651 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -23,113 +23,10 @@ #include #include #include - -#ifndef ENABLE_MEMCPY_ASYNC -// enable memcpy_async interface by default for newer GPUs -#if __CUDA_ARCH__ >= 800 -#define ENABLE_MEMCPY_ASYNC 1 -#endif -#else // ENABLE_MEMCPY_ASYNC -// disable memcpy_async for all older GPUs -#if __CUDA_ARCH__ < 800 -#define ENABLE_MEMCPY_ASYNC 0 -#endif -#endif // ENABLE_MEMCPY_ASYNC +#include namespace raft { -/** helper macro for device inlined functions */ -#define DI inline __device__ -#define HDI inline __host__ __device__ -#define HD __host__ __device__ - -/** - * @brief Provide a ceiling division operation ie. ceil(a / b) - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType ceildiv(IntType a, IntType b) -{ - return (a + b - 1) / b; -} - -/** - * @brief Provide an alignment function ie. ceil(a / b) * b - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType alignTo(IntType a, IntType b) -{ - return ceildiv(a, b) * b; -} - -/** - * @brief Provide an alignment function ie. (a / b) * b - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType alignDown(IntType a, IntType b) -{ - return (a / b) * b; -} - -/** - * @brief Check if the input is a power of 2 - * @tparam IntType data type (checked only for integers) - */ -template -constexpr HDI bool isPo2(IntType num) -{ - return (num && !(num & (num - 1))); -} - -/** - * @brief Give logarithm of the number to base-2 - * @tparam IntType data type (checked only for integers) - */ -template -constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) -{ - return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); -} - -/** Device function to apply the input lambda across threads in the grid */ -template -DI void forEach(int num, L lambda) -{ - int idx = (blockDim.x * blockIdx.x) + threadIdx.x; - const int numThreads = blockDim.x * gridDim.x; -#pragma unroll - for (int itr = 0; itr < ItemsPerThread; ++itr, idx += numThreads) { - if (idx < num) lambda(idx, itr); - } -} - -/** number of threads per warp */ -static const int WarpSize = 32; - -/** get the laneId of the current thread */ -DI int laneId() -{ - int id; - asm("mov.s32 %0, %%laneid;" : "=r"(id)); - return id; -} - -/** - * @brief Swap two values - * @tparam T the datatype of the values - * @param a first input - * @param b second input - */ -template -HDI void swapVals(T& a, T& b) -{ - T tmp = a; - a = b; - b = tmp; -} - /** Device function to have atomic add support for older archs */ template DI void myAtomicAdd(Type* address, Type val) diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 0feb188ad8..0a7ca23028 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -25,6 +25,7 @@ #pragma once #include +#include #include #include #include @@ -40,42 +41,7 @@ #include #include #include - -namespace raft { - -/** - * @brief Exception thrown when a CUDA error is encountered. - */ -struct cuda_error : public raft::exception { - explicit cuda_error(char const* const message) : raft::exception(message) {} - explicit cuda_error(std::string const& message) : raft::exception(message) {} -}; - -} // namespace raft - -/** - * @brief Error checking macro for CUDA runtime API functions. - * - * Invokes a CUDA runtime API function call, if the call does not return - * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an - * exception detailing the CUDA error that occurred - * - */ -#define RAFT_CUDA_TRY(call) \ - do { \ - cudaError_t const status = call; \ - if (status != cudaSuccess) { \ - cudaGetLastError(); \ - std::string msg{}; \ - SET_ERROR_MSG(msg, \ - "CUDA error encountered at: ", \ - "call='%s', Reason=%s:%s", \ - #call, \ - cudaGetErrorName(status), \ - cudaGetErrorString(status)); \ - throw raft::cuda_error(msg); \ - } \ - } while (0) +#include // FIXME: Remove after consumers rename #ifndef CUDA_TRY diff --git a/cpp/include/raft/util/device_loads_stores.cuh b/cpp/include/raft/util/device_loads_stores.cuh index 2b87c44d60..4344201fa4 100644 --- a/cpp/include/raft/util/device_loads_stores.cuh +++ b/cpp/include/raft/util/device_loads_stores.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include namespace raft { From 9370a2980d5f9f7da4af89b20663ec8297b8cee8 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 18:26:21 +0100 Subject: [PATCH 72/93] Decouple test and pairwise distance code A change in the pairwise_distance code will not trigger a rebuild of all the tests any more. pairwise_test 0.5 seconds build.ninja 4.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu.o 7.3 seconds akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu.o 7.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu.o 8.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_double_double_double_int.cu.o 8.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu.o 8.3 seconds akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu.o 8.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu.o 9.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu.o 10.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu.o 11.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_float_float_float_int.cu.o 12.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu.o 13.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu.o 13.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu.o 13.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu.o 13.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu.o 14.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu.o 19.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu.o 22.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu.o 26.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu.o 31.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu.o 31.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu.o 34.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu.o 35.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu.o 36.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu.o 39.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu.o 39.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu.o 39.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu.o 42.4 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_inner_product.cu.o 42.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_cos.cu.o 42.8 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_sqrt_exp.cu.o 42.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_exp.cu.o 43.0 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_lp_unexp.cu.o 43.4 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l1.cu.o 43.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu.o 43.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_unexp.cu.o 43.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_russell_rao.cu.o 43.8 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_canberra.cu.o 43.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu.o 44.2 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l_inf.cu.o 44.2 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_hamming.cu.o 44.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu.o 44.5 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_jensen_shannon.cu.o 44.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu.o 44.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_correlation.cu.o 44.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_kl_divergence.cu.o 45.0 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_hellinger.cu.o 45.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu.o 45.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/pairwise_distance.cu.o 56.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu.o 57.2 seconds CMakeFiles/pairwise_test.dir/test/distance/gram.cu.o 57.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu.o 70.1 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_adj.cu.o 116.8 seconds --- cpp/test/distance/distance_base.cuh | 73 +++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 10 deletions(-) diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 5fcaf07539..ae8230984a 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -16,16 +16,23 @@ #include "../test_utils.cuh" #include -#include -#include -#include -#include -#include -#include +#include // common::nvtx::range + +#include // raft::device_resources +#include // raft::sqrt +#include // raft::distance::DistanceType +#include +#include // rmm::device_uvector + +// When the distance library is precompiled, include only the raft_runtime +// headers. This way, a small change in one of the kernel internals does not +// trigger a rebuild of the test files (it of course still triggers a rebuild of +// the raft specializations) #if defined RAFT_DISTANCE_COMPILED -#include +#include +#else +#include #endif -#include namespace raft { namespace distance { @@ -409,6 +416,25 @@ template return os; } +// TODO: Remove when mdspan-based raft::runtime::distance::pairwise_distance is +// implemented. +// +// Context: +// https://github.com/rapidsai/raft/issues/1338 +template +constexpr bool layout_to_row_major(); + +template <> +constexpr bool layout_to_row_major() +{ + return true; +} +template <> +constexpr bool layout_to_row_major() +{ + return false; +} + template void distanceLauncher(raft::device_resources const& handle, DataType* x, @@ -422,12 +448,23 @@ void distanceLauncher(raft::device_resources const& handle, DataType threshold, DataType metric_arg = 2.0f) { +#if defined RAFT_DISTANCE_COMPILED + // TODO: Implement and use mdspan-based + // raft::runtime::distance::pairwise_distance here. + // + // Context: + // https://github.com/rapidsai/raft/issues/1338 + bool row_major = layout_to_row_major(); + raft::runtime::distance::pairwise_distance( + handle, x, y, dist, m, n, k, distanceType, row_major, metric_arg); +#else auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); raft::distance::distance( handle, x_v, y_v, dist_v, metric_arg); +#endif } template @@ -523,9 +560,25 @@ class BigMatrixDistanceTest : public ::testing::Test { auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + void pairwise_distance(raft::device_resources const& handle, + float* x, + float* y, + float* dists, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + bool row_major = true; + float metric_arg = 0.0f; +#if defined RAFT_DISTANCE_COMPILED + raft::runtime::distance::pairwise_distance( + handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); +#else raft::distance::distance( - handle, x.data(), x.data(), dist.data(), m, n, k, true, 0.0f); - + handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); +#endif RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } From 14a9477c6ac84ae7a955831d79d358f32d7086ff Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 14 Mar 2023 20:50:23 +0100 Subject: [PATCH 73/93] Take distance_op in pairwise_distance_base Fixes issue #1323 --- cpp/include/raft/core/kvp.hpp | 2 +- .../raft/distance/detail/fused_l2_nn.cuh | 133 ++--- .../detail/pairwise_distance_base.cuh | 37 +- .../detail/pairwise_matrix/kernel_sm60.cuh | 28 +- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 520 +++++++++--------- cpp/test/distance/fused_l2_nn.cu | 30 +- 6 files changed, 325 insertions(+), 425 deletions(-) diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index 8d3321eb77..192d160d45 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -20,7 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include -#include +#include // raft::shfl_xor #endif namespace raft { /** diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 8fbd7a9c69..be6fed9f10 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -16,23 +16,20 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy +#include // raft::ceildiv, raft::shfl namespace raft { namespace distance { namespace detail { -#if (ENABLE_MEMCPY_ASYNC == 1) -#include -using namespace nvcuda::experimental; -#endif - template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; @@ -124,11 +121,10 @@ DI void updateReducedVal( template __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, const DataT* x, @@ -142,7 +138,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, int* mutex, ReduceOpT redOp, KVPReduceOpT pairRedOp, - CoreLambda core_op, + OpT distance_op, FinalLambda fin_op) { extern __shared__ char smem[]; @@ -163,24 +159,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (Sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; - } - } - } - // intra thread reduce const auto acccolid = threadIdx.x % P::AccThCols; const auto accrowid = threadIdx.x / P::AccThCols; @@ -229,18 +207,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, }; IdxT lda = k, ldb = k, ldd = n; - PairwiseDistances + row_major, + write_out> obj(x, y, m, @@ -251,9 +229,9 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, ldd, xn, yn, - nullptr, + nullptr, // Output pointer smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -289,9 +267,6 @@ void fusedL2NNImpl(OutT* min, constexpr auto maxVal = std::numeric_limits::max(); typedef KeyValuePair KVPair; - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel @@ -300,59 +275,25 @@ void fusedL2NNImpl(OutT* min, } constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } else { - auto fusedL2NN = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedL2NNkernel; + + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index a051bdf4cd..583476ede6 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -26,16 +26,12 @@ namespace detail { /** * @brief Device class for L1, L2 and cosine distance metrics. - * @tparam useNorms whether norms are needed * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) + * @tparam OpT A distance operation, e.g., cosine_distance_op. * @tparam EpilogueLambda applies an elementwise function to compute final values. Its signature is: template void epilogue_lambda @@ -53,19 +49,17 @@ namespace detail { * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine * @param[output] pD output matrix * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. - * @param core_op the core accumulation operation lambda + * @param distance_op the distance operation, e.g. cosine_distance_op * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ -template > struct PairwiseDistances : public BaseClass { + // Get accumulation type from distance_op + using AccT = typename OpT::AccT; + private: typedef Policy P; const DataT* xn; @@ -80,7 +77,7 @@ struct PairwiseDistances : public BaseClass { const DataT* const yBase; OutT* dOutput; char* smem; - CoreLambda core_op; + OpT distance_op; EpilogueLambda epilog_op; FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; @@ -106,7 +103,7 @@ struct PairwiseDistances : public BaseClass { const DataT* _yn, OutT* _dOutput, char* _smem, - CoreLambda _core_op, + OpT _distance_op, EpilogueLambda _epilog_op, FinalLambda _fin_op, rowEpilogueLambda _rowEpilog_op) @@ -116,7 +113,7 @@ struct PairwiseDistances : public BaseClass { yBase(_y), dOutput(_dOutput), smem(_smem), - core_op(_core_op), + distance_op(_distance_op), epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), @@ -156,15 +153,25 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); // Epilog: - if (useNorms) { + if (distance_op.use_norms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { store_output(tile_idx_m, tile_idx_n); } @@ -209,7 +216,7 @@ struct PairwiseDistances : public BaseClass { for (int j = 0; j < P::AccColsPerTh; ++j) { #pragma unroll for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); + distance_op.core(acc[i][j], this->regx[i][v], this->regy[j][v]); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 410dfa1080..b298391ef2 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -43,36 +43,20 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( extern __shared__ char smem[]; - using AccT = typename OpT::AccT; - - // Wrap operator back into lambdas. This is temporary and should be removed. - // See: https://github.com/rapidsai/raft/issues/1323 - auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - distance_op.core(acc, x, y); - }; - auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); - }; - + // The epilog is already provided by distance_op. Do not provide additional + // epilogs. + auto epilog_op = raft::void_op(); // No support for row_epilog_op. auto row_epilog_op = raft::void_op(); // Always write output constexpr bool write_out = true; constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances #include +#include +#include #include #include @@ -183,13 +185,11 @@ DI void updateSortedWarpQ( } } -template Pair; @@ -223,295 +223,275 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; - auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( - IdxT gridStrideY) { - if (gridDim.x == 1) { return; } - - Pair* shDumpKV = nullptr; - if (useNorms) { - shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } - - const int lid = threadIdx.x % warpSize; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - - // 0 -> consumer done consuming the buffer. - // -1 -> consumer started consuming the buffer - // -2 -> producer done filling the buffer - // 1 -> prod acquired to fill the buffer - if (blockIdx.x == 0) { - auto cta_processed = 0; - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - __syncwarp(); - - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - - while (cta_processed < gridDim.x - 1) { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) - ; - } - __threadfence(); - __syncthreads(); + auto rowEpilog_lambda = + [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { + if (gridDim.x == 1) { return; } + + int smem_offset = distance_op.template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + const int lid = threadIdx.x % warpSize; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + + // 0 -> consumer done consuming the buffer. + // -1 -> consumer started consuming the buffer + // -2 -> producer done filling the buffer + // 1 -> prod acquired to fill the buffer + if (blockIdx.x == 0) { + auto cta_processed = 0; + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + __syncwarp(); + + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + + while (cta_processed < gridDim.x - 1) { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - otherKV.value = out_dists[rowId * numOfNN + idx]; - otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } } } } - } - __threadfence(); - __syncthreads(); + __threadfence(); + __syncthreads(); - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } - __threadfence(); + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); // Perform merging of otherKV with topk's across warp. #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); } } + cta_processed++; } - cta_processed++; - } #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(0xffffffff, needSort); - if (needSort) { heapArr[i]->reduce(); } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(0xffffffff, needSort); + if (needSort) { heapArr[i]->reduce(); } + } } - } - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } else { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) - ; - } - __threadfence(); - __syncthreads(); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } else { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - for (int idx = lid; idx < numOfNN; idx += warpSize) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; - out_dists[rowId * numOfNN + idx] = KVPair.value; - out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + for (int idx = lid; idx < numOfNN; idx += warpSize) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + out_dists[rowId * numOfNN + idx] = KVPair.value; + out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + } } } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } - __threadfence(); - } - }; + __threadfence(); + __syncthreads(); - // epilogue operation lambda for final value calculation - auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - if (useNorms) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); } - } + }; - Pair* shDumpKV = nullptr; - if (useNorms) { - constexpr size_t shmemSize = - Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - shDumpKV = (Pair*)(&smem[shmemSize]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } + // epilogue operation lambda for final value calculation + auto epilog_lambda = + [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + int smem_offset = distance_op.template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); - - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + } } - } - if (gridStrideX > blockIdx.x * Policy::Nblk) { + if (gridStrideX > blockIdx.x * Policy::Nblk) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + } } + anyWarpTopKs += numValsWarpTopK[i]; } - anyWarpTopKs += numValsWarpTopK[i]; } - } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair* allWarpTopKs = (Pair*)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair* allWarpTopKs = (Pair*)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { #pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { numValsWarpTopK[i] += n; } + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { numValsWarpTopK[i] += n; } + } } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; - } - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; - numValsWarpTopK[i]++; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; + numValsWarpTopK[i]++; + } } } } + __syncwarp(); + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); + updateSortedWarpQ( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } - __syncwarp(); - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); - updateSortedWarpQ( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + } } } } - } - } else { + } else { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); - } - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { heapArr[i]->reduce(); } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + } } } - } - if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - }; + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + }; - raft::distance::detail::PairwiseDistances + write_out> obj(x, y, m, @@ -522,9 +502,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x ldd, _xn, _yn, - nullptr, + nullptr, // output ptr, can be null as write_out == false. smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -563,38 +543,32 @@ void fusedL2UnexpKnnImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = x - y; - acc += diff * diff; - }; - typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; if (numOfNN <= 32) { @@ -605,8 +579,10 @@ void fusedL2UnexpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - dim3 grid = raft::distance::detail::launchConfigGenerator( + const auto sharedMemSize = + distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); + + dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); if (grid.x > 1) { @@ -629,9 +605,8 @@ void fusedL2UnexpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, (int*)workspace, out_dists, @@ -754,36 +729,33 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(workspace != nullptr, "workspace is null"); dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; if (numOfNN <= 32) { @@ -794,9 +766,8 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + - ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + - (KPolicy::Mblk * numOfNN * sizeof(Pair)); + const auto sharedMemSize = + distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2ExpKnnRowMajor); int32_t* mutexes = nullptr; @@ -836,9 +807,8 @@ void fusedL2ExpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, mutexes, out_dists, diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index af67214193..adb73cb9b2 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -182,22 +182,20 @@ class FusedL2NNTest : public ::testing::TestWithParam> { int m = params.m; int n = params.n; int k = params.k; - MinAndDistanceReduceOp redOp; - fusedL2NN, int>( - out, - x.data(), - y.data(), - xn.data(), - yn.data(), - m, - n, - k, - (void*)workspace.data(), - redOp, - raft::distance::KVPMinReduce(), - Sqrt, - true, - stream); + + const bool init_out_buffer = true; + fusedL2NNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + Sqrt, + init_out_buffer, + stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } }; From f54e7a41482765aca175d3dd60a0c018a0a131c4 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 15 Mar 2023 16:40:10 +0100 Subject: [PATCH 74/93] Add tuning benchmark for pairwise distances --- cpp/bench/CMakeLists.txt | 5 + cpp/bench/distance/tune_pairwise/bench.cu | 146 ++++++++++++++++++ cpp/bench/distance/tune_pairwise/kernel.cu | 85 ++++++++++ cpp/bench/distance/tune_pairwise/kernel.cuh | 51 ++++++ .../distance/detail/distance_ops/canberra.cuh | 1 + .../detail/distance_ops/jensen_shannon.cuh | 1 + .../detail/distance_ops/kl_divergence.cuh | 1 + .../distance/detail/distance_ops/lp_unexp.cuh | 1 + cpp/include/raft/util/device_loads_stores.cuh | 3 +- 9 files changed, 293 insertions(+), 1 deletion(-) create mode 100644 cpp/bench/distance/tune_pairwise/bench.cu create mode 100644 cpp/bench/distance/tune_pairwise/kernel.cu create mode 100644 cpp/bench/distance/tune_pairwise/kernel.cuh diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index e2324de654..d053d6c18f 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -73,6 +73,11 @@ if(BUILD_BENCH) OPTIONAL DIST NN ) + ConfigureBench( + NAME TUNE_DISTANCE PATH bench/distance/tune_pairwise/kernel.cu + bench/distance/tune_pairwise/bench.cu bench/main.cpp + ) + ConfigureBench( NAME DISTANCE_BENCH diff --git a/cpp/bench/distance/tune_pairwise/bench.cu b/cpp/bench/distance/tune_pairwise/bench.cu new file mode 100644 index 0000000000..2cc0be63ec --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/bench.cu @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Tuning benchmarks. +// +// Goals: +// +// 1. Fast compile times to maintain iteration speed. +// 2. Create benchmarks that can inform the design of the kernels. +// +// Non-goals: +// +// 1. Measure every distance operation. Instead measures just one distance +// operation at the same time. +// 2. Be useful for finding performance regressions. This is handled by the +// normal benchmarks. +// +// So far, both goals are partly achieved. +// +// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not. +// When the internals of a pairwise distance kernel is changed, this file is not +// recompiled. +// +// RE 2, benchmarks with intent: this file contains a benchmark to check the +// maximal throughput of a kernel. Measuring other things, like performance on +// skinny or wide matrices is not yet implemented. + +#include "kernel.cuh" // launch_kernel +#include // std::min +#include // RAFT_BENCH_REGISTER +#include // pairwise_matrix_params +#include // rmm::device_uvector +#include // std::vector + +namespace raft::bench::distance::tune { + +// Max throughput benchmark. +// +// Goal: Measure the maximum distances/sec that can be computed. +// +// To achieve this, we make sure that: +// +// - Input data size is a multiple of the block tile size. +// +// - Perfect distribution of work between SMs, i.e. the number of block tiles is +// a large multiple (num_waves) of the number of blocks (#SMs * occupancy). +// +// - Multiple iterations over Kblk are executed (num_k_iters). +struct throughput_param { + int num_waves; + int occupancy; + int num_k_iters; +}; + +const std::vector throughput_params{ + // 32 waves, requested occupancy of 4, and 32 k iterations typically achieves + // maximum throughput. No need to pick higher values. + {32, 4, 32}, +}; + +struct throughput_bench : public fixture { + const throughput_param p; + + throughput_bench(const throughput_param& p_) : p(p_) {} + + void run_benchmark(::benchmark::State& state) override + { + // Get block size: + int block_m, block_n, block_k; + get_block_size(block_m, block_n, block_k); + + // Determine number of blocks that will be launched. This informs the size + // of the inputs as well as the grid size. + const int num_sms = raft::getMultiProcessorCount(); + const int max_occupancy = get_max_occupancy(distance_op); + const int occupancy = std::min(p.occupancy, max_occupancy); + const int num_blocks = occupancy * num_sms; + dim3 grid(num_blocks); + + // Create input sizes that are a multiple of the block tile size. + size_t m = block_m; + size_t n = block_n * p.num_waves * num_blocks; + size_t k = block_k * p.num_k_iters; + + // DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh + rmm::device_uvector x_vec(m * k, stream); + rmm::device_uvector y_vec(n * k, stream); + rmm::device_uvector x_norm_vec(m, stream); + rmm::device_uvector y_norm_vec(n, stream); + rmm::device_uvector out_vec(m * n, stream); + + auto x = x_vec.data(); + auto y = y_vec.data(); + auto x_norm = x_norm_vec.data(); + auto y_norm = y_norm_vec.data(); + auto out = out_vec.data(); + FinOpT fin_op{}; + + auto make_params = raft::distance::detail::make_params; + pairwise_matrix_params kparams = + row_major ? make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, row_major) + : make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, row_major); + + // Run benchmark + loop_on_state(state, [&]() { launch_kernel(distance_op, kparams, grid, stream); }); + + // Report metrics. We don't report flop/s because we do not know for each + // distance operation how many flops it costs. For L2_unexp and l1, we can + // double this number to get the flop/s. For l2 expanded, dist/s should + // equal flop/s (modulo the sqrt and subtracting from the norm). + size_t num_dists = m * n * k; + size_t read_elts = n * k + m * k; + size_t write_elts = m * n; + + state.counters["m"] = benchmark::Counter(m); + state.counters["n"] = benchmark::Counter(n); + state.counters["k"] = benchmark::Counter(k); + state.counters["occupancy"] = benchmark::Counter(occupancy); + state.counters["# waves"] = benchmark::Counter(p.num_waves); + state.counters["# k iters"] = benchmark::Counter(p.num_k_iters); + + state.counters["dist/s"] = benchmark::Counter( + num_dists, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000); + + state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + } +}; + +RAFT_BENCH_REGISTER(throughput_bench, "", throughput_params); + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/distance/tune_pairwise/kernel.cu b/cpp/bench/distance/tune_pairwise/kernel.cu new file mode 100644 index 0000000000..3b46ff89ea --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/kernel.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.cuh" +#include // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 +#include // raft::arch::SM_compute_arch + +namespace raft::bench::distance::tune { + +constexpr int vec_len = 1; +using Policy = typename raft::linalg::Policy4x4::Policy; +constexpr auto sm_compat_range = + raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()); + +void launch_kernel(OpT distance_op, pairwise_matrix_params params, dim3 grid, cudaStream_t stream) +{ + dim3 block(Policy::Nthreads); + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_size = distance_op.template shared_mem_size(); + + // Obtain function pointer to kernel + auto kernel = raft::distance::detail::pairwise_matrix_kernel; + + kernel<<>>(distance_op, params); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +void get_block_size(int& m, int& n, int& k) +{ + m = Policy::Mblk; + n = Policy::Nblk; + k = Policy::Kblk; +} + +void* get_kernel_ptr() +{ + auto kernel = raft::distance::detail::pairwise_matrix_kernel; + + return reinterpret_cast(kernel); +} + +int get_max_occupancy(OpT distance_op) +{ + void* kernel_ptr = get_kernel_ptr(); + int max_occupancy; + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_size = distance_op.template shared_mem_size(); + + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, kernel_ptr, Policy::Nthreads, smem_size)); + + return max_occupancy; +} + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/distance/tune_pairwise/kernel.cuh b/cpp/bench/distance/tune_pairwise/kernel.cuh new file mode 100644 index 0000000000..b444c5a87a --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/kernel.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // lp_unexp_distance_op +#include // pairwise_matrix_params + +namespace raft::bench::distance::tune { + +// Launch one specific kernel with the following template parameters +constexpr bool row_major = true; +using DataT = float; +using AccT = float; +using OutT = DataT; +using IdxT = int; + +// Distance op +// C++17 inline variable. Used by both tuned_kernel.cu and tune_pairwise.cu +// See: https://open-std.org/JTC1/SC22/WG21/docs/papers/2016/p0386r0.pdf +using OpT = raft::distance::detail::ops::lp_unexp_distance_op; +constexpr float metric_arg = 2.0; +inline const OpT distance_op{metric_arg}; +using FinOpT = raft::identity_op; + +using pairwise_matrix_params = + raft::distance::detail::pairwise_matrix_params; + +// Launches kernel +void launch_kernel(OpT, pairwise_matrix_params, dim3, cudaStream_t); + +// Describes the block size that is decided by the policy +void get_block_size(int& m, int& n, int& k); + +void* get_kernel_ptr(); +int get_max_occupancy(OpT); + +} // namespace raft::bench::distance::tune diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 2215ded8e2..0664eb98ee 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -16,6 +16,7 @@ #pragma once +#include // raft::abs #include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index d82dfe8463..fd2e0f4a3e 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include // raft::log #include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index 8f3260a799..705f83ecfc 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include // raft::log #include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index fa1048b753..1c40adf905 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -15,6 +15,7 @@ */ #pragma once +#include // raft::pow, raft::abs #include // DI namespace raft::distance::detail::ops { diff --git a/cpp/include/raft/util/device_loads_stores.cuh b/cpp/include/raft/util/device_loads_stores.cuh index 4344201fa4..c9bda26b81 100644 --- a/cpp/include/raft/util/device_loads_stores.cuh +++ b/cpp/include/raft/util/device_loads_stores.cuh @@ -16,7 +16,8 @@ #pragma once -#include +#include // uintX_t +#include // DI namespace raft { From 35a2ad4373f0b4963385e24826610157476019f4 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 15 Mar 2023 16:53:30 +0100 Subject: [PATCH 75/93] Limit loop unrolling for expensive distance ops We have four distance operations with expensive inner loops. They are already limited to veclen == 1. In this commit, we are also limiting the loop unrolling in accumulate(). In general unroll 1 is best. The only exception is canberra, which can be sped up with unlimited unrolling. Results below are for veclen = 1 on H100 with locked 2619Mhz memory and 1980 MHz SM clock. For comparison, I have added L1 results as well. | distance_op | veclen | unroll 1 | unroll 2 | unlimited | |----------------+--------+----------+----------+-----------| | lp_unexp | 1 | 248 G/s | 248 G/s | 235 G/s | | canberra | 1 | 371 | 367 | 447 | | canberra | 4 | 377 | 369 | 378 | | jensen shannon | 1 | 512 G/s | 512 | 449 | | kl divergence | 1 | 659 | 391 | 265 | |----------------+--------+----------+----------+-----------| | l1 | 1 | 8.9 T/s | 8.6 T/s | 9.95 T/s | | l1 | 4 | 11.1 T/s | 11.5 T/s | 11.7 T/s | Compile time impact: pairwise_test 0.5 seconds build.ninja 4.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu.o 6.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu.o 7.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu.o 7.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu.o 7.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu.o 8.2 seconds akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu.o 8.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_double_double_double_int.cu.o 9.0 seconds akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu.o 9.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu.o 9.2 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu.o 10.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu.o 10.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu.o 10.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu.o 10.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu.o 12.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu.o 13.0 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu.o 13.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu.o 13.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_float_float_float_int.cu.o 13.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu.o 14.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu.o 14.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu.o 14.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu.o 23.4 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu.o 24.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu.o 31.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu.o 34.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu.o 35.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/fused_l2_nn_double_int64.cu.o 37.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/fused_l2_nn_double_int.cu.o 38.8 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu.o 40.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu.o 41.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu.o 42.1 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu.o 42.5 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/fused_l2_nn_float_int.cu.o 43.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_hellinger.cu.o 44.4 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_russell_rao.cu.o 44.6 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_hamming.cu.o 44.7 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/fused_l2_nn_float_int64.cu.o 45.0 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_inner_product.cu.o 45.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_cos.cu.o 45.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu.o 46.2 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_kl_divergence.cu.o 46.4 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_unexp.cu.o 46.7 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l1.cu.o 47.1 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_lp_unexp.cu.o 47.3 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_exp.cu.o 47.4 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_canberra.cu.o 47.6 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_correlation.cu.o 47.6 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_jensen_shannon.cu.o 48.3 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l_inf.cu.o 48.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu.o 48.9 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_l2_sqrt_exp.cu.o 49.3 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu.o 50.6 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu.o 51.9 seconds CMakeFiles/pairwise_test.dir/src/distance/distance/pairwise_distance.cu.o 54.5 seconds CMakeFiles/pairwise_test.dir/test/distance/gram.cu.o 56.8 seconds CMakeFiles/pairwise_test.dir/test/distance/fused_l2_nn.cu.o 67.5 seconds CMakeFiles/pairwise_test.dir/test/distance/dist_adj.cu.o 123.3 seconds --- .../detail/pairwise_distance_base.cuh | 31 ++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 583476ede6..c6b09be31e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -205,24 +205,41 @@ struct PairwiseDistances : public BaseClass { } } - DI void accumulate() + DI void accumulate_reg_tile(DataT (®_x)[P::AccRowsPerTh][P::Veclen], + DataT (®_y)[P::AccColsPerTh][P::Veclen]) { #pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); + for (int v = 0; v < P::Veclen; ++v) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - distance_op.core(acc[i][j], this->regx[i][v], this->regy[j][v]); - } + distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); } } } } + DI void accumulate() + { + // We have a separate ldsXY and accumulate_reg_tile outside the loop body, + // so that these separated calls can be interspersed with preceding and + // following instructions, thereby hiding latency. + this->ldsXY(0); + + // If expensive inner loop, do not unroll loop. + constexpr int num_iterations = P::Kblk / P::Veclen - 1; + constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations; +#pragma unroll unroll_count + for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) { + accumulate_reg_tile(this->regx, this->regy); + this->ldsXY(ki); + } + + // Accumulate last loaded tile. + accumulate_reg_tile(this->regx, this->regy); + } + DI void load_norms(IdxT tile_idx_m, IdxT tile_idx_n, DataT (®xn)[P::AccRowsPerTh], From 9bd7a8392057539c6327edf6515fba588f6f8de8 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 16 Mar 2023 16:41:32 +0100 Subject: [PATCH 76/93] Fix column major errors on SM80 --- .../detail/pairwise_matrix/dispatch.cuh | 12 ++++++-- .../detail/pairwise_matrix/params.cuh | 30 +++++-------------- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index b5bed6e53d..5239b1744a 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -121,10 +121,16 @@ void pairwise_matrix_dispatch(OpT distance_op, bool is_row_major) { // Create kernel parameter struct. Flip x and y if column major. - pairwise_matrix_params params = - is_row_major ? make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, is_row_major) - : make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, is_row_major); + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { + params = params.flip_x_and_y(); + } pairwise_matrix_instantiation_point(distance_op, params, stream); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh index d7fc2b28c3..f10a12d41c 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh @@ -32,30 +32,14 @@ struct pairwise_matrix_params { OutT* out; FinOpT fin_op; bool is_row_major; -}; -template -pairwise_matrix_params make_params(IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - bool is_row_major) -{ - // Determine leading dimensions. - IdxT ldx, ldy, ld_out; - if (is_row_major) { - ldx = k, ldy = k, ld_out = n; - } else { - ldx = m, ldy = n, ld_out = m; + // + [[nodiscard]] pairwise_matrix_params flip_x_and_y() + { + // Flip m, n; ldx, ldy; x, y; x_norm, y_norm. + return pairwise_matrix_params { + n, m, k, ldy, ldx, ld_out, y, x, y_norm, x_norm, out, fin_op, is_row_major}; } - - return pairwise_matrix_params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; -} +}; } // namespace raft::distance::detail From c9ab1b8b78b1fd53409984740a63e5cd3f3bd4e5 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 16 Mar 2023 19:23:12 +0100 Subject: [PATCH 77/93] Fix col major errors on SM80 --- .../detail/pairwise_matrix/dispatch.cuh | 11 +++++-- .../detail/pairwise_matrix/params.cuh | 30 +++++-------------- 2 files changed, 15 insertions(+), 26 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 651490b6be..492f8600e9 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -46,9 +46,14 @@ void pairwise_matrix_dispatch(OpT distance_op, bool is_row_major) { // Create kernel parameter struct. Flip x and y if column major. - pairwise_matrix_params params = - is_row_major ? make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, is_row_major) - : make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, is_row_major); + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params = params.flip_x_and_y(); } // On CUDA 12: // - always execute normal kernel diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh index d7fc2b28c3..dbc47d5aeb 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh @@ -32,30 +32,14 @@ struct pairwise_matrix_params { OutT* out; FinOpT fin_op; bool is_row_major; -}; -template -pairwise_matrix_params make_params(IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - bool is_row_major) -{ - // Determine leading dimensions. - IdxT ldx, ldy, ld_out; - if (is_row_major) { - ldx = k, ldy = k, ld_out = n; - } else { - ldx = m, ldy = n, ld_out = m; + // + [[nodiscard]] pairwise_matrix_params flip_x_and_y() + { + // Flip m, n; ldx, ldy; x, y; x_norm, y_norm. + return pairwise_matrix_params{ + n, m, k, ldy, ldx, ld_out, y, x, y_norm, x_norm, out, fin_op, is_row_major}; } - - return pairwise_matrix_params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; -} +}; } // namespace raft::distance::detail From 5d1f6c2b3deb3fd536a8208cdea34ca3771a1773 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 16 Mar 2023 20:16:03 +0100 Subject: [PATCH 78/93] Use raft::util::arch namespace --- cpp/bench/distance/tune_pairwise/kernel.cu | 11 +++++----- .../detail/pairwise_matrix/dispatch.cuh | 18 ++++++++--------- .../detail/pairwise_matrix/kernel_sm60.cuh | 4 ++-- .../detail/00_write_template.py | 13 ------------ cpp/include/raft/util/arch.cuh | 20 ++++++++++--------- .../detail/00_write_template.py | 2 -- .../canberra_double_double_double_int.cu | 6 ++---- .../detail/canberra_float_float_float_int.cu | 6 ++---- .../correlation_double_double_double_int.cu | 6 ++---- .../correlation_float_float_float_int.cu | 6 ++---- .../detail/cosine_double_double_double_int.cu | 6 ++---- .../detail/cosine_float_float_float_int.cu | 6 ++---- ...ing_unexpanded_double_double_double_int.cu | 6 ++---- ...amming_unexpanded_float_float_float_int.cu | 6 ++---- ...inger_expanded_double_double_double_int.cu | 6 ++---- ...ellinger_expanded_float_float_float_int.cu | 6 ++---- ...jensen_shannon_double_double_double_int.cu | 6 ++---- .../jensen_shannon_float_float_float_int.cu | 6 ++---- .../kl_divergence_double_double_double_int.cu | 6 ++---- .../kl_divergence_float_float_float_int.cu | 6 ++---- .../detail/l1_double_double_double_int.cu | 6 ++---- .../detail/l1_float_float_float_int.cu | 6 ++---- .../l2_expanded_double_double_double_int.cu | 6 ++---- .../l2_expanded_float_float_float_int.cu | 6 ++---- .../l2_unexpanded_double_double_double_int.cu | 6 ++---- .../l2_unexpanded_float_float_float_int.cu | 6 ++---- .../detail/l_inf_double_double_double_int.cu | 6 ++---- .../detail/l_inf_float_float_float_int.cu | 6 ++---- .../lp_unexpanded_double_double_double_int.cu | 6 ++---- .../lp_unexpanded_float_float_float_int.cu | 6 ++---- .../russel_rao_double_double_double_int.cu | 6 ++---- .../russel_rao_float_float_float_int.cu | 6 ++---- 32 files changed, 80 insertions(+), 144 deletions(-) diff --git a/cpp/bench/distance/tune_pairwise/kernel.cu b/cpp/bench/distance/tune_pairwise/kernel.cu index 3b46ff89ea..7511e5dc7d 100644 --- a/cpp/bench/distance/tune_pairwise/kernel.cu +++ b/cpp/bench/distance/tune_pairwise/kernel.cu @@ -17,14 +17,15 @@ #include "kernel.cuh" #include // pairwise_matrix_sm60_wrapper #include // raft::linalg::Policy4x4 -#include // raft::arch::SM_compute_arch +#include // raft::util::arch::SM_compute_arch namespace raft::bench::distance::tune { -constexpr int vec_len = 1; -using Policy = typename raft::linalg::Policy4x4::Policy; -constexpr auto sm_compat_range = - raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()); +namespace arch = raft::util::arch; + +constexpr int vec_len = 1; +using Policy = typename raft::linalg::Policy4x4::Policy; +constexpr auto sm_compat_range = arch:: ::SM_range(arch:: ::SM_min(), arch:: ::SM_future()); void launch_kernel(OpT distance_op, pairwise_matrix_params params, dim3 grid, cudaStream_t stream) { diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 5239b1744a..4bde32e31a 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -34,6 +34,7 @@ #include // ops::has_cutlass_op #include // dispatch_sm60 #include // pairwise_matrix_params +#include // raft::util::arch::SM_* // NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. // Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). @@ -70,17 +71,18 @@ void pairwise_matrix_instantiation_point(OpT distance_op, // On CUDA 11 and below: // - execute CUTLASS-based kernel on SM_80 and above // - execute normal kernel below SM_80 + namespace arch = raft::util::arch; constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); if constexpr (is_ctk_12 || cutlass_op_unavailable) { // Always execute legacy kernels on CUDA 12 - auto any_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()); + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); } else { - auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); // Get pointer to SM60 kernel to determine the runtime architecture of the // current system. Other methods to determine the architecture (that do not @@ -88,7 +90,7 @@ void pairwise_matrix_instantiation_point(OpT distance_op, // https://github.com/NVIDIA/cub/issues/545 auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = raft::arch::kernel_runtime_arch(kernel_ptr); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. @@ -121,16 +123,14 @@ void pairwise_matrix_dispatch(OpT distance_op, bool is_row_major) { // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; IdxT ld_out = is_row_major ? n : m; pairwise_matrix_params params{ m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - if (!params.is_row_major) { - params = params.flip_x_and_y(); - } + if (!params.is_row_major) { params = params.flip_x_and_y(); } pairwise_matrix_instantiation_point(distance_op, params, stream); } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index b298391ef2..9952d6e641 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -19,7 +19,7 @@ #include // raft::void_op #include // PairwiseDistances #include // pairwise_matrix_params -#include // raft::arch::SM_compute_arch +#include // raft::util::arch::SM_compute_arch namespace raft::distance::detail { @@ -36,7 +36,7 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( { // Early exit to minimize the size of the kernel when it is not supposed to be compiled. constexpr SM_compat_t sm_compat_range{}; - if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { + if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { assert(false); return; } diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py index 861264e3a0..f0b6d0ed5e 100644 --- a/cpp/include/raft/distance/specializations/detail/00_write_template.py +++ b/cpp/include/raft/distance/specializations/detail/00_write_template.py @@ -66,70 +66,57 @@ dict( path_prefix="canberra", OpT="ops::canberra_distance_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="correlation", OpT="ops::correlation_distance_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="cosine", OpT="ops::cosine_distance_op", # cosine uses CUTLASS for SM80+ - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="hamming_unexpanded", OpT="ops::hamming_distance_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="hellinger_expanded", OpT="ops::hellinger_distance_op", - SM_compat_t="raft::arch::SM_range", ), # inner product is handled by cublas. dict( path_prefix="jensen_shannon", OpT="ops::jensen_shannon_distance_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="kl_divergence", OpT="ops::kl_divergence_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="l1", OpT="ops::l1_distance_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="l2_expanded", OpT="ops::l2_exp_distance_op", # L2 expanded uses CUTLASS for SM80+ - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="l2_unexpanded", OpT="ops::l2_unexp_distance_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="l_inf", OpT="ops::l_inf_distance_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="lp_unexpanded", OpT="ops::lp_unexp_distance_op", - SM_compat_t="raft::arch::SM_range", ), dict( path_prefix="russel_rao", OpT="ops::russel_rao_distance_op", - SM_compat_t="raft::arch::SM_range", ), ] diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 8c48b87269..740c2ff971 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -15,25 +15,27 @@ */ #pragma once -namespace raft::arch { +#include // RAFT_CUDA_TRY -/* raft::arch provides the following facilities: +namespace raft::util::arch { + +/* raft::util::arch provides the following facilities: * - * - raft::arch::SM_XX : hardcoded compile-time constants for various compute - * architectures. The values raft::arch::SM_min and raft::arch::SM_future + * - raft::util::arch::SM_XX : hardcoded compile-time constants for various compute + * architectures. The values raft::util::arch::SM_min and raft::util::arch::SM_future * represent architectures that are always smaller and larger (respectively) * than any architecture that can be encountered in practice. * - * - raft::arch::SM_compute_arch : a compile-time value for the *current* + * - raft::util::arch::SM_compute_arch : a compile-time value for the *current* * compute architecture that a kernel is compiled with. It can only be used * inside kernels with a template argument. * - * - raft::arch::kernel_runtime_arch : a function that computes at *run-time* + * - raft::util::arch::kernel_runtime_arch : a function that computes at *run-time* * which version of a kernel will launch (i.e., it will return the compute * architecture of the version of the kernel that will be launched by the * driver). * - * - raft::arch::SM_range : a compile-time value to represent an open interval + * - raft::util::arch::SM_range : a compile-time value to represent an open interval * of compute architectures. This can be used to check if the current * compile-time architecture is in a specified compatibility range. */ @@ -119,7 +121,7 @@ struct SM_runtime { inline SM_runtime kernel_runtime_arch(void* kernel) { cudaFuncAttributes attributes; - cudaFuncGetAttributes(&attributes, kernel); + RAFT_CUDA_TRY(cudaFuncGetAttributes(&attributes, kernel)); return SM_runtime(10 * attributes.ptxVersion); } @@ -143,4 +145,4 @@ struct SM_range { } }; -} // namespace raft::arch +} // namespace raft::util::arch diff --git a/cpp/src/distance/distance/specializations/detail/00_write_template.py b/cpp/src/distance/distance/specializations/detail/00_write_template.py index 81b8731546..3f2f853569 100644 --- a/cpp/src/distance/distance/specializations/detail/00_write_template.py +++ b/cpp/src/distance/distance/specializations/detail/00_write_template.py @@ -20,10 +20,8 @@ #include // raft::identity_op #include // ops::* - #include // pairwise_matrix_instantiation_point INCLUDE_SM_HEADERS -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu index 71ce79ad28..037d218178 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu index 84c1cfe4e2..0ed8ea7bb0 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu index d684273826..0c11f0621e 100644 --- a/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu index c83bb2b204..396e158554 100644 --- a/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu index 202ee96ee5..e9afb6f563 100644 --- a/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu @@ -14,13 +14,11 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu index 6b221aa2b5..1033c491d6 100644 --- a/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu @@ -14,13 +14,11 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu index a1a3ebc601..195115914d 100644 --- a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu index 8d596db93b..a74c6c404e 100644 --- a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu index cd1b37de7e..bac1dd7bd0 100644 --- a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu index b67121f6af..77c113b1a9 100644 --- a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index 738a9406be..188e52c152 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu index 1685494010..b0afbf7bb2 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu index c3a77c7a8f..f06ae85414 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu index 75c17fdb10..00d5a5ee5b 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu index 516384c967..5c235316da 100644 --- a/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu index a3535a75a6..fb293ca83d 100644 --- a/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu index 474c031e01..2c02f0224f 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu @@ -14,13 +14,11 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu index 334a367453..85e25a25ca 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu @@ -14,13 +14,11 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu index 41a70341d0..5b4d995d14 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu index ac27e35d01..a63c3f0bb8 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu index 4e06d0264a..831167523f 100644 --- a/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu index c19a8e6016..02e667cbe3 100644 --- a/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu index c3c8d2b96f..ebd71065ec 100644 --- a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu index ec8317d9d4..b94a81fdce 100644 --- a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu index d842cebd44..6f952fcc37 100644 --- a/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { diff --git a/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu index 179599f549..3223ce33a7 100644 --- a/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu @@ -14,12 +14,10 @@ * limitations under the License. */ -#include // raft::identity_op -#include // ops::* - +#include // raft::identity_op +#include // ops::* #include // pairwise_matrix_instantiation_point #include -#include // raft::arch::SM_compat_range namespace raft::distance::detail { From 30c33912a79a0d37668bc4dc715cb230f891e25f Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 17 Mar 2023 10:03:57 +0100 Subject: [PATCH 79/93] Fix build failure The fin_op can be non-trivially copyable, causing problems. --- .../raft/distance/detail/pairwise_matrix/dispatch.cuh | 2 +- .../raft/distance/detail/pairwise_matrix/params.cuh | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 492f8600e9..8524ce6fdf 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -53,7 +53,7 @@ void pairwise_matrix_dispatch(OpT distance_op, pairwise_matrix_params params{ m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - if (!params.is_row_major) { params = params.flip_x_and_y(); } + if (!params.is_row_major) { params.flip_x_and_y(); } // On CUDA 12: // - always execute normal kernel diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh index dbc47d5aeb..5962432dfd 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh @@ -33,12 +33,14 @@ struct pairwise_matrix_params { FinOpT fin_op; bool is_row_major; - // - [[nodiscard]] pairwise_matrix_params flip_x_and_y() + /// @brief: Flips the x and y input and corrresponding sizes + void flip_x_and_y() { // Flip m, n; ldx, ldy; x, y; x_norm, y_norm. - return pairwise_matrix_params{ - n, m, k, ldy, ldx, ld_out, y, x, y_norm, x_norm, out, fin_op, is_row_major}; + std::swap(m, n); + std::swap(ldx, ldy); + std::swap(x, y); + std::swap(x_norm, y_norm); } }; From 9eaf9b5058d8e03d96d6ec4a22214ba5601213a3 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Fri, 17 Mar 2023 11:51:07 +0100 Subject: [PATCH 80/93] Fix spelling --- cpp/include/raft/distance/detail/pairwise_matrix/params.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh index 5962432dfd..005b95afe9 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh @@ -33,7 +33,7 @@ struct pairwise_matrix_params { FinOpT fin_op; bool is_row_major; - /// @brief: Flips the x and y input and corrresponding sizes + /// @brief: Flips the x and y input and corresponding sizes void flip_x_and_y() { // Flip m, n; ldx, ldy; x, y; x_norm, y_norm. From 2b3b20330109e2bf7a59baf5e08fb4e75054627e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Sat, 18 Mar 2023 12:41:40 +0100 Subject: [PATCH 81/93] Fix pairwise tune benchmark --- cpp/bench/distance/tune_pairwise/kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/bench/distance/tune_pairwise/kernel.cu b/cpp/bench/distance/tune_pairwise/kernel.cu index 7511e5dc7d..fb0e230ca5 100644 --- a/cpp/bench/distance/tune_pairwise/kernel.cu +++ b/cpp/bench/distance/tune_pairwise/kernel.cu @@ -25,7 +25,7 @@ namespace arch = raft::util::arch; constexpr int vec_len = 1; using Policy = typename raft::linalg::Policy4x4::Policy; -constexpr auto sm_compat_range = arch:: ::SM_range(arch:: ::SM_min(), arch:: ::SM_future()); +constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future()); void launch_kernel(OpT distance_op, pairwise_matrix_params params, dim3 grid, cudaStream_t stream) { From 3b686a9f48a86ec50a0e4bc20c82dae2652a1b96 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 20 Mar 2023 09:33:05 +0100 Subject: [PATCH 82/93] Fix compilation error pairwise tuning bench --- cpp/bench/distance/tune_pairwise/bench.cu | 25 ++++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/cpp/bench/distance/tune_pairwise/bench.cu b/cpp/bench/distance/tune_pairwise/bench.cu index 2cc0be63ec..02e2ca9432 100644 --- a/cpp/bench/distance/tune_pairwise/bench.cu +++ b/cpp/bench/distance/tune_pairwise/bench.cu @@ -109,21 +109,25 @@ struct throughput_bench : public fixture { auto out = out_vec.data(); FinOpT fin_op{}; - auto make_params = raft::distance::detail::make_params; - pairwise_matrix_params kparams = - row_major ? make_params(m, n, k, x, y, x_norm, y_norm, out, fin_op, row_major) - : make_params(n, m, k, y, x, y_norm, x_norm, out, fin_op, row_major); + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = row_major ? k : m; + IdxT ldy = row_major ? k : n; + IdxT ld_out = row_major ? n : m; + + // Template parameters of pairwise_matrix_params are defined in kernel.cuh + pairwise_matrix_params kparams{ + IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major}; // Run benchmark loop_on_state(state, [&]() { launch_kernel(distance_op, kparams, grid, stream); }); // Report metrics. We don't report flop/s because we do not know for each // distance operation how many flops it costs. For L2_unexp and l1, we can - // double this number to get the flop/s. For l2 expanded, dist/s should + // double this number to get the flop/s. For l2 expanded, core_ops/s should // equal flop/s (modulo the sqrt and subtracting from the norm). - size_t num_dists = m * n * k; - size_t read_elts = n * k + m * k; - size_t write_elts = m * n; + size_t num_core_ops = m * n * k; + size_t read_elts = n * k + m * k; + size_t write_elts = m * n; state.counters["m"] = benchmark::Counter(m); state.counters["n"] = benchmark::Counter(n); @@ -132,8 +136,9 @@ struct throughput_bench : public fixture { state.counters["# waves"] = benchmark::Counter(p.num_waves); state.counters["# k iters"] = benchmark::Counter(p.num_k_iters); - state.counters["dist/s"] = benchmark::Counter( - num_dists, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000); + state.counters["core_ops/s"] = benchmark::Counter(num_core_ops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT), benchmark::Counter::kIsIterationInvariantRate, From e5eb7721030846753ec0c054540088751997aa77 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 20 Mar 2023 10:47:57 +0100 Subject: [PATCH 83/93] Implement reviewer feedback --- .../distance/detail/distance_ops/cosine.cuh | 7 ++-- .../distance/detail/distance_ops/l2_exp.cuh | 5 ++- .../detail/pairwise_distance_cutlass_base.cuh | 29 ++++++++-------- .../pairwise_matrix/dispatch_layout.cuh | 12 +++---- .../detail/pairwise_matrix/dispatch_sm60.cuh | 18 +++++++--- .../detail/pairwise_matrix/dispatch_sm80.cuh | 10 ++++-- .../detail/00_write_template.py | 2 +- .../specializations/detail/canberra.cuh | 34 +++++++++---------- .../specializations/detail/correlation.cuh | 34 +++++++++---------- .../specializations/detail/cosine.cuh | 21 ++++++------ .../detail/hamming_unexpanded.cuh | 34 +++++++++---------- .../detail/hellinger_expanded.cuh | 34 +++++++++---------- .../specializations/detail/jensen_shannon.cuh | 34 +++++++++---------- .../specializations/detail/kl_divergence.cuh | 8 ++--- .../distance/specializations/detail/l1.cuh | 8 ++--- .../specializations/detail/l2_expanded.cuh | 21 ++++++------ .../specializations/detail/l2_unexpanded.cuh | 34 +++++++++---------- .../distance/specializations/detail/l_inf.cuh | 21 ++++++------ .../specializations/detail/lp_unexpanded.cuh | 34 +++++++++---------- .../specializations/detail/russel_rao.cuh | 34 +++++++++---------- cpp/include/raft/util/arch.cuh | 3 -- cpp/test/distance/distance_base.cuh | 4 +-- 22 files changed, 219 insertions(+), 222 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index c103cf6121..9eb84932c5 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -26,7 +26,7 @@ struct cosine_cutlass_op { __device__ cosine_cutlass_op() noexcept {} __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { - return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + return static_cast(1.0) - static_cast(accVal / (aNorm * bNorm)); } __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; @@ -76,7 +76,10 @@ struct cosine_distance_op { } } - cosine_cutlass_op get_cutlass_op() { return cosine_cutlass_op(); } + constexpr cosine_cutlass_op get_cutlass_op() + { + return cosine_cutlass_op(); + } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index cb7702396a..84da07a586 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -102,7 +102,10 @@ struct l2_exp_distance_op { } } - l2_exp_cutlass_op get_cutlass_op() { return l2_exp_cutlass_op(sqrt); } + constexpr l2_exp_cutlass_op get_cutlass_op() + { + return l2_exp_cutlass_op(sqrt); + } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index c5fdd28117..efcd5d9389 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -64,21 +64,20 @@ template -typename std::enable_if::value>::type cutlassDistanceKernel( - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - OpT distance_op, - cudaStream_t stream) +std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + OpT distance_op, + cudaStream_t stream) { static_assert(!(std::is_same::value), "OutType bool is not supported use uint8_t instead"); diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh index 08f5155cf4..f2b0e59822 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -100,15 +100,15 @@ auto dispatch_layout(bool row_major, int vec_len, F&& f) { if (row_major) { switch (vec_len) { - case 4: return f(std::bool_constant(), vec_len_constant<4>()); - case 2: return f(std::bool_constant(), vec_len_constant<2>()); - default: return f(std::bool_constant(), vec_len_constant<1>()); + case 4: return f(std::true_type(), vec_len_constant<4>()); + case 2: return f(std::true_type(), vec_len_constant<2>()); + default: return f(std::true_type(), vec_len_constant<1>()); } } else { switch (vec_len) { - case 4: return f(std::bool_constant(), vec_len_constant<4>()); - case 2: return f(std::bool_constant(), vec_len_constant<2>()); - default: return f(std::bool_constant(), vec_len_constant<1>()); + case 4: return f(std::false_type(), vec_len_constant<4>()); + case 2: return f(std::false_type(), vec_len_constant<2>()); + default: return f(std::false_type(), vec_len_constant<1>()); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh index cb0fd59da2..2080fbe9cd 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh @@ -35,7 +35,11 @@ pairwise_matrix_sm60_wrapper pairwise_matrix_sm6 { int vec_len = determine_vec_len(params); - return dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // f takes compile-time constants row_major and vec_len aligned and returns + // the corresponding kernel wrapper. The wrapper contains the launch + // parameters of the kernel: a pointer to the kernel function, grid size, + // block size, and shared memory size. + auto f = [&](auto row_major, auto vec_len_aligned) { // row_major and vec_len are std::integral_constants of type bool and int // respectively. @@ -46,15 +50,19 @@ pairwise_matrix_sm60_wrapper pairwise_matrix_sm6 // Prevent double, vec_len=4 combination (this is not supported) constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; + using RowPolicy = typename raft::linalg::Policy4x4::Policy; + using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; + using Policy = typename std::conditional::type; auto wrapper = make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); return wrapper; - }); + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + return dispatch_layout(params.is_row_major, vec_len, f); } template , - int, - float, - float, - decltype(raft::identity_op())>( - ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::canberra_distance_op, + int, + float, + float, + raft::identity_op>(ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::canberra_distance_op, + int, + double, + double, + raft::identity_op>(ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh index 2aec977be4..f019f678df 100644 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ b/cpp/include/raft/distance/specializations/detail/correlation.cuh @@ -20,23 +20,21 @@ namespace raft::distance::detail { -extern template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::correlation_distance_op, + int, + float, + float, + raft::identity_op>(ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::correlation_distance_op, + int, + double, + double, + raft::identity_op>(ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh index 92317f0de6..dcde4ec286 100644 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ b/cpp/include/raft/distance/specializations/detail/cosine.cuh @@ -24,18 +24,17 @@ extern template void pairwise_matrix_instantiation_point( + raft::identity_op>( ops::cosine_distance_op, - pairwise_matrix_params, + pairwise_matrix_params, cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::cosine_distance_op, + int, + double, + double, + raft::identity_op>(ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh index be06070514..1d6964fbce 100644 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh @@ -20,23 +20,21 @@ namespace raft::distance::detail { -extern template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::hamming_distance_op, + int, + float, + float, + raft::identity_op>(ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::hamming_distance_op, + int, + double, + double, + raft::identity_op>(ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh index b7d9dac1a1..f96a06f919 100644 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh @@ -20,23 +20,21 @@ namespace raft::distance::detail { -extern template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::hellinger_distance_op, + int, + float, + float, + raft::identity_op>(ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::hellinger_distance_op, + int, + double, + double, + raft::identity_op>(ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index b51cc32b62..0b58646582 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -20,23 +20,21 @@ namespace raft::distance::detail { -extern template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::jensen_shannon_distance_op, + int, + float, + float, + raft::identity_op>(ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::jensen_shannon_distance_op, + int, + double, + double, + raft::identity_op>(ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh index 5e1a125dea..5c164e0fd4 100644 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh @@ -24,17 +24,17 @@ extern template void pairwise_matrix_instantiation_point( + raft::identity_op>( ops::kl_divergence_op, - pairwise_matrix_params, + pairwise_matrix_params, cudaStream_t); extern template void pairwise_matrix_instantiation_point, int, double, double, - decltype(raft::identity_op())>( + raft::identity_op>( ops::kl_divergence_op, - pairwise_matrix_params, + pairwise_matrix_params, cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh index c44953bf02..870627d909 100644 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ b/cpp/include/raft/distance/specializations/detail/l1.cuh @@ -24,17 +24,17 @@ extern template void pairwise_matrix_instantiation_point( + raft::identity_op>( ops::l1_distance_op, - pairwise_matrix_params, + pairwise_matrix_params, cudaStream_t); extern template void pairwise_matrix_instantiation_point, int, double, double, - decltype(raft::identity_op())>( + raft::identity_op>( ops::l1_distance_op, - pairwise_matrix_params, + pairwise_matrix_params, cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh index 5e427af021..ee3207bcce 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh @@ -24,18 +24,17 @@ extern template void pairwise_matrix_instantiation_point( + raft::identity_op>( ops::l2_exp_distance_op, - pairwise_matrix_params, + pairwise_matrix_params, cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::l2_exp_distance_op, + int, + double, + double, + raft::identity_op>(ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh index 840760c4db..1fbf57632b 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh @@ -20,23 +20,21 @@ namespace raft::distance::detail { -extern template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::l2_unexp_distance_op, + int, + float, + float, + raft::identity_op>(ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::l2_unexp_distance_op, + int, + double, + double, + raft::identity_op>(ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh index b10d1b8098..388d3bf439 100644 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ b/cpp/include/raft/distance/specializations/detail/l_inf.cuh @@ -24,18 +24,17 @@ extern template void pairwise_matrix_instantiation_point( + raft::identity_op>( ops::l_inf_distance_op, - pairwise_matrix_params, + pairwise_matrix_params, cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::l_inf_distance_op, + int, + double, + double, + raft::identity_op>(ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh index e7632ead6c..d8e86ce6f2 100644 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh @@ -20,23 +20,21 @@ namespace raft::distance::detail { -extern template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::lp_unexp_distance_op, + int, + float, + float, + raft::identity_op>(ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::lp_unexp_distance_op, + int, + double, + double, + raft::identity_op>(ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh index 0c6f4c993e..4803fb8ab0 100644 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh @@ -20,23 +20,21 @@ namespace raft::distance::detail { -extern template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::russel_rao_distance_op, + int, + float, + float, + raft::identity_op>(ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); -extern template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); +extern template void pairwise_matrix_instantiation_point< + ops::russel_rao_distance_op, + int, + double, + double, + raft::identity_op>(ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); } // namespace raft::distance::detail diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 740c2ff971..dc35b10063 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -48,9 +48,6 @@ struct SM_generic { public: __host__ __device__ constexpr int value() const { return n; } }; - -// A dummy kernel that is used to determine the runtime architecture. -__global__ inline void dummy_runtime_kernel() {} } // namespace detail // A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index ae8230984a..30ec9ddfd8 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -570,8 +570,8 @@ class BigMatrixDistanceTest : public ::testing::Test { raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); - bool row_major = true; - float metric_arg = 0.0f; + constexpr bool row_major = true; + constexpr float metric_arg = 0.0f; #if defined RAFT_DISTANCE_COMPILED raft::runtime::distance::pairwise_distance( handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); From a5a362916e4831dbb120b6bb67069ffa269275b0 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 20 Mar 2023 12:11:47 +0100 Subject: [PATCH 84/93] Fix merge The merge duplicated two functions. Fixed here. --- .../detail/pairwise_matrix/kernel_sm60.cuh | 71 ------------------- 1 file changed, 71 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index a4b596bb1e..9952d6e641 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -152,75 +152,4 @@ pairwise_matrix_sm60_wrapper make_pairwise_matri grid, block, smem_size, kernel}; } -// The type of a pointer to the pairwise matrix kernel. The following template -// arguments are type-erased: -// -// - The kernel policy -// - row_major -// - SM_compat_t -template -using pairwise_matrix_kernel_t = void (*)(OpT, pairwise_matrix_params); - -// A wrapper for the pairwise matrix kernel launch. Includes kernel launch -// parameters. -template -struct pairwise_matrix_sm60_wrapper { - dim3 grid; - dim3 block; - int smem_size; - pairwise_matrix_kernel_t kernel_ptr; - - void launch(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) - { - kernel_ptr<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); - } -}; - -/** @brief: Create kernel launch wrapper for pairwise matrix kernel - * - * This can be used to type-erase the kernel execution policy, row_major, and SM - * compatibility range. - * - * @tparam Policy: Kernel execution policy - * @tparam row_major: Indicates whether input matrices are row major - * @tparam OpT: Type of distance operation - * @tparam IdxT: Index type - * @tparam DataT: Data type - * @tparam OutT: Output data type - * @tparam FinOpT: Final operation type - * @tparam SM_compat_t: Type of the SM architecture compatibility - * - * @param distance_op: Distance operation - * @param params: Parameters - * @param sm_compat_range: Which SM architectures to compile for. - */ -template -pairwise_matrix_sm60_wrapper make_pairwise_matrix_sm60_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - dim3 block(Policy::Nthreads); - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = distance_op.template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - return pairwise_matrix_sm60_wrapper{ - grid, block, smem_size, kernel}; -} - }; // namespace raft::distance::detail From c2970ba2d9913c049732164b0613cb75b469aee9 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 20 Mar 2023 17:53:41 +0100 Subject: [PATCH 85/93] tune_distance: Enable changing distance op without recompile --- cpp/bench/distance/tune_pairwise/bench.cu | 4 ++-- cpp/bench/distance/tune_pairwise/kernel.cu | 18 ++++++++++++------ cpp/bench/distance/tune_pairwise/kernel.cuh | 11 ++--------- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/cpp/bench/distance/tune_pairwise/bench.cu b/cpp/bench/distance/tune_pairwise/bench.cu index 02e2ca9432..87159ab1b1 100644 --- a/cpp/bench/distance/tune_pairwise/bench.cu +++ b/cpp/bench/distance/tune_pairwise/bench.cu @@ -85,7 +85,7 @@ struct throughput_bench : public fixture { // Determine number of blocks that will be launched. This informs the size // of the inputs as well as the grid size. const int num_sms = raft::getMultiProcessorCount(); - const int max_occupancy = get_max_occupancy(distance_op); + const int max_occupancy = get_max_occupancy(); const int occupancy = std::min(p.occupancy, max_occupancy); const int num_blocks = occupancy * num_sms; dim3 grid(num_blocks); @@ -119,7 +119,7 @@ struct throughput_bench : public fixture { IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major}; // Run benchmark - loop_on_state(state, [&]() { launch_kernel(distance_op, kparams, grid, stream); }); + loop_on_state(state, [&]() { launch_kernel(kparams, grid, stream); }); // Report metrics. We don't report flop/s because we do not know for each // distance operation how many flops it costs. For L2_unexp and l1, we can diff --git a/cpp/bench/distance/tune_pairwise/kernel.cu b/cpp/bench/distance/tune_pairwise/kernel.cu index fb0e230ca5..18efdeae60 100644 --- a/cpp/bench/distance/tune_pairwise/kernel.cu +++ b/cpp/bench/distance/tune_pairwise/kernel.cu @@ -21,13 +21,20 @@ namespace raft::bench::distance::tune { -namespace arch = raft::util::arch; +// Distance op +using OpT = raft::distance::detail::ops::lp_unexp_distance_op; +constexpr float metric_arg = 2.0; +OpT distance_op{metric_arg}; -constexpr int vec_len = 1; -using Policy = typename raft::linalg::Policy4x4::Policy; +// Kernel policy +constexpr int vec_len = 1; +using Policy = typename raft::linalg::Policy4x4::Policy; + +// Architecture +namespace arch = raft::util::arch; constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future()); -void launch_kernel(OpT distance_op, pairwise_matrix_params params, dim3 grid, cudaStream_t stream) +void launch_kernel(pairwise_matrix_params params, dim3 grid, cudaStream_t stream) { dim3 block(Policy::Nthreads); // Use .template to disambiguate (See: @@ -65,11 +72,10 @@ void* get_kernel_ptr() DataT, OutT, FinOpT>; - return reinterpret_cast(kernel); } -int get_max_occupancy(OpT distance_op) +int get_max_occupancy() { void* kernel_ptr = get_kernel_ptr(); int max_occupancy; diff --git a/cpp/bench/distance/tune_pairwise/kernel.cuh b/cpp/bench/distance/tune_pairwise/kernel.cuh index b444c5a87a..5da54a343c 100644 --- a/cpp/bench/distance/tune_pairwise/kernel.cuh +++ b/cpp/bench/distance/tune_pairwise/kernel.cuh @@ -28,24 +28,17 @@ using AccT = float; using OutT = DataT; using IdxT = int; -// Distance op -// C++17 inline variable. Used by both tuned_kernel.cu and tune_pairwise.cu -// See: https://open-std.org/JTC1/SC22/WG21/docs/papers/2016/p0386r0.pdf -using OpT = raft::distance::detail::ops::lp_unexp_distance_op; -constexpr float metric_arg = 2.0; -inline const OpT distance_op{metric_arg}; using FinOpT = raft::identity_op; using pairwise_matrix_params = raft::distance::detail::pairwise_matrix_params; // Launches kernel -void launch_kernel(OpT, pairwise_matrix_params, dim3, cudaStream_t); +void launch_kernel(pairwise_matrix_params, dim3, cudaStream_t); // Describes the block size that is decided by the policy void get_block_size(int& m, int& n, int& k); -void* get_kernel_ptr(); -int get_max_occupancy(OpT); +int get_max_occupancy(); } // namespace raft::bench::distance::tune From 6c0d944d26ebb7351fa49e9d810dd4a5d03a1c1b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 20 Mar 2023 19:09:31 +0100 Subject: [PATCH 86/93] Use std::declval --- cpp/include/raft/distance/detail/distance_ops/cutlass.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh index d3eb90467b..7a4fe0ce83 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh @@ -16,7 +16,8 @@ #pragma once -#include +#include // std::false_type +#include // std::declval namespace raft::distance::detail::ops { @@ -34,7 +35,8 @@ struct has_cutlass_op : std::false_type { // Specialization recognizes types that do support CUTLASS template -struct has_cutlass_op> : std::true_type { +struct has_cutlass_op().get_cutlass_op())>> + : std::true_type { }; } // namespace raft::distance::detail::ops From 1df17be90d41806ddb16eddcc8a363cbb3f63db0 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Tue, 21 Mar 2023 13:46:47 +0100 Subject: [PATCH 87/93] distance_ops: Use static shared_mem_size --- cpp/bench/distance/tune_pairwise/kernel.cu | 8 ++------ .../raft/distance/detail/distance_ops/canberra.cuh | 2 +- .../raft/distance/detail/distance_ops/correlation.cuh | 2 +- cpp/include/raft/distance/detail/distance_ops/cosine.cuh | 4 ++-- cpp/include/raft/distance/detail/distance_ops/hamming.cuh | 2 +- .../raft/distance/detail/distance_ops/hellinger.cuh | 2 +- .../raft/distance/detail/distance_ops/jensen_shannon.cuh | 2 +- .../raft/distance/detail/distance_ops/kl_divergence.cuh | 2 +- cpp/include/raft/distance/detail/distance_ops/l1.cuh | 2 +- cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh | 6 +++--- .../raft/distance/detail/distance_ops/l2_unexp.cuh | 2 +- cpp/include/raft/distance/detail/distance_ops/l_inf.cuh | 2 +- .../raft/distance/detail/distance_ops/lp_unexp.cuh | 2 +- .../raft/distance/detail/distance_ops/russel_rao.cuh | 2 +- .../raft/distance/detail/distance_ops/template.cuh | 8 ++++++-- .../raft/distance/detail/pairwise_matrix/kernel_sm60.cuh | 4 ++-- cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh | 8 ++++++-- 17 files changed, 32 insertions(+), 28 deletions(-) diff --git a/cpp/bench/distance/tune_pairwise/kernel.cu b/cpp/bench/distance/tune_pairwise/kernel.cu index 18efdeae60..3112e1ea9a 100644 --- a/cpp/bench/distance/tune_pairwise/kernel.cu +++ b/cpp/bench/distance/tune_pairwise/kernel.cu @@ -37,9 +37,7 @@ constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future( void launch_kernel(pairwise_matrix_params params, dim3 grid, cudaStream_t stream) { dim3 block(Policy::Nthreads); - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = distance_op.template shared_mem_size(); + int smem_size = OpT::shared_mem_size(); // Obtain function pointer to kernel auto kernel = raft::distance::detail::pairwise_matrix_kernel(); + int smem_size = OpT::shared_mem_size(); RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_occupancy, kernel_ptr, Policy::Nthreads, smem_size)); diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 0664eb98ee..eaf37b7e9c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -43,7 +43,7 @@ struct canberra_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 8cbca6ef75..4fc4bb8297 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -61,7 +61,7 @@ struct correlation_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index 9eb84932c5..0883136c9f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -53,7 +53,7 @@ struct cosine_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } @@ -76,7 +76,7 @@ struct cosine_distance_op { } } - constexpr cosine_cutlass_op get_cutlass_op() + constexpr cosine_cutlass_op get_cutlass_op() const { return cosine_cutlass_op(); } diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 2495233dee..475b8892e9 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -45,7 +45,7 @@ struct hamming_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index 0b01a0e967..0489b45854 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -42,7 +42,7 @@ struct hellinger_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index fd2e0f4a3e..e46c63734c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -45,7 +45,7 @@ struct jensen_shannon_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index 705f83ecfc..d083c5ddcc 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -50,7 +50,7 @@ struct kl_divergence_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 5330be4f0c..7e86fd3603 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -41,7 +41,7 @@ struct l1_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 84da07a586..95577fd311 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -54,7 +54,7 @@ struct l2_exp_distance_op { using AccT = AccType; using IdxT = IdxType; - bool sqrt; + const bool sqrt; l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} @@ -67,7 +67,7 @@ struct l2_exp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } @@ -102,7 +102,7 @@ struct l2_exp_distance_op { } } - constexpr l2_exp_cutlass_op get_cutlass_op() + constexpr l2_exp_cutlass_op get_cutlass_op() const { return l2_exp_cutlass_op(sqrt); } diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index f8105462a1..62c212ee8f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -46,7 +46,7 @@ struct l2_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index 108c0cd8ef..88853a3083 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -42,7 +42,7 @@ struct l_inf_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index 1c40adf905..290f4af1b4 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -46,7 +46,7 @@ struct lp_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index 745251771f..63dbf350d1 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -47,7 +47,7 @@ struct russel_rao_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index e4aa281776..4320068361 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -42,8 +42,8 @@ struct template_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template - constexpr size_t shared_mem_size() + template + static constexpr size_t shared_mem_size() { return Policy::SmemSize + TODO; } @@ -59,6 +59,10 @@ struct template_distance_op { { TODO; } + + // If exist, returns a cutlass op that performs the same operation. + // See cosine and l2_exp distance ops for an example. + constexpr l2_exp_cutlass_op get_cutlass_op() const { TODO; } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 9952d6e641..2d0a98862e 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -140,9 +140,9 @@ pairwise_matrix_sm60_wrapper make_pairwise_matri SM_compat_t sm_compat_range) { dim3 block(Policy::Nthreads); - // Use .template to disambiguate (See: + // Use ::template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = distance_op.template shared_mem_size(); + int smem_size = OpT::template shared_mem_size(); // Obtain function pointer to kernel auto kernel = pairwise_matrix_kernel; diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index c249d64af3..4a571c1447 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -226,7 +226,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { if (gridDim.x == 1) { return; } - int smem_offset = distance_op.template shared_mem_size(); + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); Pair* shDumpKV = (Pair*)(&smem[smem_offset]); const int lid = threadIdx.x % warpSize; @@ -345,7 +347,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x DataT * regyn, IdxT gridStrideX, IdxT gridStrideY) { - int smem_offset = distance_op.template shared_mem_size(); + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); Pair* shDumpKV = (Pair*)(&smem[smem_offset]); constexpr uint32_t mask = 0xffffffffu; From 6f4e77df25ad622b8a88b7a9f7d2228a9c72527c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 21 Mar 2023 16:24:11 -0400 Subject: [PATCH 88/93] Pinning dask temporarily because a recent commit broke things --- conda/environments/all_cuda-118_arch-x86_64.yaml | 4 ++-- dependencies.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 47af29d9d2..9c62593de1 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -19,8 +19,8 @@ dependencies: - cxx-compiler - cython>=0.29,<0.30 - dask-cuda=23.04 -- dask>=2023.1.1 -- distributed>=2023.1.1 +- dask<=2023.3.1 +- distributed>=2023.3.1 - doxygen>=1.8.20 - faiss-proc=*=cuda - gcc_linux-64=11.* diff --git a/dependencies.yaml b/dependencies.yaml index 93893d07af..dc726a7d2c 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -173,8 +173,8 @@ dependencies: common: - output_types: [conda] packages: - - dask>=2023.1.1 - - distributed>=2023.1.1 + - dask<=2023.3.1 + - distributed>=2023.3.1 - ucx>=1.13.0 - ucx-py=0.31.* - ucx-proc=*=gpu From 7958a3289f81e5040dd385bf0fe21ace06a2c847 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 21 Mar 2023 20:50:39 -0400 Subject: [PATCH 89/93] Updating raft-dask recipe for now. Not yet able to fix the issue w/ the suggestions from dask team --- conda/recipes/raft-dask/meta.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml index b387f0f47c..a76ff851b1 100644 --- a/conda/recipes/raft-dask/meta.yaml +++ b/conda/recipes/raft-dask/meta.yaml @@ -46,7 +46,7 @@ requirements: run: - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} - cuda-python >=11.7.1,<12.0 - - dask >=2023.1.1 + - dask <=2023.3.1 - dask-cuda ={{ minor_version }} - distributed >=2023.1.1 - joblib >=0.11 From e2f1aa26e4bfc337073958e2ccbd52d6d3b0ed9c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 22 Mar 2023 05:52:12 -0400 Subject: [PATCH 90/93] Pinning dask for wheel --- python/raft-dask/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index 2fe6522f57..e93e22014f 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "numba>=0.49", "joblib>=0.11", "dask-cuda==23.4.*", - "dask>=2023.1.1", + "dask<=2023.3.1", "ucx-py==0.31.*", "distributed>=2023.1.1", "pylibraft==23.4.*", From d57bca4b81f3bc654b6fc3bde1aadd2110aeaf6d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 22 Mar 2023 18:06:56 -0400 Subject: [PATCH 91/93] Revert "Pinning dask for wheel" This reverts commit e2f1aa26e4bfc337073958e2ccbd52d6d3b0ed9c. --- python/raft-dask/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index e93e22014f..2fe6522f57 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "numba>=0.49", "joblib>=0.11", "dask-cuda==23.4.*", - "dask<=2023.3.1", + "dask>=2023.1.1", "ucx-py==0.31.*", "distributed>=2023.1.1", "pylibraft==23.4.*", From b9383e128c25cc92539c59a01b384109e25b854f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 22 Mar 2023 18:07:02 -0400 Subject: [PATCH 92/93] Revert "Updating raft-dask recipe for now. Not yet able to fix the issue w/ the suggestions from dask team" This reverts commit 7958a3289f81e5040dd385bf0fe21ace06a2c847. --- conda/recipes/raft-dask/meta.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml index a76ff851b1..b387f0f47c 100644 --- a/conda/recipes/raft-dask/meta.yaml +++ b/conda/recipes/raft-dask/meta.yaml @@ -46,7 +46,7 @@ requirements: run: - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} - cuda-python >=11.7.1,<12.0 - - dask <=2023.3.1 + - dask >=2023.1.1 - dask-cuda ={{ minor_version }} - distributed >=2023.1.1 - joblib >=0.11 From 4d16e5abb06cba8644b632b9600ea5a038c88444 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 22 Mar 2023 18:07:16 -0400 Subject: [PATCH 93/93] Revert "Pinning dask temporarily because a recent commit broke things" This reverts commit 6f4e77df25ad622b8a88b7a9f7d2228a9c72527c. --- conda/environments/all_cuda-118_arch-x86_64.yaml | 4 ++-- dependencies.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 01bfe9bff7..39f1fef4d5 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -19,8 +19,8 @@ dependencies: - cxx-compiler - cython>=0.29,<0.30 - dask-cuda=23.04 -- dask<=2023.3.1 -- distributed>=2023.3.1 +- dask>=2023.1.1 +- distributed>=2023.1.1 - doxygen>=1.8.20 - gcc_linux-64=11.* - graphviz diff --git a/dependencies.yaml b/dependencies.yaml index e9b817c923..9fbf26bcd1 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -173,8 +173,8 @@ dependencies: common: - output_types: [conda] packages: - - dask<=2023.3.1 - - distributed>=2023.3.1 + - dask>=2023.1.1 + - distributed>=2023.1.1 - ucx>=1.13.0 - ucx-py=0.31.* - ucx-proc=*=gpu