From 3a4ec66f46b38314fbe990f3a12f83c83e75c9b2 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 12 May 2021 22:03:33 +0530 Subject: [PATCH 01/29] Refactor fusedL2NN to use pairwiseDistance class. invert block y/x dir usage in all contraction based kernels so that n is along x dir and m is along y dir blocks --- cpp/include/raft/distance/cosine.cuh | 4 +- cpp/include/raft/distance/euclidean.cuh | 9 +- cpp/include/raft/distance/fused_l2_nn.cuh | 298 ++++-------------- cpp/include/raft/distance/l1.cuh | 4 +- .../raft/distance/pairwise_distance_base.cuh | 30 +- cpp/include/raft/linalg/contractions.cuh | 8 +- 6 files changed, 100 insertions(+), 253 deletions(-) diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/cosine.cuh index 5a212ce64c..a1793777f2 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/cosine.cuh @@ -61,8 +61,8 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(m, KPolicy::Mblk), - raft::ceildiv(n, KPolicy::Nblk)); + dim3 grid(raft::ceildiv(n, KPolicy::Nblk), + raft::ceildiv(m, KPolicy::Mblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/euclidean.cuh index f3f946ad7b..76a721c202 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/euclidean.cuh @@ -60,8 +60,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(m, KPolicy::Mblk), - raft::ceildiv(n, KPolicy::Nblk)); + dim3 grid(raft::ceildiv(n, KPolicy::Nblk), + raft::ceildiv(m, KPolicy::Mblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -229,8 +229,9 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(m, KPolicy::Mblk), - raft::ceildiv(n, KPolicy::Nblk)); + + dim3 grid(raft::ceildiv(n, KPolicy::Nblk), + raft::ceildiv(m, KPolicy::Mblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index 000d856841..2dd32d93bc 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -21,6 +21,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -68,120 +69,52 @@ struct MinReduceOp { DI void init(DataT* out, DataT maxVal) { *out = maxVal; } }; -template > -struct FusedL2NN : public BaseClass { - private: - typedef Policy P; - - const DataT* xn; - const DataT* yn; - OutT* min; - int* mutex; - - DataT *sxNorm, *syNorm; - cub::KeyValuePair* sRed; - - DataT maxVal; - - DataT acc[P::AccRowsPerTh][P::AccColsPerTh]; - - ReduceOpT redOp; - KVPReduceOpT pairRedOp; - -#if (ENABLE_MEMCPY_ASYNC == 1) - DataT zeros[P::Veclen]; - nvcuda::experimental::pipeline pipe; -#endif - - static const DataT Two = (DataT)2.0; - static constexpr size_t SizeAndAlign = P::Veclen * sizeof(DataT); - - public: - DI FusedL2NN(OutT* _min, const DataT* _x, const DataT* _y, const DataT* _xn, - const DataT* _yn, IdxT _m, IdxT _n, IdxT _k, char* _smem, - DataT _mv, int* _mut, ReduceOpT op, KVPReduceOpT pair_op) - : BaseClass(_x, _y, _m, _n, _k, _smem), - xn(_xn), - yn(_yn), - min(_min), - mutex(_mut), - sxNorm((DataT*)_smem), - syNorm(&(sxNorm[P::Mblk])), - sRed((cub::KeyValuePair*)_smem), - maxVal(_mv), - redOp(op), - pairRedOp(pair_op) { -#if (ENABLE_MEMCPY_ASYNC == 1) -#pragma unroll - for (int i = 0; i < P::Veclen; ++i) { - zeros[i] = BaseClass::Zero; - } -#endif - } - - DI void run() { - prolog(); - loop(); - __syncthreads(); // so that we can safely reuse smem - epilog(); - } - - private: - DI void prolog() { - this->ldgXY(0); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero; - } - } - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; +template +__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { + redOp.init(min + tid, maxVal); } +} - DI void loop() { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; - } - accumulate(); // last iteration - } +template +void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, + ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, + bool initOutBuffer, cudaStream_t stream) { + typedef typename linalg::Policy4x4::Policy P; + dim3 grid(raft::ceildiv(n, P::Nblk), + raft::ceildiv(m, P::Mblk)); + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef cub::KeyValuePair KVPair; + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT &acc, DataT & x, DataT & y) { + acc += x * y; + }; + + int *mutex = workspace; + // epilogue operation lambda for final value calculation + auto epilog_lambda = [sqrt, min, mutex, m, n, redOp, pairRedOp] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, DataT * regyn) { + extern __shared__ char smem[]; + KVPair *sRed = (KVPair*)smem; + + ReduceOpT red_op(redOp); + KVPReduceOpT pairRed_op(pairRedOp); - DI void epilog() { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = blockIdx.x * P::Mblk + i; - sxNorm[i] = idx < this->m ? xn[idx] : maxVal; - } - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = blockIdx.y * P::Nblk + i; - syNorm[i] = idx < this->n ? yn[idx] : maxVal; - } - __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + this->accrowid]; - } -#pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + this->acccolid]; - } #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - Two * acc[i][j]; + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; } } - if (Sqrt) { + if (sqrt) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -190,175 +123,82 @@ struct FusedL2NN : public BaseClass { } } } + // reduce - cub::KeyValuePair val[P::AccRowsPerTh]; + KVPair val[P::AccRowsPerTh]; auto lid = raft::laneId(); + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { val[i] = {-1, maxVal}; #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = this->acccolid + j * P::AccThCols + blockIdx.y * P::Nblk; - cub::KeyValuePair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < this->n) + auto tmpkey = acccolid + j * P::AccThCols + blockIdx.x * P::Nblk; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) val[i] = - pairRedOp(this->accrowid + i * P::AccThRows + blockIdx.x * P::Mblk, + pairRed_op(accrowid + i * P::AccThRows + blockIdx.y * P::Mblk, tmp, val[i]); } - __syncthreads(); #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { auto tmpkey = raft::shfl(val[i].key, lid + j); auto tmpvalue = raft::shfl(val[i].value, lid + j); - cub::KeyValuePair tmp = {tmpkey, tmpvalue}; + KVPair tmp = {tmpkey, tmpvalue}; val[i] = - pairRedOp(this->accrowid + i * P::AccThRows + blockIdx.x * P::Mblk, + pairRed_op(accrowid + i * P::AccThRows + blockIdx.y * P::Mblk, tmp, val[i]); } } + __syncthreads(); if (lid % P::AccThCols == 0) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - sRed[i * P::AccThCols + this->accrowid] = val[i]; + sRed[i * P::AccThCols + accrowid] = val[i]; } } __syncthreads(); - updateResults(); - } - - /* - * todo: From Volta onwards see if "coalesced" atomicCAS approach as - * written below helps improve perf - * ``` - * auto tid = threadIdx.x; - * auto rid = IdxT(blockIdx.x) * P::Mblk + tid; - * if (rid < m) { - * auto val = sRed[i]; - * while (atomicCAS(mutex + rid, 0, 1) == 1) - * ; - * __threadfence(); - * redOp(rid, min + rid, val); - * __threadfence(); - * atomicCAS(mutex + rid, 1, 0); - * } - * ``` - */ - DI void updateResults() { // for now have first lane from each warp update a unique output row. This // will resolve hang issues with pre-Volta architectures auto nWarps = blockDim.x / raft::WarpSize; - auto lid = raft::laneId(); - auto ridx = IdxT(blockIdx.x) * P::Mblk; + auto ridx = IdxT(blockIdx.y) * P::Mblk; if (lid == 0) { for (int i = threadIdx.x / raft::WarpSize; i < P::Mblk; i += nWarps) { auto rid = ridx + i; - if (rid < this->m) { + if (rid < m) { auto val = sRed[i]; while (atomicCAS(mutex + rid, 0, 1) == 1) ; __threadfence(); - redOp(rid, min + rid, val); + red_op(rid, min + rid, val); __threadfence(); atomicCAS(mutex + rid, 1, 0); } } } - } - - DI void accumulate() { -#pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - acc[i][j] += this->regx[i][v] * this->regy[j][v]; - } - } - } - } - } - -#if (ENABLE_MEMCPY_ASYNC == 1) - DI void ldgXY(IdxT kidx) { - auto koffset = kidx + this->scolid; - auto offset = - this->pageWr * P::SmemPage + this->srowid * P::SmemStride + this->scolid; - auto* saddrx = this->sx + offset; - for (int i = 0; i < P::LdgPerThX; ++i) { - auto* sax = saddrx + i * P::LdgRowsX * P::SmemStride; - auto* gax = this->x + i * P::LdgRowsX * this->k + koffset; - auto inside = - koffset < this->k && (this->xrowid + i * P::LdgRowsX) < this->m; - __pipeline_memcpy_async(sax, inside ? gax : nullptr, SizeAndAlign, - inside ? 0 : SizeAndAlign); - } - auto* saddry = this->sy + offset; - for (int i = 0; i < P::LdgPerThY; ++i) { - auto* say = saddry + i * P::LdgRowsY * P::SmemStride; - auto* gay = this->y + i * P::LdgRowsY * this->k + koffset; - auto inside = - koffset < this->k && (this->yrowid + i * P::LdgRowsY) < this->n; - __pipeline_memcpy_async(say, inside ? gay : nullptr, SizeAndAlign, - inside ? 0 : SizeAndAlign); - } - pipe.commit(); - } + }; - DI void stsXY() { pipe.wait_prior<0>(); } -#endif // ENABLE_MEMCPY_ASYNC -}; // struct FusedL2NN - -template -__global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2NNkernel( - OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, - IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, - KVPReduceOpT pairRedOp) { - extern __shared__ char smem[]; - FusedL2NN obj( - min, x, y, xn, yn, m, n, k, smem, maxVal, mutex, redOp, pairRedOp); - obj.run(); -} - -template -__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { - redOp.init(min + tid, maxVal); - } -} - -template -void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, - const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, - ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, - bool initOutBuffer, cudaStream_t stream) { - typedef typename linalg::Policy4x4::Policy Policy; - dim3 grid(raft::ceildiv(m, Policy::Mblk), - raft::ceildiv(n, Policy::Nblk)); - dim3 blk(Policy::Nthreads); - auto nblks = raft::ceildiv(m, Policy::Nthreads); - auto maxVal = std::numeric_limits::max(); CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel - <<>>(min, m, maxVal, redOp); + <<>>(min, m, maxVal, redOp); CUDA_CHECK(cudaGetLastError()); } - if (sqrt) { - fusedL2NNkernel - <<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp); - } else { - fusedL2NNkernel - <<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp); - } + + IdxT lda = k, ldb = k, ldd = n; + + auto fin_op = [] __device__(DataT d_val, int g_d_idx) { + return d_val; + }; + + pairwiseDistanceMatKernel + <<>>(x, y, xn, yn, m, n, k, lda, + ldb, ldd, nullptr, core_lambda, + epilog_lambda, fin_op); + CUDA_CHECK(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/l1.cuh index ce4fbb33e3..c6232ae169 100644 --- a/cpp/include/raft/distance/l1.cuh +++ b/cpp/include/raft/distance/l1.cuh @@ -53,8 +53,8 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(m, KPolicy::Mblk), - raft::ceildiv(n, KPolicy::Nblk)); + dim3 grid(raft::ceildiv(n, KPolicy::Nblk), + raft::ceildiv(m, KPolicy::Mblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 4e1605b887..1f49a5bdf6 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -56,6 +56,7 @@ namespace distance { template > struct PairwiseDistances : public BaseClass { @@ -147,13 +148,15 @@ struct PairwiseDistances : public BaseClass { // Load x & y norms required by this threadblock in shmem buffer for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = blockIdx.x * P::Mblk + i; + auto idx = blockIdx.y * P::Mblk + i; sxNorm[i] = idx < this->m ? xn[idx] : 0; } + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = blockIdx.y * P::Nblk + i; + auto idx = blockIdx.x * P::Nblk + i; syNorm[i] = idx < this->n ? yn[idx] : 0; } + __syncthreads(); DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; @@ -171,16 +174,18 @@ struct PairwiseDistances : public BaseClass { epilog_op(acc, nullptr, nullptr); } - IdxT startx = blockIdx.x * P::Mblk + this->accrowid; - IdxT starty = blockIdx.y * P::Nblk + this->acccolid; + if (writeOut) { + IdxT starty = blockIdx.y * P::Mblk + this->accrowid; + IdxT startx = blockIdx.x * P::Nblk + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = startx + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = starty + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - dOutput[rowId * this->n + colId] = fin_op(acc[i][j], 0); + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + dOutput[rowId * this->n + colId] = fin_op(acc[i][j], 0); + } } } } @@ -219,7 +224,8 @@ struct PairwiseDistances : public BaseClass { */ template + typename EpilogueLambda, typename FinalLambda, bool isRowMajor = true, + bool writeOut = true> __global__ __launch_bounds__( Policy::Nthreads, 2) void pairwiseDistanceMatKernel(const DataT* x, const DataT* y, @@ -231,7 +237,7 @@ __global__ __launch_bounds__( extern __shared__ char smem[]; PairwiseDistances + EpilogueLambda, FinalLambda, isRowMajor, writeOut> obj(x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, epilog_op, fin_op); obj.run(); diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index 86d608ea87..c590abb142 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -293,13 +293,13 @@ struct Contractions_NT { pageWr(0), pageRd(0) { if (isRowMajor) { - xrowid = IdxT(blockIdx.x) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.y) * P::Nblk + srowid; + xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; + yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; x = _x + xrowid * lda; y = _y + yrowid * ldb; } else { - xrowid = IdxT(blockIdx.x) * P::Mblk; - yrowid = IdxT(blockIdx.y) * P::Nblk; + xrowid = IdxT(blockIdx.y) * P::Mblk; + yrowid = IdxT(blockIdx.x) * P::Nblk; x = _x + xrowid + srowid * lda; y = _y + yrowid + srowid * ldb; } From 76f9a72b4be77fa2433bc0509002b2c388c422a2 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 17 May 2021 21:18:08 +0530 Subject: [PATCH 02/29] -- add grid stride support to pairwise distance based cosine, l2, l1 kernels. --add launch config generator function to launch optimal grid size kernel for these pairwise dist kernels --- cpp/include/raft/distance/cosine.cuh | 3 +- cpp/include/raft/distance/euclidean.cuh | 7 +- cpp/include/raft/distance/l1.cuh | 3 +- .../raft/distance/pairwise_distance_base.cuh | 87 ++++++++++++++++--- 4 files changed, 78 insertions(+), 22 deletions(-) diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/cosine.cuh index a1793777f2..a214e5cddf 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/cosine.cuh @@ -61,8 +61,7 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(n, KPolicy::Nblk), - raft::ceildiv(m, KPolicy::Mblk)); + dim3 grid = launchConfigGenerator(m, n); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/euclidean.cuh index 76a721c202..69c1b1015c 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/euclidean.cuh @@ -60,8 +60,7 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(n, KPolicy::Nblk), - raft::ceildiv(m, KPolicy::Mblk)); + dim3 grid = launchConfigGenerator(m, n); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -230,8 +229,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(n, KPolicy::Nblk), - raft::ceildiv(m, KPolicy::Mblk)); + dim3 grid = launchConfigGenerator(m, n); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -256,6 +254,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, }; if (isRowMajor) { + pairwiseDistanceMatKernel diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/l1.cuh index c6232ae169..b6439da5a9 100644 --- a/cpp/include/raft/distance/l1.cuh +++ b/cpp/include/raft/distance/l1.cuh @@ -53,8 +53,7 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - dim3 grid(raft::ceildiv(n, KPolicy::Nblk), - raft::ceildiv(m, KPolicy::Mblk)); + dim3 grid = launchConfigGenerator(m, n); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 1f49a5bdf6..4875dbed52 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -16,6 +16,8 @@ #pragma once #include #include +#include +#include namespace raft { namespace distance { @@ -64,14 +66,12 @@ struct PairwiseDistances : public BaseClass { typedef Policy P; const DataT* xn; const DataT* yn; - DataT* sxNorm; - DataT* syNorm; + const DataT *const yBase; OutT* dOutput; char* smem; CoreLambda core_op; EpilogueLambda epilog_op; FinalLambda fin_op; - AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; public: @@ -82,10 +82,9 @@ struct PairwiseDistances : public BaseClass { char* _smem, CoreLambda _core_op, EpilogueLambda _epilog_op, FinalLambda _fin_op) : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), - sxNorm((DataT*)_smem), - syNorm(&(sxNorm[P::Mblk])), xn(_xn), yn(_yn), + yBase(_y), dOutput(_dOutput), smem(_smem), core_op(_core_op), @@ -93,13 +92,54 @@ struct PairwiseDistances : public BaseClass { fin_op(_fin_op) {} DI void run() { - prolog(); - loop(); - epilog(); + + for(auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; + gridStrideY += P::Mblk * gridDim.y) { + for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; + gridStrideX += P::Nblk * gridDim.x) { + prolog(gridStrideX, gridStrideY); + loop(); + epilog(gridStrideX, gridStrideY); + } + } } private: - DI void prolog() { + DI void updateIndicesY() { + const auto stride = P::Nblk * gridDim.x; + if (isRowMajor) { + this->y += stride * this->ldb; + } else { + this->y += stride; + } + this->yrowid += stride; + this->pageWr = 0; + this->pageRd = 0; + } + + DI void updateIndicesXY() { + const auto stride = P::Mblk * gridDim.y; + if (isRowMajor) { + this->x += stride * this->lda; + this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; + this->y = yBase + this->yrowid * this->ldb; + } else { + this->x += stride; + this->yrowid = IdxT(blockIdx.x) * P::Nblk; + this->y = yBase + this->yrowid + this->srowid * this->ldb; + } + this->xrowid += stride; + this->pageWr = 0; + this->pageRd = 0; + } + + DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { + if (gridStrideX > blockIdx.x * P::Nblk) { + updateIndicesY(); + } else if (gridStrideY > blockIdx.y * P::Mblk) { + updateIndicesXY(); + } + this->ldgXY(0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { @@ -142,18 +182,21 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog() { + DI void epilog(IdxT gridStrideX, IdxT gridStrideY) { if (useNorms) { __syncthreads(); // so that we can safely reuse smem + DataT* sxNorm = (DataT*)smem; + DataT* syNorm = (&sxNorm[P::Mblk]); + // Load x & y norms required by this threadblock in shmem buffer for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = blockIdx.y * P::Mblk + i; + auto idx = gridStrideY + i; sxNorm[i] = idx < this->m ? xn[idx] : 0; } for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = blockIdx.x * P::Nblk + i; + auto idx = gridStrideX + i; syNorm[i] = idx < this->n ? yn[idx] : 0; } @@ -175,8 +218,9 @@ struct PairwiseDistances : public BaseClass { } if (writeOut) { - IdxT starty = blockIdx.y * P::Mblk + this->accrowid; - IdxT startx = blockIdx.x * P::Nblk + this->acccolid; + IdxT starty = gridStrideY + this->accrowid; + IdxT startx = gridStrideX + this->acccolid; + #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { auto rowId = starty + i * P::AccThRows; @@ -243,5 +287,20 @@ __global__ __launch_bounds__( obj.run(); } +template +dim3 launchConfigGenerator(IdxT m, IdxT n) { + const auto numSMs = raft::getMultiProcessorCount(); + // multiply by 2 as per launch bounds for pairwise dist kernels. + int minGridSize = numSMs * 2; + dim3 grid; + int yChunks = raft::ceildiv(m, P::Mblk); + int xChunks = raft::ceildiv(n, P::Nblk); + grid.y = yChunks > minGridSize ? minGridSize : yChunks; + grid.x = (minGridSize - grid.y) <= 0 ? 1 : xChunks; + grid.x = grid.x > (minGridSize + 1 - grid.y) ? (minGridSize + 1 - grid.y) : grid.x; + return grid; +} + + }; // namespace distance }; // namespace raft \ No newline at end of file From af890856742afc3a2172328a43adc2bb1319df72 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 19 May 2021 18:04:05 +0530 Subject: [PATCH 03/29] --Add grid stride based fusedL2NN kernel, this gives approx 1.67x speed up over previous version. -- improve logic of the grid launch config generator for x-dir blocks --- cpp/include/raft/distance/cosine.cuh | 3 +- cpp/include/raft/distance/euclidean.cuh | 7 +- cpp/include/raft/distance/fused_l2_nn.cuh | 191 +++++++++++------- cpp/include/raft/distance/l1.cuh | 3 +- .../raft/distance/pairwise_distance_base.cuh | 38 +++- 5 files changed, 158 insertions(+), 84 deletions(-) diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/cosine.cuh index a214e5cddf..6f776591b6 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/cosine.cuh @@ -72,7 +72,8 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn) { + DataT * regxn, DataT * regyn, + IdxT gridStrideX, IdxT gridStrideY) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/euclidean.cuh index 69c1b1015c..3f2aabb0fd 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/euclidean.cuh @@ -71,7 +71,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn) { + DataT * regxn, DataT * regyn, + IdxT gridStrideX, IdxT gridStrideY) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll @@ -241,7 +242,8 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn) { + DataT * regxn, DataT * regyn, + IdxT gridStrideX, IdxT gridStrideY) { if (sqrt) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { @@ -254,7 +256,6 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, }; if (isRowMajor) { - pairwiseDistanceMatKernel diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index 2dd32d93bc..2d91268d52 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -77,34 +77,64 @@ __global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { } } -template -void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, - const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, - ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, - bool initOutBuffer, cudaStream_t stream) { - typedef typename linalg::Policy4x4::Policy P; - dim3 grid(raft::ceildiv(n, P::Nblk), - raft::ceildiv(m, P::Mblk)); - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef cub::KeyValuePair KVPair; +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal(int *mutex, OutT *min, KVPair *val, + ReduceOpT red_op, IdxT m, IdxT gridStrideY) { + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT &acc, DataT & x, DataT & y) { - acc += x * y; - }; +#pragma unroll + for (int j = 0; j < (raft::WarpSize/P::AccThCols); j++) { + if (lid == 0) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + j + i * P::AccThRows ; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + if (j < (raft::WarpSize/P::AccThCols) - 1) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols); + val[i] = {tmpkey, tmpvalue}; + } + } + } +} + +template +__global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( + OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, + IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, + KVPReduceOpT pairRedOp, CoreLambda core_op, FinalLambda fin_op) { + extern __shared__ char smem[]; + + typedef cub::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } - int *mutex = workspace; // epilogue operation lambda for final value calculation - auto epilog_lambda = [sqrt, min, mutex, m, n, redOp, pairRedOp] __device__( + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, DataT * regyn) { - extern __shared__ char smem[]; - KVPair *sRed = (KVPair*)smem; - - ReduceOpT red_op(redOp); + DataT * regxn, DataT * regyn, + IdxT gridStrideX, IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); #pragma unroll @@ -114,7 +144,7 @@ void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; } } - if (sqrt) { + if (Sqrt) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -124,61 +154,81 @@ void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, } } - // reduce - KVPair val[P::AccRowsPerTh]; - auto lid = raft::laneId(); + // intra thread reduce const auto acccolid = threadIdx.x % P::AccThCols; const auto accrowid = threadIdx.x / P::AccThCols; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + blockIdx.x * P::Nblk; + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) - val[i] = - pairRed_op(accrowid + i * P::AccThRows + blockIdx.y * P::Mblk, - tmp, val[i]); + if (tmpkey < n){ + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } } + } + }; + + auto rowEpilog_lambda = [m, mutex, min, pairRedOp, redOp, &val, maxVal] + __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { auto tmpkey = raft::shfl(val[i].key, lid + j); auto tmpvalue = raft::shfl(val[i].value, lid + j); KVPair tmp = {tmpkey, tmpvalue}; - val[i] = - pairRed_op(accrowid + i * P::AccThRows + blockIdx.y * P::Mblk, - tmp, val[i]); + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); } } - __syncthreads(); - if (lid % P::AccThCols == 0) { + + updateReducedVal(mutex, min, val, + red_op, m, gridStrideY); + + // reset the val array. #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - sRed[i * P::AccThCols + accrowid] = val[i]; - } - } - __syncthreads(); - // for now have first lane from each warp update a unique output row. This - // will resolve hang issues with pre-Volta architectures - auto nWarps = blockDim.x / raft::WarpSize; - auto ridx = IdxT(blockIdx.y) * P::Mblk; - if (lid == 0) { - for (int i = threadIdx.x / raft::WarpSize; i < P::Mblk; i += nWarps) { - auto rid = ridx + i; - if (rid < m) { - auto val = sRed[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, val); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; } }; + IdxT lda = k, ldb = k, ldd = n; + PairwiseDistances + obj(x, y, m, n, k, lda, ldb, ldd, xn, yn, nullptr, smem, core_op, + epilog_lambda, fin_op, rowEpilog_lambda); + obj.run(); +} + +template +void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, + ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, + bool initOutBuffer, cudaStream_t stream) { + typedef typename linalg::Policy4x4::Policy P; + + dim3 grid = launchConfigGenerator(m, n); + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef cub::KeyValuePair KVPair; + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT &acc, DataT & x, DataT & y) { + acc += x * y; + }; + CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel @@ -186,18 +236,21 @@ void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, CUDA_CHECK(cudaGetLastError()); } - IdxT lda = k, ldb = k, ldd = n; - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, nullptr, core_lambda, - epilog_lambda, fin_op); + if (sqrt) { + fusedL2NNkernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, + core_lambda, fin_op); + } else { + fusedL2NNkernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, + core_lambda, fin_op); + } CUDA_CHECK(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/l1.cuh index b6439da5a9..e25eba2345 100644 --- a/cpp/include/raft/distance/l1.cuh +++ b/cpp/include/raft/distance/l1.cuh @@ -65,7 +65,8 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn) { return; }; + DataT * regxn, DataT * regyn, + IdxT gridStrideX, IdxT gridStrideY) { return; }; if (isRowMajor) { pairwiseDistanceMatKernel> @@ -72,6 +74,8 @@ struct PairwiseDistances : public BaseClass { CoreLambda core_op; EpilogueLambda epilog_op; FinalLambda fin_op; + rowEpilogueLambda rowEpilog_op; + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; public: @@ -80,7 +84,8 @@ struct PairwiseDistances : public BaseClass { IdxT _k, IdxT _lda, IdxT _ldb, IdxT _ldd, const DataT* _xn, const DataT* _yn, OutT* _dOutput, char* _smem, CoreLambda _core_op, - EpilogueLambda _epilog_op, FinalLambda _fin_op) + EpilogueLambda _epilog_op, FinalLambda _fin_op, + rowEpilogueLambda _rowEpilog_op) : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), xn(_xn), yn(_yn), @@ -89,7 +94,8 @@ struct PairwiseDistances : public BaseClass { smem(_smem), core_op(_core_op), epilog_op(_epilog_op), - fin_op(_fin_op) {} + fin_op(_fin_op), + rowEpilog_op(_rowEpilog_op) {} DI void run() { @@ -101,6 +107,7 @@ struct PairwiseDistances : public BaseClass { loop(); epilog(gridStrideX, gridStrideY); } + rowEpilog_op(gridStrideY); } } @@ -212,9 +219,9 @@ struct PairwiseDistances : public BaseClass { regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } - epilog_op(acc, regxn, regyn); + epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); } else { - epilog_op(acc, nullptr, nullptr); + epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); } if (writeOut) { @@ -266,10 +273,12 @@ struct PairwiseDistances : public BaseClass { * @param epilog_op the epilogue lambda * @param fin_op the final gemm epilogue lambda */ + template + typename EpilogueLambda, + typename FinalLambda, + bool isRowMajor = true, bool writeOut = true> __global__ __launch_bounds__( Policy::Nthreads, 2) void pairwiseDistanceMatKernel(const DataT* x, const DataT* y, @@ -279,11 +288,13 @@ __global__ __launch_bounds__( EpilogueLambda epilog_op, FinalLambda fin_op) { extern __shared__ char smem[]; + auto rowEpilog = [] __device__ (IdxT starty) { return; }; PairwiseDistances + EpilogueLambda, FinalLambda, decltype(rowEpilog), + isRowMajor, writeOut> obj(x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, - epilog_op, fin_op); + epilog_op, fin_op, rowEpilog); obj.run(); } @@ -297,7 +308,14 @@ dim3 launchConfigGenerator(IdxT m, IdxT n) { int xChunks = raft::ceildiv(n, P::Nblk); grid.y = yChunks > minGridSize ? minGridSize : yChunks; grid.x = (minGridSize - grid.y) <= 0 ? 1 : xChunks; - grid.x = grid.x > (minGridSize + 1 - grid.y) ? (minGridSize + 1 - grid.y) : grid.x; + if (grid.x != 1) { + int i = 1; + while(grid.y * i < minGridSize) { + i++; + } + grid.x = i >= xChunks ? xChunks : i; + } + return grid; } From 9c71c4ac753bb4852fb1f21500aba9bfac931404 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 19 May 2021 19:29:37 +0530 Subject: [PATCH 04/29] Add note on reason to use thread 0 from each warp to write final reduced val for pre-volta arch --- cpp/include/raft/distance/fused_l2_nn.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index 2d91268d52..8a3ac19a0e 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -86,6 +86,8 @@ DI void updateReducedVal(int *mutex, OutT *min, KVPair *val, const auto lid = threadIdx.x % raft::WarpSize; const auto accrowid = threadIdx.x / P::AccThCols; + // for now have first lane from each warp update a unique output row. This + // will resolve hang issues with pre-Volta architectures #pragma unroll for (int j = 0; j < (raft::WarpSize/P::AccThCols); j++) { if (lid == 0) { From 4d76b57676c2de8098065d9e98b1ff347e67c7dc Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 19 May 2021 20:41:15 +0530 Subject: [PATCH 05/29] fix clangformat and copyright year --- cpp/include/raft/distance/cosine.cuh | 4 +- cpp/include/raft/distance/euclidean.cuh | 8 +-- cpp/include/raft/distance/fused_l2_nn.cuh | 67 ++++++++++--------- cpp/include/raft/distance/l1.cuh | 4 +- .../raft/distance/pairwise_distance_base.cuh | 25 +++---- 5 files changed, 53 insertions(+), 55 deletions(-) diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/cosine.cuh index 6f776591b6..cc5bffd1f4 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/cosine.cuh @@ -72,8 +72,8 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn, - IdxT gridStrideX, IdxT gridStrideY) { + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/euclidean.cuh index 3f2aabb0fd..bdf2d262d5 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/euclidean.cuh @@ -71,8 +71,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn, - IdxT gridStrideX, IdxT gridStrideY) { + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll @@ -242,8 +242,8 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn, - IdxT gridStrideX, IdxT gridStrideY) { + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { if (sqrt) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index 8a3ac19a0e..da40f896e0 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include namespace raft { namespace distance { @@ -80,20 +80,20 @@ __global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { // TODO: specialize this function for MinAndDistanceReduceOp // with atomicCAS of 64 bit which will eliminate mutex and shfls template -DI void updateReducedVal(int *mutex, OutT *min, KVPair *val, - ReduceOpT red_op, IdxT m, IdxT gridStrideY) { + typename ReduceOpT> +DI void updateReducedVal(int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, + IdxT m, IdxT gridStrideY) { const auto lid = threadIdx.x % raft::WarpSize; const auto accrowid = threadIdx.x / P::AccThCols; // for now have first lane from each warp update a unique output row. This // will resolve hang issues with pre-Volta architectures #pragma unroll - for (int j = 0; j < (raft::WarpSize/P::AccThCols); j++) { + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { if (lid == 0) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + j + i * P::AccThRows ; + auto rid = gridStrideY + accrowid + j + i * P::AccThRows; if (rid < m) { auto value = val[i]; while (atomicCAS(mutex + rid, 0, 1) == 1) @@ -105,7 +105,7 @@ DI void updateReducedVal(int *mutex, OutT *min, KVPair *val, } } } - if (j < (raft::WarpSize/P::AccThCols) - 1) { + if (j < (raft::WarpSize / P::AccThCols) - 1) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols); @@ -116,9 +116,9 @@ DI void updateReducedVal(int *mutex, OutT *min, KVPair *val, } } -template +template __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, @@ -135,8 +135,8 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( // epilogue operation lambda for final value calculation auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, DataT * regyn, - IdxT gridStrideX, IdxT gridStrideY) { + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); #pragma unroll @@ -163,17 +163,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n){ - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + if (tmpkey < n) { + val[i] = + pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); } } } }; - auto rowEpilog_lambda = [m, mutex, min, pairRedOp, redOp, &val, maxVal] - __device__(IdxT gridStrideY) { + auto rowEpilog_lambda = [m, mutex, min, pairRedOp, redOp, &val, + maxVal] __device__(IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); ReduceOpT red_op(redOp); @@ -188,12 +189,13 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( auto tmpkey = raft::shfl(val[i].key, lid + j); auto tmpvalue = raft::shfl(val[i].value, lid + j); KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + val[i] = + pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); } } - updateReducedVal(mutex, min, val, - red_op, m, gridStrideY); + updateReducedVal(mutex, min, val, red_op, + m, gridStrideY); // reset the val array. #pragma unroll @@ -204,9 +206,8 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( IdxT lda = k, ldb = k, ldd = n; PairwiseDistances + decltype(epilog_lambda), FinalLambda, + decltype(rowEpilog_lambda), true, false> obj(x, y, m, n, k, lda, ldb, ldd, xn, yn, nullptr, smem, core_op, epilog_lambda, fin_op, rowEpilog_lambda); obj.run(); @@ -227,7 +228,7 @@ void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, typedef cub::KeyValuePair KVPair; // Accumulation operation lambda - auto core_lambda = [] __device__(DataT &acc, DataT & x, DataT & y) { + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; @@ -238,20 +239,20 @@ void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, CUDA_CHECK(cudaGetLastError()); } - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { - return d_val; - }; + auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; if (sqrt) { fusedL2NNkernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, - core_lambda, fin_op); + decltype(core_lambda), decltype(fin_op)> + <<>>(min, x, y, xn, yn, m, n, k, maxVal, + workspace, redOp, pairRedOp, + core_lambda, fin_op); } else { fusedL2NNkernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, - core_lambda, fin_op); + decltype(core_lambda), decltype(fin_op)> + <<>>(min, x, y, xn, yn, m, n, k, maxVal, + workspace, redOp, pairRedOp, + core_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/l1.cuh index e25eba2345..734a17606d 100644 --- a/cpp/include/raft/distance/l1.cuh +++ b/cpp/include/raft/distance/l1.cuh @@ -65,8 +65,8 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn, - IdxT gridStrideX, IdxT gridStrideY) { return; }; + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { return; }; if (isRowMajor) { pairwiseDistanceMatKernel -#include #include #include +#include +#include namespace raft { namespace distance { @@ -68,7 +68,7 @@ struct PairwiseDistances : public BaseClass { typedef Policy P; const DataT* xn; const DataT* yn; - const DataT *const yBase; + const DataT* const yBase; OutT* dOutput; char* smem; CoreLambda core_op; @@ -98,11 +98,10 @@ struct PairwiseDistances : public BaseClass { rowEpilog_op(_rowEpilog_op) {} DI void run() { - - for(auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; - gridStrideY += P::Mblk * gridDim.y) { + for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; + gridStrideY += P::Mblk * gridDim.y) { for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { + gridStrideX += P::Nblk * gridDim.x) { prolog(gridStrideX, gridStrideY); loop(); epilog(gridStrideX, gridStrideY); @@ -276,9 +275,8 @@ struct PairwiseDistances : public BaseClass { template + typename EpilogueLambda, typename FinalLambda, bool isRowMajor = true, + bool writeOut = true> __global__ __launch_bounds__( Policy::Nthreads, 2) void pairwiseDistanceMatKernel(const DataT* x, const DataT* y, @@ -288,7 +286,7 @@ __global__ __launch_bounds__( EpilogueLambda epilog_op, FinalLambda fin_op) { extern __shared__ char smem[]; - auto rowEpilog = [] __device__ (IdxT starty) { return; }; + auto rowEpilog = [] __device__(IdxT starty) { return; }; PairwiseDistances +template dim3 launchConfigGenerator(IdxT m, IdxT n) { const auto numSMs = raft::getMultiProcessorCount(); // multiply by 2 as per launch bounds for pairwise dist kernels. @@ -310,7 +308,7 @@ dim3 launchConfigGenerator(IdxT m, IdxT n) { grid.x = (minGridSize - grid.y) <= 0 ? 1 : xChunks; if (grid.x != 1) { int i = 1; - while(grid.y * i < minGridSize) { + while (grid.y * i < minGridSize) { i++; } grid.x = i >= xChunks ? xChunks : i; @@ -319,6 +317,5 @@ dim3 launchConfigGenerator(IdxT m, IdxT n) { return grid; } - }; // namespace distance }; // namespace raft \ No newline at end of file From 4ada29e3144f9bab9f67710ab0b23c2fc1c62f21 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 20 May 2021 17:07:48 +0530 Subject: [PATCH 06/29] --Add additional Mblk + Nblk shmem for storing norms, and reuse xNorm for subsequent gridStrideX variations. this overall improves perf of fusedL2NN to 1.85x over previous version. --Also remove checking keys only check values in fusedL2nn test case, as it may happen a row has multiple keys with same min val --- cpp/include/raft/distance/cosine.cuh | 14 ++++++++------ cpp/include/raft/distance/euclidean.cuh | 14 ++++++++------ cpp/include/raft/distance/fused_l2_nn.cuh | 13 +++++++------ .../raft/distance/pairwise_distance_base.cuh | 12 ++++++------ cpp/test/distance/fused_l2_nn.cu | 2 -- 5 files changed, 29 insertions(+), 26 deletions(-) diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/cosine.cuh index cc5bffd1f4..94d8fa07b5 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/cosine.cuh @@ -83,20 +83,22 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, } }; + size_t shmemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + <<>>(x, y, xn, yn, m, n, k, lda, ldb, ldd, + dOutput, core_lambda, epilog_lambda, + fin_op); } else { pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + <<>>(x, y, xn, yn, m, n, k, lda, ldb, ldd, + dOutput, core_lambda, epilog_lambda, + fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/euclidean.cuh index bdf2d262d5..70526a8797 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/euclidean.cuh @@ -91,20 +91,22 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, } }; + size_t shmemSize = + KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + <<>>(x, y, xn, yn, m, n, k, lda, ldb, ldd, + dOutput, core_lambda, epilog_lambda, + fin_op); } else { pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, - ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + <<>>(x, y, xn, yn, m, n, k, lda, ldb, ldd, + dOutput, core_lambda, epilog_lambda, + fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index da40f896e0..d85540497b 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -241,18 +241,19 @@ void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; + size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { fusedL2NNkernel - <<>>(min, x, y, xn, yn, m, n, k, maxVal, - workspace, redOp, pairRedOp, - core_lambda, fin_op); + <<>>(min, x, y, xn, yn, m, n, k, maxVal, + workspace, redOp, pairRedOp, + core_lambda, fin_op); } else { fusedL2NNkernel - <<>>(min, x, y, xn, yn, m, n, k, maxVal, - workspace, redOp, pairRedOp, - core_lambda, fin_op); + <<>>(min, x, y, xn, yn, m, n, k, maxVal, + workspace, redOp, pairRedOp, + core_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 605fb9e428..3c6431a999 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -190,15 +190,15 @@ struct PairwiseDistances : public BaseClass { DI void epilog(IdxT gridStrideX, IdxT gridStrideY) { if (useNorms) { - __syncthreads(); // so that we can safely reuse smem - - DataT* sxNorm = (DataT*)smem; + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); DataT* syNorm = (&sxNorm[P::Mblk]); // Load x & y norms required by this threadblock in shmem buffer - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; + if (gridStrideX == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = gridStrideY + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } } for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index d4e39a0b5e..4573a070b6 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -164,7 +164,6 @@ struct CompareApproxAbsKVP { typedef typename cub::KeyValuePair KVP; CompareApproxAbsKVP(T eps_) : eps(eps_) {} bool operator()(const KVP &a, const KVP &b) const { - if (a.key != b.key) return false; T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); T m = std::max(raft::abs(a.value), raft::abs(b.value)); T ratio = m >= eps ? diff / m : diff; @@ -179,7 +178,6 @@ template struct CompareExactKVP { typedef typename cub::KeyValuePair KVP; bool operator()(const KVP &a, const KVP &b) const { - if (a.key != b.key) return false; if (a.value != b.value) return false; return true; } From 2e804c2c099eb63517e2c506b997301d6f9d5280 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 24 May 2021 18:57:05 +0530 Subject: [PATCH 07/29] Use cudaOccupancyMaxActiveBlocksPerSM instead of hard-coded launch bound in launchConfigGenerator. --Use constexpr in shmemSize. --- cpp/include/raft/distance/cosine.cuh | 31 +++++---- cpp/include/raft/distance/euclidean.cuh | 68 +++++++++++-------- cpp/include/raft/distance/fused_l2_nn.cuh | 29 ++++---- cpp/include/raft/distance/l1.cuh | 32 +++++---- .../raft/distance/pairwise_distance_base.cuh | 11 +-- 5 files changed, 101 insertions(+), 70 deletions(-) diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/cosine.cuh index 94d8fa07b5..ed9bd28b7f 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/cosine.cuh @@ -61,7 +61,6 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; - dim3 grid = launchConfigGenerator(m, n); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -83,22 +82,26 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, } }; - size_t shmemSize = + constexpr size_t shmemSize = KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, ldb, ldd, - dOutput, core_lambda, epilog_lambda, - fin_op); + auto cosineRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); + cosineRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } else { - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, ldb, ldd, - dOutput, core_lambda, epilog_lambda, - fin_op); + auto cosineColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); + cosineColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/euclidean.cuh index 70526a8797..484da0e5bf 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/euclidean.cuh @@ -60,7 +60,6 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; - dim3 grid = launchConfigGenerator(m, n); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -91,22 +90,29 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, } }; - size_t shmemSize = + constexpr size_t shmemSize = KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, ldb, ldd, - dOutput, core_lambda, epilog_lambda, - fin_op); + auto euclideanExpRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); + + euclideanExpRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } else { - pairwiseDistanceMatKernel - <<>>(x, y, xn, yn, m, n, k, lda, ldb, ldd, - dOutput, core_lambda, epilog_lambda, - fin_op); + auto euclideanExpColMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); + euclideanExpColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } CUDA_CHECK(cudaGetLastError()); @@ -232,7 +238,6 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - dim3 grid = launchConfigGenerator(m, n); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -258,19 +263,28 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, }; if (isRowMajor) { - pairwiseDistanceMatKernel - <<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto euclideanUnExpRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + euclideanUnExpRowMajor); + + euclideanUnExpRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { - pairwiseDistanceMatKernel - <<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto euclideanUnExpColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + euclideanUnExpColMajor); + + euclideanUnExpColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index d85540497b..b96a536e38 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -221,7 +221,6 @@ void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, bool initOutBuffer, cudaStream_t stream) { typedef typename linalg::Policy4x4::Policy P; - dim3 grid = launchConfigGenerator(m, n); dim3 blk(P::Nthreads); auto nblks = raft::ceildiv(m, P::Nthreads); constexpr auto maxVal = std::numeric_limits::max(); @@ -241,19 +240,25 @@ void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + constexpr size_t shmemSize = + P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { - fusedL2NNkernel - <<>>(min, x, y, xn, yn, m, n, k, maxVal, - workspace, redOp, pairRedOp, - core_lambda, fin_op); + auto fusedL2NNSqrt = + fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); + + fusedL2NNSqrt<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, + core_lambda, fin_op); } else { - fusedL2NNkernel - <<>>(min, x, y, xn, yn, m, n, k, maxVal, - workspace, redOp, pairRedOp, - core_lambda, fin_op); + auto fusedL2NN = + fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); + fusedL2NN<<>>(min, x, y, xn, yn, m, n, k, + maxVal, workspace, redOp, + pairRedOp, core_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/l1.cuh index 734a17606d..6ab084f041 100644 --- a/cpp/include/raft/distance/l1.cuh +++ b/cpp/include/raft/distance/l1.cuh @@ -53,7 +53,6 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - dim3 grid = launchConfigGenerator(m, n); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -69,19 +68,26 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, IdxT gridStrideY) { return; }; if (isRowMajor) { - pairwiseDistanceMatKernel - <<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto l1RowMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, l1RowMajor); + + l1RowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } else { - pairwiseDistanceMatKernel - <<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + auto l1ColMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, l1ColMajor); + l1ColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 3c6431a999..aa8c93a0b9 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -296,12 +296,15 @@ __global__ __launch_bounds__( obj.run(); } -template -dim3 launchConfigGenerator(IdxT m, IdxT n) { +template +dim3 launchConfigGenerator(IdxT m, IdxT n, size_t sMemSize, T func) { const auto numSMs = raft::getMultiProcessorCount(); - // multiply by 2 as per launch bounds for pairwise dist kernels. - int minGridSize = numSMs * 2; + int numBlocksPerSm = 0; dim3 grid; + + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &numBlocksPerSm, func, P::Nthreads, sMemSize)); + int minGridSize = numSMs * numBlocksPerSm; int yChunks = raft::ceildiv(m, P::Mblk); int xChunks = raft::ceildiv(n, P::Nblk); grid.y = yChunks > minGridSize ? minGridSize : yChunks; From 69b316df52d337d6bd2480ba7544e298a5a22565 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 1 Jun 2021 14:27:23 +0530 Subject: [PATCH 08/29] initialize regx and regy during each prolog call --- cpp/include/raft/distance/pairwise_distance_base.cuh | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index aa8c93a0b9..25d499fa88 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -18,6 +18,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -147,13 +148,22 @@ struct PairwiseDistances : public BaseClass { } this->ldgXY(0); + typedef TxN_t VecType; + VecType zeros; + zeros.fill(BaseClass::Zero); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { + zeros.store(&(this->regx[i][0]), 0); #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { acc[i][j] = BaseClass::Zero; } } +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + zeros.store(&(this->regy[j][0]), 0); + } + this->stsXY(); __syncthreads(); this->pageWr ^= 1; From 6a64b7a10a5af5d4819d7f399a7d95df58141120 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 1 Jun 2021 19:23:34 +0530 Subject: [PATCH 09/29] Add chebyshev distance metric support --- cpp/include/raft/distance/chebyshev.cuh | 156 ++++++++++++++++++++++++ cpp/include/raft/distance/distance.cuh | 18 +++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_chebyshev.cu | 68 +++++++++++ cpp/test/distance/distance_base.cuh | 19 ++- 5 files changed, 256 insertions(+), 6 deletions(-) create mode 100644 cpp/include/raft/distance/chebyshev.cuh create mode 100644 cpp/test/distance/dist_chebyshev.cu diff --git a/cpp/include/raft/distance/chebyshev.cuh b/cpp/include/raft/distance/chebyshev.cuh new file mode 100644 index 0000000000..fe50501708 --- /dev/null +++ b/cpp/include/raft/distance/chebyshev.cuh @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft { +namespace distance { + +/** + * @brief the Chebyshev distance matrix calculation implementer + * It computes the following equation: cij = max(cij, op(ai-bj)) + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major, + false for column major + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] pD output matrix + * @param fin_op the final gemm epilogue lambda + */ +template +static void chebyshevImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + const auto diff = raft::L1Op()(x - y); + acc = raft::myMax(acc, diff); + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { return; }; + + if (isRowMajor) { + auto chebyshevRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + chebyshevRowMajor); + + chebyshevRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { + auto chebyshevColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + chebyshevColMajor); + chebyshevColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void chebyshev(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + chebyshevImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + chebyshevImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else { + chebyshevImpl( + x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + } +} + +/** + * @brief the chebyshev distance matrix calculation + * It computes the following equation: cij = max(cij, op(ai-bj)) + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final element-wise epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + */ +template +void chebyshevImpl(int m, int n, int k, const InType *pA, const InType *pB, + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + chebyshevOutType; + Index_ lda, ldb, ldd; + chebyshevOutType *pDcast = reinterpret_cast(pD); + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + chebyshev( + m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + } else { + lda = n, ldb = m, ldd = m; + chebyshev( + n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + } +} +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 579b3bb446..0db141798d 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -116,6 +117,19 @@ struct DistanceImpl +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor) { + raft::distance::chebyshevImpl(m, n, k, x, y, dist, fin_op, stream, + isRowMajor); + } +}; + } // anonymous namespace /** @@ -288,6 +302,10 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, raft::distance::DistanceType::L2SqrtUnexpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; + case raft::distance::DistanceType::Linf: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6496ac26c6..0044431c09 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -19,6 +19,7 @@ add_executable(test_raft test/cudart_utils.cpp test/cluster_solvers.cu test/distance/dist_adj.cu + test/distance/dist_chebyshev.cu test/distance/dist_cos.cu test/distance/dist_euc_exp.cu test/distance/dist_euc_unexp.cu diff --git a/cpp/test/distance/dist_chebyshev.cu b/cpp/test/distance/dist_chebyshev.cu new file mode 100644 index 0000000000..6a2b02863a --- /dev/null +++ b/cpp/test/distance/dist_chebyshev.cu @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceLinf + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceLinf DistanceLinfF; +TEST_P(DistanceLinfF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceLinf DistanceLinfD; +TEST_P(DistanceLinfD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index d6f06c186a..7ac7be10d9 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -47,9 +47,11 @@ __global__ void naiveDistanceKernel(DataType *dist, const DataType *x, } template -__global__ void naiveL1DistanceKernel(DataType *dist, const DataType *x, - const DataType *y, int m, int n, int k, - bool isRowMajor) { +__global__ void naiveL1_LinfDistanceKernel(DataType *dist, const DataType *x, + const DataType *y, int m, int n, + int k, + raft::distance::DistanceType type, + bool isRowMajor) { int midx = threadIdx.x + blockIdx.x * blockDim.x; int nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) { @@ -63,7 +65,11 @@ __global__ void naiveL1DistanceKernel(DataType *dist, const DataType *x, auto a = x[xidx]; auto b = y[yidx]; auto diff = (a > b) ? (a - b) : (b - a); - acc += diff; + if (type == raft::distance::DistanceType::Linf) { + acc = raft::myMax(acc, diff); + } else { + acc += diff; + } } int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; @@ -109,9 +115,10 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); switch (type) { + case raft::distance::DistanceType::Linf: case raft::distance::DistanceType::L1: - naiveL1DistanceKernel - <<>>(dist, x, y, m, n, k, isRowMajor); + naiveL1_LinfDistanceKernel + <<>>(dist, x, y, m, n, k, type, isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: case raft::distance::DistanceType::L2Unexpanded: From 9a30a872b63a807faafce2acf0faca123c2ebbbc Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 1 Jun 2021 21:25:21 +0530 Subject: [PATCH 10/29] initialize ldgX, ldgY in prolog --- .../raft/distance/pairwise_distance_base.cuh | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 25d499fa88..503397bac9 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -147,13 +147,23 @@ struct PairwiseDistances : public BaseClass { updateIndicesXY(); } - this->ldgXY(0); typedef TxN_t VecType; VecType zeros; zeros.fill(BaseClass::Zero); +#pragma unroll + for (int j = 0; j < P::LdgPerThX; ++j) { + zeros.store(&this->ldgDataX[j][0], 0); + } +#pragma unroll + for (int j = 0; j < P::LdgPerThY; ++j) { + zeros.store(&this->ldgDataY[j][0], 0); + } + + this->ldgXY(0); + #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - zeros.store(&(this->regx[i][0]), 0); + zeros.store(&this->regx[i][0], 0); #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { acc[i][j] = BaseClass::Zero; @@ -161,7 +171,7 @@ struct PairwiseDistances : public BaseClass { } #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - zeros.store(&(this->regy[j][0]), 0); + zeros.store(&this->regy[j][0], 0); } this->stsXY(); From 21577a4a1d1156f8ba3ddb44aa6f3f1694f5a41c Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 2 Jun 2021 18:36:14 +0530 Subject: [PATCH 11/29] Add hellinger distance metric support --- cpp/include/raft/distance/distance.cuh | 19 +++ cpp/include/raft/distance/hellinger.cuh | 175 ++++++++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_hellinger.cu | 69 ++++++++++ cpp/test/distance/distance_base.cuh | 35 +++++ 5 files changed, 299 insertions(+) create mode 100644 cpp/include/raft/distance/hellinger.cuh create mode 100644 cpp/test/distance/dist_hellinger.cu diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 0db141798d..9c302811d9 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -130,6 +131,19 @@ struct DistanceImpl +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor) { + raft::distance::hellingerImpl(m, n, k, x, y, dist, fin_op, stream, + isRowMajor); + } +}; + } // anonymous namespace /** @@ -306,6 +320,11 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, pairwise_distance_impl( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; + case raft::distance::HellingerExpanded: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh new file mode 100644 index 0000000000..a8d49cbe07 --- /dev/null +++ b/cpp/include/raft/distance/hellinger.cuh @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft { +namespace distance { + +/** + * @brief the Hellinger distance matrix using the expanded form: + * It computes the following equation: + cij = sqrt(1 - sum(sqrt(x_k) * sqrt(y_k))) + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major, + false for column major + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] pD output matrix + * @param fin_op the final gemm epilogue lambda + */ +template +static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + const auto rectifierX = (!signbit(x)); + const auto rectifierY = (!signbit(y)); + const auto product = + raft::mySqrt(rectifierX * x) * raft::mySqrt(rectifierY * y); + acc += product; + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { +#pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + const auto finalVal = (1 - acc[i][j]); + const auto rectifier = (!signbit(finalVal)); + ; + acc[i][j] = raft::mySqrt(rectifier * finalVal); + } + } + }; + + if (isRowMajor) { + auto hellingerRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + hellingerRowMajor); + + hellingerRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { + auto hellingerColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + hellingerColMajor); + hellingerColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void hellinger(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + hellingerImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + hellingerImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else { + hellingerImpl( + x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + } +} + +/** + * @brief the Hellinger distance matrix calculation + * It computes the following equation: + sqrt(1 - sum(sqrt(x_k) * sqrt(y_k))) + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final element-wise epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + */ +template +void hellingerImpl(int m, int n, int k, const InType *pA, const InType *pB, + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + hellingerOutType; + Index_ lda, ldb, ldd; + hellingerOutType *pDcast = reinterpret_cast(pD); + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + hellinger( + m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + + } else { + lda = n, ldb = m, ldd = m; + hellinger( + n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + } +} +} // namespace distance +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0044431c09..9d08dcc7ec 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -23,6 +23,7 @@ add_executable(test_raft test/distance/dist_cos.cu test/distance/dist_euc_exp.cu test/distance/dist_euc_unexp.cu + test/distance/dist_hellinger.cu test/distance/dist_l1.cu test/distance/fused_l2_nn.cu test/eigen_solvers.cu diff --git a/cpp/test/distance/dist_hellinger.cu b/cpp/test/distance/dist_hellinger.cu new file mode 100644 index 0000000000..39dc7aaeff --- /dev/null +++ b/cpp/test/distance/dist_hellinger.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceHellingerExp + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHellingerExp DistanceHellingerExpF; +TEST_P(DistanceHellingerExpF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHellingerExp DistanceHellingerExpD; +TEST_P(DistanceHellingerExpD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 7ac7be10d9..8067984c70 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -107,6 +107,37 @@ __global__ void naiveCosineDistanceKernel(DataType *dist, const DataType *x, (DataType)1.0 - acc_ab / (raft::mySqrt(acc_a) * raft::mySqrt(acc_b)); } +template +__global__ void naiveHellingerDistanceKernel(DataType *dist, const DataType *x, + const DataType *y, int m, int n, + int k, bool isRowMajor) { + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { + return; + } + + DataType acc_ab = DataType(0); + + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + auto rectifierX = (!signbit(a)); + auto rectifierY = (!signbit(b)); + acc_ab += raft::mySqrt(rectifierX * a) * raft::mySqrt(rectifierY * b); + } + + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + acc_ab = 1 - acc_ab; + auto rectifier = (!signbit(acc_ab)); + dist[outidx] = raft::mySqrt(rectifier * acc_ab); +} + template void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, raft::distance::DistanceType type, @@ -131,6 +162,10 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, naiveCosineDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::HellingerExpanded: + naiveHellingerDistanceKernel<<>>(dist, x, y, m, n, k, + isRowMajor); + break; default: FAIL() << "should be here\n"; } From 9c4d5a0ecf6710540283a78e4b1bd5fa57e9572e Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 3 Jun 2021 20:12:00 +0530 Subject: [PATCH 12/29] add syncthreads post epilog calc for non-norm distance metrics to make sure next grid stride doesn't pollute shmem before completion of this calculation --- cpp/include/raft/distance/pairwise_distance_base.cuh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 503397bac9..69e0ad0bf1 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -241,6 +241,10 @@ struct PairwiseDistances : public BaseClass { epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); } else { epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + // This sync is needed for making sure next grid stride of + // non-norm based metrics doesn't make shmem dirty until current + // this iteration is complete. + __syncthreads(); } if (writeOut) { From 4fb00e6b040578805d477986dbf8073feded9598 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 4 Jun 2021 12:12:43 +0530 Subject: [PATCH 13/29] remove syncthreads in epilog and instead use ping-pong buffers in next iteration of grid stride --- .../raft/distance/pairwise_distance_base.cuh | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 69e0ad0bf1..5d216120dc 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -120,8 +120,11 @@ struct PairwiseDistances : public BaseClass { this->y += stride; } this->yrowid += stride; - this->pageWr = 0; - this->pageRd = 0; + // This is needed for making sure next grid stride of + // non-norm based metrics uses previously accumulated buffer so + // it doesn't make shmem dirty until previous iteration + // is complete. + this->pageRd = this->pageWr; } DI void updateIndicesXY() { @@ -136,8 +139,7 @@ struct PairwiseDistances : public BaseClass { this->y = yBase + this->yrowid + this->srowid * this->ldb; } this->xrowid += stride; - this->pageWr = 0; - this->pageRd = 0; + this->pageRd = this->pageWr; } DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { @@ -241,10 +243,6 @@ struct PairwiseDistances : public BaseClass { epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); } else { epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); - // This sync is needed for making sure next grid stride of - // non-norm based metrics doesn't make shmem dirty until current - // this iteration is complete. - __syncthreads(); } if (writeOut) { From 5346232047f5025f298274dbf9cb98c92445fa82 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 4 Jun 2021 20:38:42 +0530 Subject: [PATCH 14/29] Add minkowski distance metric --- cpp/include/raft/distance/distance.cuh | 58 +++++--- cpp/include/raft/distance/hellinger.cuh | 5 +- cpp/include/raft/distance/minkowski.cuh | 170 ++++++++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_minkowski.cu | 69 ++++++++++ cpp/test/distance/distance_base.cuh | 61 +++++++-- 6 files changed, 329 insertions(+), 35 deletions(-) create mode 100644 cpp/include/raft/distance/minkowski.cuh create mode 100644 cpp/test/distance/dist_minkowski.cu diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 9c302811d9..90452d5d94 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace raft { @@ -36,7 +37,7 @@ template { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, InType metric_arg) { raft::distance::euclideanAlgo1(m, n, k, x, y, dist, false, (AccType *)workspace, worksize, @@ -59,7 +60,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, InType metric_arg) { raft::distance::euclideanAlgo1(m, n, k, x, y, dist, true, (AccType *)workspace, worksize, @@ -73,7 +74,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, InType metric_arg) { raft::distance::cosineAlgo1( m, n, k, x, y, dist, (AccType *)workspace, worksize, fin_op, stream, isRowMajor); @@ -86,7 +87,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, InType metric_arg) { raft::distance::euclideanAlgo2(m, n, k, x, y, dist, false, fin_op, stream, isRowMajor); @@ -99,7 +100,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, InType metric_arg) { raft::distance::euclideanAlgo2(m, n, k, x, y, dist, true, fin_op, stream, isRowMajor); @@ -112,7 +113,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, InType metric_arg) { raft::distance::l1Impl( m, n, k, x, y, dist, fin_op, stream, isRowMajor); } @@ -124,7 +125,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, InType metric_arg) { raft::distance::chebyshevImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); @@ -137,13 +138,25 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, InType metric_arg) { raft::distance::hellingerImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); } }; +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg) { + raft::distance::minkowski( + m, n, k, x, y, dist, fin_op, stream, isRowMajor, metric_arg); + } +}; + } // anonymous namespace /** @@ -206,11 +219,12 @@ template void distance(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, - FinalLambda fin_op, cudaStream_t stream, bool isRowMajor = true) { + FinalLambda fin_op, cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { DistanceImpl distImpl; distImpl.run(x, y, dist, m, n, k, workspace, worksize, fin_op, stream, - isRowMajor); + isRowMajor, metric_arg); CUDA_CHECK(cudaPeekAtLastError()); } @@ -239,13 +253,14 @@ template void distance(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor = true) { + cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { auto default_fin_op = [] __device__(AccType d_val, Index_ g_d_idx) { return d_val; }; distance(x, y, dist, m, n, k, workspace, worksize, default_fin_op, - stream, isRowMajor); + stream, isRowMajor, metric_arg); CUDA_CHECK(cudaPeekAtLastError()); } @@ -272,12 +287,14 @@ template void pairwise_distance_impl(const Type *x, const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, raft::mr::device::buffer &workspace, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, + Type metric_arg = 2.0f) { auto worksize = getWorkspaceSize(x, y, m, n, k); workspace.resize(worksize, stream); - distance( - x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor); + distance(x, y, dist, m, n, k, + workspace.data(), worksize, + stream, isRowMajor, metric_arg); } template @@ -285,7 +302,7 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, raft::mr::device::buffer &workspace, raft::distance::DistanceType metric, cudaStream_t stream, - bool isRowMajor = true) { + bool isRowMajor = true, Type metric_arg = 2.0f) { switch (metric) { case raft::distance::DistanceType::L2Expanded: pairwise_distance_impl( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; - case raft::distance::HellingerExpanded: + case raft::distance::DistanceType::HellingerExpanded: pairwise_distance_impl( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; + case raft::distance::DistanceType::LpUnexpanded: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh index a8d49cbe07..068599d5a2 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/hellinger.cuh @@ -59,10 +59,7 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - const auto rectifierX = (!signbit(x)); - const auto rectifierY = (!signbit(y)); - const auto product = - raft::mySqrt(rectifierX * x) * raft::mySqrt(rectifierY * y); + const auto product = raft::mySqrt(x) * raft::mySqrt(y); acc += product; }; diff --git a/cpp/include/raft/distance/minkowski.cuh b/cpp/include/raft/distance/minkowski.cuh new file mode 100644 index 0000000000..4f975b1c26 --- /dev/null +++ b/cpp/include/raft/distance/minkowski.cuh @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft { +namespace distance { + +/** + * @brief the unexpanded Minkowski distance matrix calculation + * It computes the following equation: cij = sum(|x - y|^p)^(1/p) + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam FinalLambda final lambda called on final distance value + * + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] pD output matrix + * @param fin_op the final gemm epilogue lambda + * @param stream cuda stream where to launch work + * @param[in] the value of `p` for Minkowski (l-p) distances. + */ +template +void minkowskiUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, + IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream, DataT p) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [p] __device__(AccT & acc, DataT & x, DataT & y) { + const auto diff = raft::L1Op()(x - y); + acc += raft::myPow(diff, p); + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [p] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { + const auto one_over_p = 1.0f / p; +#pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = raft::myPow(acc[i][j], one_over_p); + } + } + }; + + if (isRowMajor) { + auto minkowskiUnExpRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + minkowskiUnExpRowMajor); + + minkowskiUnExpRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + + } else { + auto minkowskiUnExpColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + minkowskiUnExpColMajor); + + minkowskiUnExpColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void minkowskiUnExp(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream, DataT metric_arg) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + minkowskiUnExpImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, + fin_op, stream, metric_arg); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + minkowskiUnExpImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, + fin_op, stream, metric_arg); + } else { + minkowskiUnExpImpl( + x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream, metric_arg); + } +} + +/** + * @brief the unexpanded minkowski distance matrix calculation + * It computes the following equation: cij = sum(|x - y|^p)^(1/p) + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final gemm epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + * @param metric_arg the value of `p` for Minkowski (l-p) distances. + */ +template +void minkowski(Index_ m, Index_ n, Index_ k, const InType *pA, const InType *pB, + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType metric_arg) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + LpUnexpOutType; + LpUnexpOutType *pDcast = reinterpret_cast(pD); + Index_ lda, ldb, ldd; + + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + minkowskiUnExp( + m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream, metric_arg); + } else { + lda = n, ldb = m, ldd = m; + minkowskiUnExp( + n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream, metric_arg); + } +} + +}; // end namespace distance +}; // end namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 9d08dcc7ec..ebb0d41293 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -25,6 +25,7 @@ add_executable(test_raft test/distance/dist_euc_unexp.cu test/distance/dist_hellinger.cu test/distance/dist_l1.cu + test/distance/dist_minkowski.cu test/distance/fused_l2_nn.cu test/eigen_solvers.cu test/handle.cpp diff --git a/cpp/test/distance/dist_minkowski.cu b/cpp/test/distance/dist_minkowski.cu new file mode 100644 index 0000000000..42b8e294ac --- /dev/null +++ b/cpp/test/distance/dist_minkowski.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceLpUnexp + : public DistanceTest { +}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL, 4.0f}, + {0.001f, 1024, 32, 1024, true, 1234ULL, 3.0f}, + {0.001f, 32, 1024, 1024, true, 1234ULL, 4.0f}, + {0.003f, 1024, 1024, 1024, true, 1234ULL, 3.0f}, + {0.001f, 1024, 1024, 32, false, 1234ULL, 4.0f}, + {0.001f, 1024, 32, 1024, false, 1234ULL, 3.0f}, + {0.001f, 32, 1024, 1024, false, 1234ULL, 4.0f}, + {0.003f, 1024, 1024, 1024, false, 1234ULL, 3.0f}, +}; +typedef DistanceLpUnexp DistanceLpUnexpF; +TEST_P(DistanceLpUnexpF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL, 4.0}, + {0.001, 1024, 32, 1024, true, 1234ULL, 3.0}, + {0.001, 32, 1024, 1024, true, 1234ULL, 4.0}, + {0.003, 1024, 1024, 1024, true, 1234ULL, 3.0}, + {0.001, 1024, 1024, 32, false, 1234ULL, 4.0}, + {0.001, 1024, 32, 1024, false, 1234ULL, 3.0}, + {0.001, 32, 1024, 1024, false, 1234ULL, 4.0}, + {0.003, 1024, 1024, 1024, false, 1234ULL, 3.0}, +}; +typedef DistanceLpUnexp DistanceLpUnexpD; +TEST_P(DistanceLpUnexpD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 8067984c70..f57f52d6f0 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -124,10 +124,7 @@ __global__ void naiveHellingerDistanceKernel(DataType *dist, const DataType *x, int yidx = isRowMajor ? i + nidx * k : i * n + nidx; auto a = x[xidx]; auto b = y[yidx]; - // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - auto rectifierX = (!signbit(a)); - auto rectifierY = (!signbit(b)); - acc_ab += raft::mySqrt(rectifierX * a) * raft::mySqrt(rectifierY * b); + acc_ab += raft::mySqrt(a) * raft::mySqrt(b); } int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; @@ -138,10 +135,32 @@ __global__ void naiveHellingerDistanceKernel(DataType *dist, const DataType *x, dist[outidx] = raft::mySqrt(rectifier * acc_ab); } +template +__global__ void naiveLpUnexpDistanceKernel(DataType *dist, const DataType *x, + const DataType *y, int m, int n, + int k, bool isRowMajor, DataType p) { + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + DataType acc = DataType(0); + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + auto diff = raft::L1Op()(a - b); + acc += raft::myPow(diff, p); + } + auto one_over_p = 1 / p; + acc = raft::myPow(acc, one_over_p); + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + template void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, raft::distance::DistanceType type, - bool isRowMajor) { + bool isRowMajor, DataType metric_arg = 2.0f) { static const dim3 TPB(16, 32, 1); dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); @@ -163,8 +182,12 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, <<>>(dist, x, y, m, n, k, isRowMajor); break; case raft::distance::DistanceType::HellingerExpanded: - naiveHellingerDistanceKernel<<>>(dist, x, y, m, n, k, - isRowMajor); + naiveHellingerDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; + case raft::distance::DistanceType::LpUnexpanded: + naiveLpUnexpDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor, metric_arg); break; default: FAIL() << "should be here\n"; @@ -178,6 +201,7 @@ struct DistanceInputs { int m, n, k; bool isRowMajor; unsigned long long int seed; + DataType metric_arg = 2.0f; }; template @@ -190,13 +214,15 @@ template void distanceLauncher(DataType *x, DataType *y, DataType *dist, DataType *dist2, int m, int n, int k, DistanceInputs ¶ms, DataType threshold, char *workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor) { + cudaStream_t stream, bool isRowMajor, + DataType metric_arg = 2.0f) { auto fin_op = [dist2, threshold] __device__(DataType d_val, int g_d_idx) { dist2[g_d_idx] = (d_val < threshold) ? 0.f : d_val; return d_val; }; raft::distance::distance( - x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor); + x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, + metric_arg); } template @@ -208,6 +234,7 @@ class DistanceTest : public ::testing::TestWithParam> { int m = params.m; int n = params.n; int k = params.k; + DataType metric_arg = params.metric_arg; bool isRowMajor = params.isRowMajor; cudaStream_t stream; CUDA_CHECK(cudaStreamCreate(&stream)); @@ -216,9 +243,17 @@ class DistanceTest : public ::testing::TestWithParam> { raft::allocate(dist_ref, m * n); raft::allocate(dist, m * n); raft::allocate(dist2, m * n); - r.uniform(x, m * k, DataType(-1.0), DataType(1.0), stream); - r.uniform(y, n * k, DataType(-1.0), DataType(1.0), stream); - naiveDistance(dist_ref, x, y, m, n, k, distanceType, isRowMajor); + if (distanceType == raft::distance::DistanceType::HellingerExpanded) { + // Hellinger works only on positive numbers as it applies sqrt on inputs + r.uniform(x, m * k, DataType(0.0), DataType(2.0), stream); + r.uniform(y, n * k, DataType(0.0), DataType(2.0), stream); + } else { + r.uniform(x, m * k, DataType(-1.0), DataType(1.0), stream); + r.uniform(y, n * k, DataType(-1.0), DataType(1.0), stream); + } + + naiveDistance(dist_ref, x, y, m, n, k, distanceType, isRowMajor, + metric_arg); char *workspace = nullptr; size_t worksize = raft::distance::getWorkspaceSize> { DataType threshold = -10000.f; distanceLauncher(x, y, dist, dist2, m, n, k, params, threshold, workspace, worksize, - stream, isRowMajor); + stream, isRowMajor, metric_arg); CUDA_CHECK(cudaStreamDestroy(stream)); CUDA_CHECK(cudaFree(workspace)); } From b5b3c51298c614a650d8b9be8a7694a669fec353 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 4 Jun 2021 21:17:43 +0530 Subject: [PATCH 15/29] use ping-pong buffers for safely grid striding --- cpp/include/raft/distance/pairwise_distance_base.cuh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 5d216120dc..d5a434f2fa 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -120,11 +120,6 @@ struct PairwiseDistances : public BaseClass { this->y += stride; } this->yrowid += stride; - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. - this->pageRd = this->pageWr; } DI void updateIndicesXY() { @@ -139,7 +134,6 @@ struct PairwiseDistances : public BaseClass { this->y = yBase + this->yrowid + this->srowid * this->ldb; } this->xrowid += stride; - this->pageRd = this->pageWr; } DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { @@ -191,6 +185,11 @@ struct PairwiseDistances : public BaseClass { this->pageRd ^= 1; } accumulate(); // last iteration + // This is needed for making sure next grid stride of + // non-norm based metrics uses previously accumulated buffer so + // it doesn't make shmem dirty until previous iteration + // is complete. + this->pageRd ^= 1; } DI void accumulate() { From 2fd7f4c74ece9d626c495f63b86d9c83d5c3e3f7 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 7 Jun 2021 14:06:46 +0530 Subject: [PATCH 16/29] Add canberra distance metric support --- cpp/include/raft/distance/canberra.cuh | 159 +++++++++++++++++++++++++ cpp/include/raft/distance/distance.cuh | 18 +++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_canberra.cu | 68 +++++++++++ cpp/test/distance/distance_base.cuh | 16 ++- 5 files changed, 256 insertions(+), 6 deletions(-) create mode 100644 cpp/include/raft/distance/canberra.cuh create mode 100644 cpp/test/distance/dist_canberra.cu diff --git a/cpp/include/raft/distance/canberra.cuh b/cpp/include/raft/distance/canberra.cuh new file mode 100644 index 0000000000..ba789377d9 --- /dev/null +++ b/cpp/include/raft/distance/canberra.cuh @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft { +namespace distance { + +/** + * @brief the canberra distance matrix calculation implementer + * It computes the following equation: cij = max(cij, op(ai-bj)) + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major, + false for column major + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] pD output matrix + * @param fin_op the final gemm epilogue lambda + */ +template +static void canberraImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, + IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + const auto diff = raft::L1Op()(x - y); + const auto add = raft::myAbs(x) + raft::myAbs(y); + // deal with potential for 0 in denominator by + // forcing 1/0 instead + acc += ((add != 0) * diff / (add + (add == 0))); + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { return; }; + + if (isRowMajor) { + auto canberraRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, canberraRowMajor); + + canberraRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { + auto canberraColMajor = + pairwiseDistanceMatKernel; + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, canberraColMajor); + canberraColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void canberra(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, OutT *dOutput, FinalLambda fin_op, + cudaStream_t stream) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + canberraImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + canberraImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else { + canberraImpl( + x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + } +} + +/** + * @brief the canberra distance matrix calculation + * It computes the following equation: cij = max(cij, op(ai-bj)) + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final element-wise epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + */ +template +void canberraImpl(int m, int n, int k, const InType *pA, const InType *pB, + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + canberraOutType; + Index_ lda, ldb, ldd; + canberraOutType *pDcast = reinterpret_cast(pD); + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + canberra( + m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + } else { + lda = n, ldb = m, ldd = m; + canberra( + n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + } +} +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 90452d5d94..854ee4551a 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -157,6 +158,18 @@ struct DistanceImpl +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg) { + raft::distance::canberraImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + } // anonymous namespace /** @@ -347,6 +360,11 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, raft::distance::DistanceType::LpUnexpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); break; + case raft::distance::DistanceType::Canberra: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ebb0d41293..489d71208c 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -19,6 +19,7 @@ add_executable(test_raft test/cudart_utils.cpp test/cluster_solvers.cu test/distance/dist_adj.cu + test/distance/dist_canberra.cu test/distance/dist_chebyshev.cu test/distance/dist_cos.cu test/distance/dist_euc_exp.cu diff --git a/cpp/test/distance/dist_canberra.cu b/cpp/test/distance/dist_canberra.cu new file mode 100644 index 0000000000..10bc4d1899 --- /dev/null +++ b/cpp/test/distance/dist_canberra.cu @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceCanberra + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCanberra DistanceCanberraF; +TEST_P(DistanceCanberraF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCanberra DistanceCanberraD; +TEST_P(DistanceCanberraD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index f57f52d6f0..dfee818449 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -47,11 +47,9 @@ __global__ void naiveDistanceKernel(DataType *dist, const DataType *x, } template -__global__ void naiveL1_LinfDistanceKernel(DataType *dist, const DataType *x, - const DataType *y, int m, int n, - int k, - raft::distance::DistanceType type, - bool isRowMajor) { +__global__ void naiveL1_Linf_CanberraDistanceKernel( + DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, + raft::distance::DistanceType type, bool isRowMajor) { int midx = threadIdx.x + blockIdx.x * blockDim.x; int nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) { @@ -67,6 +65,11 @@ __global__ void naiveL1_LinfDistanceKernel(DataType *dist, const DataType *x, auto diff = (a > b) ? (a - b) : (b - a); if (type == raft::distance::DistanceType::Linf) { acc = raft::myMax(acc, diff); + } else if (type == raft::distance::DistanceType::Canberra) { + const auto add = raft::myAbs(a) + raft::myAbs(b); + // deal with potential for 0 in denominator by + // forcing 1/0 instead + acc += ((add != 0) * diff / (add + (add == 0))); } else { acc += diff; } @@ -165,9 +168,10 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, dim3 nblks(raft::ceildiv(m, (int)TPB.x), raft::ceildiv(n, (int)TPB.y), 1); switch (type) { + case raft::distance::DistanceType::Canberra: case raft::distance::DistanceType::Linf: case raft::distance::DistanceType::L1: - naiveL1_LinfDistanceKernel + naiveL1_Linf_CanberraDistanceKernel <<>>(dist, x, y, m, n, k, type, isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: From 0f2c03d17d507bd091be4a7dc560c9690ac8dece Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 7 Jun 2021 18:42:30 +0530 Subject: [PATCH 17/29] fix build failure of mst and knn test by adding cuda stream arg to rmm::device_buffer --- cpp/test/mst.cu | 18 +++++++++--------- cpp/test/spatial/knn.cu | 6 ++++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cpp/test/mst.cu b/cpp/test/mst.cu index 215c6f6548..b0007c73fd 100644 --- a/cpp/test/mst.cu +++ b/cpp/test/mst.cu @@ -198,15 +198,15 @@ class MSTTest MSTTestInput>::GetParam(); iterations = mst_input.iterations; - csr_d.offsets = - rmm::device_buffer(mst_input.csr_h.offsets.data(), - mst_input.csr_h.offsets.size() * sizeof(edge_t)); - csr_d.indices = - rmm::device_buffer(mst_input.csr_h.indices.data(), - mst_input.csr_h.indices.size() * sizeof(vertex_t)); - csr_d.weights = - rmm::device_buffer(mst_input.csr_h.weights.data(), - mst_input.csr_h.weights.size() * sizeof(weight_t)); + csr_d.offsets = rmm::device_buffer( + mst_input.csr_h.offsets.data(), + mst_input.csr_h.offsets.size() * sizeof(edge_t), handle.get_stream()); + csr_d.indices = rmm::device_buffer( + mst_input.csr_h.indices.data(), + mst_input.csr_h.indices.size() * sizeof(vertex_t), handle.get_stream()); + csr_d.weights = rmm::device_buffer( + mst_input.csr_h.weights.data(), + mst_input.csr_h.weights.size() * sizeof(weight_t), handle.get_stream()); } void TearDown() override {} diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index cfd4ecc9d1..2b1ef89f7a 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -107,11 +107,13 @@ class KNNTest : public ::testing::TestWithParam { } } rmm::device_buffer input_d = rmm::device_buffer( - row_major_input.data(), row_major_input.size() * sizeof(float)); + row_major_input.data(), row_major_input.size() * sizeof(float), + handle_.get_stream()); float *input_ptr = static_cast(input_d.data()); rmm::device_buffer labels_d = rmm::device_buffer( - params_.labels.data(), params_.labels.size() * sizeof(int)); + params_.labels.data(), params_.labels.size() * sizeof(int), + handle_.get_stream()); int *labels_ptr = static_cast(labels_d.data()); raft::allocate(input_, rows_ * cols_, true); From 484b0820f0cc3ce3c33b1a0ebb90a505513aa7ea Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 7 Jun 2021 18:59:34 +0530 Subject: [PATCH 18/29] temp commit for test rerun --- cpp/test/mst.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/mst.cu b/cpp/test/mst.cu index b0007c73fd..d7aa76500b 100644 --- a/cpp/test/mst.cu +++ b/cpp/test/mst.cu @@ -58,7 +58,7 @@ weight_t prims(CSRHost &csr_h) { auto n_vertices = csr_h.offsets.size() - 1; bool active_vertex[n_vertices]; - // bool mst_set[csr_h.n_edges]; + // bool mst_set[csr_h.n_edges]; weight_t curr_edge[n_vertices]; for (auto i = 0; i < n_vertices; i++) { From 04f656f97ee82b75046501615fa2927781880a6c Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 7 Jun 2021 19:14:01 +0530 Subject: [PATCH 19/29] use ucx-py version 0.21 to temp resolve ci build failures --- ci/gpu/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index b4eef5ebd0..1707b72f44 100644 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -51,7 +51,7 @@ gpuci_conda_retry install -c conda-forge -c rapidsai -c rapidsai-nightly -c nvid "rmm=${MINOR_VERSION}" \ "dask-cudf=${MINOR_VERSION}" \ "dask-cuda=${MINOR_VERSION}" \ - "ucx-py=${MINOR_VERSION}" \ + "ucx-py=0.21" \ "rapids-build-env=${MINOR_VERSION}.*" \ "rapids-notebook-env=${MINOR_VERSION}.*" \ "rapids-doc-env=${MINOR_VERSION}.*" From 1c65ab756c1e2ff66b68d6e0f814fe2b36881269 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 11 Jun 2021 12:53:54 +0530 Subject: [PATCH 20/29] remove redundant metric_arg parameter from canberra function launch --- cpp/include/raft/distance/distance.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 854ee4551a..92605f5367 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -363,7 +363,7 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, case raft::distance::DistanceType::Canberra: pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); + x, y, dist, m, n, k, workspace, stream, isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); From 8323a3c0aaea03c8b2c11fee9678ae21008c374b Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 15 Jun 2021 19:16:29 +0530 Subject: [PATCH 21/29] reduce sqrt in hellinger my merging prod in sqrt --- cpp/include/raft/distance/hellinger.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh index 068599d5a2..6cb95223a4 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/hellinger.cuh @@ -58,8 +58,7 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - const auto product = raft::mySqrt(x) * raft::mySqrt(y); + const auto product = raft::mySqrt(x * y); acc += product; }; @@ -75,7 +74,6 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative const auto finalVal = (1 - acc[i][j]); const auto rectifier = (!signbit(finalVal)); - ; acc[i][j] = raft::mySqrt(rectifier * finalVal); } } From a33546db81aebbd9a67c4876fd0773bc16165ad2 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 15 Jun 2021 20:45:25 +0530 Subject: [PATCH 22/29] rename minkowksi to be similar to other functions, fix documentation of hellinger formula --- cpp/include/raft/distance/distance.cuh | 2 +- cpp/include/raft/distance/hellinger.cuh | 4 ++-- cpp/include/raft/distance/minkowski.cuh | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 92605f5367..83c07ece48 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -153,7 +153,7 @@ struct DistanceImpl( + raft::distance::minkowskiImpl( m, n, k, x, y, dist, fin_op, stream, isRowMajor, metric_arg); } }; diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh index 6cb95223a4..372d91fbf6 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/hellinger.cuh @@ -23,7 +23,7 @@ namespace distance { /** * @brief the Hellinger distance matrix using the expanded form: * It computes the following equation: - cij = sqrt(1 - sum(sqrt(x_k) * sqrt(y_k))) + cij = sqrt(1 - sum(sqrt(x_k * y_k))) * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) @@ -129,7 +129,7 @@ void hellinger(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, /** * @brief the Hellinger distance matrix calculation * It computes the following equation: - sqrt(1 - sum(sqrt(x_k) * sqrt(y_k))) + sqrt(1 - sum(sqrt(x_k * y_k)) * @tparam InType input data-type (for A and B matrices) * @tparam AccType accumulation data-type * @tparam OutType output data-type (for C and D matrices) diff --git a/cpp/include/raft/distance/minkowski.cuh b/cpp/include/raft/distance/minkowski.cuh index 4f975b1c26..fb4f5ee5f8 100644 --- a/cpp/include/raft/distance/minkowski.cuh +++ b/cpp/include/raft/distance/minkowski.cuh @@ -146,7 +146,7 @@ void minkowskiUnExp(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, */ template -void minkowski(Index_ m, Index_ n, Index_ k, const InType *pA, const InType *pB, +void minkowskiImpl(Index_ m, Index_ n, Index_ k, const InType *pA, const InType *pB, OutType *pD, FinalLambda fin_op, cudaStream_t stream, bool isRowMajor, InType metric_arg) { typedef std::is_same is_bool; From 76730744ed072d6579792678cf9a4b10e4122b85 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 15 Jun 2021 20:51:54 +0530 Subject: [PATCH 23/29] fix clang format issue --- cpp/include/raft/distance/distance.cuh | 5 +++-- cpp/include/raft/distance/minkowski.cuh | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 83c07ece48..1b39a6ec18 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -153,8 +153,9 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor, metric_arg); + raft::distance::minkowskiImpl(m, n, k, x, y, dist, fin_op, stream, + isRowMajor, metric_arg); } }; diff --git a/cpp/include/raft/distance/minkowski.cuh b/cpp/include/raft/distance/minkowski.cuh index fb4f5ee5f8..f41e6fe931 100644 --- a/cpp/include/raft/distance/minkowski.cuh +++ b/cpp/include/raft/distance/minkowski.cuh @@ -146,9 +146,9 @@ void minkowskiUnExp(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, */ template -void minkowskiImpl(Index_ m, Index_ n, Index_ k, const InType *pA, const InType *pB, - OutType *pD, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType metric_arg) { +void minkowskiImpl(Index_ m, Index_ n, Index_ k, const InType *pA, + const InType *pB, OutType *pD, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg) { typedef std::is_same is_bool; typedef typename std::conditional::type LpUnexpOutType; From 476ed990c3d74ed3ee884fc3c60f18bb1198e594 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 15 Jun 2021 21:03:09 +0530 Subject: [PATCH 24/29] fix hellinger inputs to be only in range of 0 to 1 as hellinger is expected to work on pdf --- cpp/test/distance/distance_base.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index dfee818449..fc7b064205 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -248,9 +248,9 @@ class DistanceTest : public ::testing::TestWithParam> { raft::allocate(dist, m * n); raft::allocate(dist2, m * n); if (distanceType == raft::distance::DistanceType::HellingerExpanded) { - // Hellinger works only on positive numbers as it applies sqrt on inputs - r.uniform(x, m * k, DataType(0.0), DataType(2.0), stream); - r.uniform(y, n * k, DataType(0.0), DataType(2.0), stream); + // Hellinger works only on positive numbers + r.uniform(x, m * k, DataType(0.0), DataType(1.0), stream); + r.uniform(y, n * k, DataType(0.0), DataType(1.0), stream); } else { r.uniform(x, m * k, DataType(-1.0), DataType(1.0), stream); r.uniform(y, n * k, DataType(-1.0), DataType(1.0), stream); From 59e78e811e811a544b0d2ff05311ccaeb93383d7 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 16 Jun 2021 17:23:48 +0530 Subject: [PATCH 25/29] fix doc issues in all dist functions, also fix copyright year to be only 2021 --- cpp/include/raft/distance/canberra.cuh | 30 ++++++++++++----------- cpp/include/raft/distance/chebyshev.cuh | 32 +++++++++++++------------ cpp/include/raft/distance/hellinger.cuh | 12 ++++++---- cpp/include/raft/distance/minkowski.cuh | 32 +++++++++++++------------ 4 files changed, 57 insertions(+), 49 deletions(-) diff --git a/cpp/include/raft/distance/canberra.cuh b/cpp/include/raft/distance/canberra.cuh index ba789377d9..b87c295eb0 100644 --- a/cpp/include/raft/distance/canberra.cuh +++ b/cpp/include/raft/distance/canberra.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,20 +27,22 @@ namespace distance { * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type - + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh * @tparam FinalLambda final lambda called on final distance value * @tparam isRowMajor true if input/output is row major, false for column major * @param[in] x input matrix * @param[in] y input matrix * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B + * @param[in] n number of rows of B and cols of C/D + * @param[in] k number of cols of A and B * @param[in] lda leading dimension of A * @param[in] ldb leading dimension of B * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix + * @param[output] dOutput output matrix * @param fin_op the final gemm epilogue lambda + * @param stream cuda stream to launch work */ template @@ -125,15 +127,15 @@ void canberra(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, * @tparam OutType output data-type (for C and D matrices) * @tparam FinalLambda user-defined epilogue lamba * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major + * @param[in] m number of rows of A and C/D + * @param[in] n number of rows of B and cols of C/D + * @param[in] k number of cols of A and B + * @param[in] pA input matrix + * @param[in] pB input matrix + * @param[out] pD output matrix + * @param[in] fin_op the final element-wise epilogue lambda + * @param[in] stream cuda stream to launch work + * @param[in] isRowMajor whether the input and output matrices are row major */ template diff --git a/cpp/include/raft/distance/chebyshev.cuh b/cpp/include/raft/distance/chebyshev.cuh index fe50501708..8d53408cf8 100644 --- a/cpp/include/raft/distance/chebyshev.cuh +++ b/cpp/include/raft/distance/chebyshev.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,20 +27,22 @@ namespace distance { * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type - + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh * @tparam FinalLambda final lambda called on final distance value * @tparam isRowMajor true if input/output is row major, false for column major * @param[in] x input matrix * @param[in] y input matrix * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B + * @param[in] n number of rows of B and cols of C/D + * @param[in] k number of cols of A and B * @param[in] lda leading dimension of A * @param[in] ldb leading dimension of B * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix - * @param fin_op the final gemm epilogue lambda + * @param[out] dOutput output matrix + * @param[in] fin_op the final gemm epilogue lambda + * @param[in] stream cuda stream to launch work */ template @@ -122,15 +124,15 @@ void chebyshev(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, * @tparam OutType output data-type (for C and D matrices) * @tparam FinalLambda user-defined epilogue lamba * @tparam Index_ Index type - * @param m number of rows of A and C/D - * @param n number of columns of B and C/D - * @param k number of cols of A and rows of B - * @param pA input matrix - * @param pB input matrix - * @param pD output matrix - * @param fin_op the final element-wise epilogue lambda - * @param stream cuda stream where to launch work - * @param isRowMajor whether the input and output matrices are row major + * @param[in] m number of rows of A and C/D + * @param[in] n number of rows of B and cols of C/D + * @param[in] k number of cols of A and B + * @param[in] pA input matrix + * @param[in] pB input matrix + * @param[out] pD output matrix + * @param[in] fin_op the final element-wise epilogue lambda + * @param[in] stream cuda stream to launch work + * @param[in] isRowMajor whether the input and output matrices are row major */ template diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh index 372d91fbf6..c76fcd0e95 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/hellinger.cuh @@ -28,20 +28,22 @@ namespace distance { * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type - + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh * @tparam FinalLambda final lambda called on final distance value * @tparam isRowMajor true if input/output is row major, false for column major * @param[in] x input matrix * @param[in] y input matrix * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B + * @param[in] n number of rows of B and C/D + * @param[in] k number of cols of A and B * @param[in] lda leading dimension of A * @param[in] ldb leading dimension of B * @param[in] ldd leading dimension of C/D - * @param[output] pD output matrix - * @param fin_op the final gemm epilogue lambda + * @param[output] dOutput output matrix + * @param[in] fin_op the final gemm epilogue lambda + * @param[in] stream cuda stream to launch work */ template diff --git a/cpp/include/raft/distance/minkowski.cuh b/cpp/include/raft/distance/minkowski.cuh index f41e6fe931..803f5fc78a 100644 --- a/cpp/include/raft/distance/minkowski.cuh +++ b/cpp/include/raft/distance/minkowski.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,19 +27,21 @@ namespace distance { * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh * @tparam FinalLambda final lambda called on final distance value * * @param[in] x input matrix * @param[in] y input matrix * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B + * @param[in] n number of rows of B and cols of C/D + * @param[in] k number of cols of A and B * @param[in] lda leading dimension of A * @param[in] ldb leading dimension of B * @param[in] ldd leading dimension of C/D * @param[output] pD output matrix - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream where to launch work + * @param[in] fin_op the final gemm epilogue lambda + * @param[in] stream cuda stream to launch work * @param[in] the value of `p` for Minkowski (l-p) distances. */ template From 7af5e328ace67be77e50201944480c9f597c3325 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 16 Jun 2021 19:30:08 +0530 Subject: [PATCH 26/29] reduce sqrt in hellinger usage by overwriting input matrices by sqrt and reverting in post completion. also add unrolls to ldg arrays in contractions --- cpp/include/raft/distance/hellinger.cuh | 31 +++++++++++++++++++++++- cpp/include/raft/linalg/contractions.cuh | 4 +++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh index c76fcd0e95..d167f7fc3e 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/hellinger.cuh @@ -16,6 +16,7 @@ #pragma once #include +#include namespace raft { namespace distance { @@ -24,6 +25,11 @@ namespace distance { * @brief the Hellinger distance matrix using the expanded form: * It computes the following equation: cij = sqrt(1 - sum(sqrt(x_k * y_k))) + * This distance computation modifies A and B by computing a sqrt + * and then performing a `pow(x, 2)` to convert it back. Because of this, + * it is possible that the values in A and B might differ slightly + * after this is invoked. + * * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) @@ -58,9 +64,19 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, dim3 blk(KPolicy::Nthreads); + // First sqrt x and y + raft::linalg::unaryOp( + (DataT*) x, (DataT*) x, m * k, + [=] __device__(DataT input) { return raft::mySqrt(input); }, stream); + raft::linalg::unaryOp( + (DataT*) y, (DataT*) y, n * k, + [=] __device__(DataT input) { return raft::mySqrt(input); }, stream); + + // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto product = raft::mySqrt(x * y); + // This is sqrt(x) * sqrt(y). + const auto product = x * y; acc += product; }; @@ -104,6 +120,14 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, epilog_lambda, fin_op); } + // Revert sqrt of x and y + raft::linalg::unaryOp( + (DataT*) x, (DataT*) x, m * k, + [=] __device__(DataT input) { return input * input; }, stream); + raft::linalg::unaryOp( + (DataT*) y, (DataT*) y, n * k, + [=] __device__(DataT input) { return input * input; }, stream); + CUDA_CHECK(cudaGetLastError()); } @@ -132,6 +156,11 @@ void hellinger(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, * @brief the Hellinger distance matrix calculation * It computes the following equation: sqrt(1 - sum(sqrt(x_k * y_k)) + * This distance computation modifies A and B by computing a sqrt + * and then performing a `pow(x, 2)` to convert it back. Because of this, + * it is possible that the values in A and B might differ slightly + * after this is invoked. + * * @tparam InType input data-type (for A and B matrices) * @tparam AccType accumulation data-type * @tparam OutType output data-type (for C and D matrices) diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index c590abb142..aa711a9140 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -338,6 +338,7 @@ struct Contractions_NT { if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; +#pragma unroll for (int i = 0; i < P::LdgPerThX; ++i) { if (koffset < lda && (xrowid + i * P::LdgRowsX) < numRows) { ldg(ldgDataX[i], x + i * P::LdgRowsX * lda + koffset); @@ -351,6 +352,7 @@ struct Contractions_NT { } else { const auto numRows = k; auto koffset = scolid; +#pragma unroll for (int i = 0; i < P::LdgPerThX; ++i) { if ((koffset + xrowid) < lda && (srowid + kidx + i * P::LdgRowsX) < numRows) { @@ -369,6 +371,7 @@ struct Contractions_NT { if (isRowMajor) { auto numRows = n; auto koffset = kidx + scolid; +#pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { if (koffset < ldb && (yrowid + i * P::LdgRowsY) < numRows) { ldg(ldgDataY[i], y + i * P::LdgRowsY * ldb + koffset); @@ -382,6 +385,7 @@ struct Contractions_NT { } else { auto numRows = k; auto koffset = scolid; +#pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { if ((koffset + yrowid) < ldb && (srowid + kidx + i * P::LdgRowsY) < numRows) { From 0e991136fc4b2d2c4374ac3186808e814ba06029 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 16 Jun 2021 19:32:31 +0530 Subject: [PATCH 27/29] fix clang format issues --- cpp/include/raft/distance/hellinger.cuh | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh index d167f7fc3e..e651c76478 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/hellinger.cuh @@ -66,12 +66,11 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // First sqrt x and y raft::linalg::unaryOp( - (DataT*) x, (DataT*) x, m * k, - [=] __device__(DataT input) { return raft::mySqrt(input); }, stream); + (DataT *)x, (DataT *)x, m * k, + [=] __device__(DataT input) { return raft::mySqrt(input); }, stream); raft::linalg::unaryOp( - (DataT*) y, (DataT*) y, n * k, - [=] __device__(DataT input) { return raft::mySqrt(input); }, stream); - + (DataT *)y, (DataT *)y, n * k, + [=] __device__(DataT input) { return raft::mySqrt(input); }, stream); // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { @@ -122,11 +121,11 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // Revert sqrt of x and y raft::linalg::unaryOp( - (DataT*) x, (DataT*) x, m * k, - [=] __device__(DataT input) { return input * input; }, stream); + (DataT *)x, (DataT *)x, m * k, + [=] __device__(DataT input) { return input * input; }, stream); raft::linalg::unaryOp( - (DataT*) y, (DataT*) y, n * k, - [=] __device__(DataT input) { return input * input; }, stream); + (DataT *)y, (DataT *)y, n * k, + [=] __device__(DataT input) { return input * input; }, stream); CUDA_CHECK(cudaGetLastError()); } From f4b8d33304b136f22ea0748bc179efee9d675737 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 17 Jun 2021 18:09:48 +0530 Subject: [PATCH 28/29] hellinger: only sqrt inputs when x & y are not same. --- cpp/include/raft/distance/hellinger.cuh | 28 ++++++++++++++----------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh index e651c76478..5ee8e65d9b 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/hellinger.cuh @@ -64,13 +64,17 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, dim3 blk(KPolicy::Nthreads); + auto unaryOp_lambda = [] __device__(DataT input) { + return raft::mySqrt(input); + }; // First sqrt x and y - raft::linalg::unaryOp( - (DataT *)x, (DataT *)x, m * k, - [=] __device__(DataT input) { return raft::mySqrt(input); }, stream); - raft::linalg::unaryOp( - (DataT *)y, (DataT *)y, n * k, - [=] __device__(DataT input) { return raft::mySqrt(input); }, stream); + raft::linalg::unaryOp( + (DataT*) x, x, m * k, unaryOp_lambda, stream); + + if (x != y) { + raft::linalg::unaryOp( + (DataT*) y, y, n * k, unaryOp_lambda, stream); + } // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { @@ -120,12 +124,12 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, } // Revert sqrt of x and y - raft::linalg::unaryOp( - (DataT *)x, (DataT *)x, m * k, - [=] __device__(DataT input) { return input * input; }, stream); - raft::linalg::unaryOp( - (DataT *)y, (DataT *)y, n * k, - [=] __device__(DataT input) { return input * input; }, stream); + raft::linalg::unaryOp( + (DataT*) x, x, m * k, unaryOp_lambda, stream); + if (x != y) { + raft::linalg::unaryOp( + (DataT*) y, y, n * k, unaryOp_lambda, stream); + } CUDA_CHECK(cudaGetLastError()); } From a71e5202455637f73d9786b94bb1b0bf06508fe3 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 17 Jun 2021 18:32:18 +0530 Subject: [PATCH 29/29] fix clang format issues --- cpp/include/raft/distance/hellinger.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/hellinger.cuh index 5ee8e65d9b..f7ad3ed1ba 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/hellinger.cuh @@ -69,11 +69,11 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, }; // First sqrt x and y raft::linalg::unaryOp( - (DataT*) x, x, m * k, unaryOp_lambda, stream); + (DataT *)x, x, m * k, unaryOp_lambda, stream); if (x != y) { raft::linalg::unaryOp( - (DataT*) y, y, n * k, unaryOp_lambda, stream); + (DataT *)y, y, n * k, unaryOp_lambda, stream); } // Accumulation operation lambda @@ -125,10 +125,10 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // Revert sqrt of x and y raft::linalg::unaryOp( - (DataT*) x, x, m * k, unaryOp_lambda, stream); + (DataT *)x, x, m * k, unaryOp_lambda, stream); if (x != y) { raft::linalg::unaryOp( - (DataT*) y, y, n * k, unaryOp_lambda, stream); + (DataT *)y, y, n * k, unaryOp_lambda, stream); } CUDA_CHECK(cudaGetLastError());