From e0855a51babdf84c846820bcb0cbfeed6b7ffa71 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Wed, 2 Jun 2021 14:51:09 -0500 Subject: [PATCH] Revert "Add Grid stride pairwise dist and fused L2 NN kernels (#232)" This reverts commit e8f1862e36072ff867f59ed3e38e8dcb7bb02fd3. --- cpp/include/raft/distance/cosine.cuh | 35 +- cpp/include/raft/distance/euclidean.cuh | 77 ++-- cpp/include/raft/distance/fused_l2_nn.cuh | 366 +++++++++++------- cpp/include/raft/distance/l1.cuh | 36 +- .../raft/distance/pairwise_distance_base.cuh | 163 ++------ cpp/include/raft/linalg/contractions.cuh | 8 +- cpp/test/distance/fused_l2_nn.cu | 2 + 7 files changed, 328 insertions(+), 359 deletions(-) diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/cosine.cuh index ed9bd28b7f..5a212ce64c 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/cosine.cuh @@ -61,6 +61,8 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; + dim3 grid(raft::ceildiv(m, KPolicy::Mblk), + raft::ceildiv(n, KPolicy::Nblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -71,8 +73,7 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn, IdxT gridStrideX, - IdxT gridStrideY) { + DataT * regxn, DataT * regyn) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll @@ -82,26 +83,20 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, } }; - constexpr size_t shmemSize = - KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { - auto cosineRowMajor = - pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); - cosineRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, - fin_op); + pairwiseDistanceMatKernel + <<>>(x, y, xn, yn, m, n, k, lda, + ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } else { - auto cosineColMajor = - pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); - cosineColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, - fin_op); + pairwiseDistanceMatKernel + <<>>(x, y, xn, yn, m, n, k, lda, + ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/euclidean.cuh index 484da0e5bf..f3f946ad7b 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/euclidean.cuh @@ -60,6 +60,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, typedef typename std::conditional::type KPolicy; + dim3 grid(raft::ceildiv(m, KPolicy::Mblk), + raft::ceildiv(n, KPolicy::Nblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -70,8 +72,7 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn, IdxT gridStrideX, - IdxT gridStrideY) { + DataT * regxn, DataT * regyn) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll @@ -90,29 +91,20 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, } }; - constexpr size_t shmemSize = - KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); if (isRowMajor) { - auto euclideanExpRowMajor = - pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); - - euclideanExpRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, - fin_op); + pairwiseDistanceMatKernel + <<>>(x, y, xn, yn, m, n, k, lda, + ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } else { - auto euclideanExpColMajor = - pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); - euclideanExpColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, - fin_op); + pairwiseDistanceMatKernel + <<>>(x, y, xn, yn, m, n, k, lda, + ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); @@ -237,7 +229,8 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; - + dim3 grid(raft::ceildiv(m, KPolicy::Mblk), + raft::ceildiv(n, KPolicy::Nblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -249,8 +242,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, // epilogue operation lambda for final value calculation auto epilog_lambda = [sqrt] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn, IdxT gridStrideX, - IdxT gridStrideY) { + DataT * regxn, DataT * regyn) { if (sqrt) { #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { @@ -263,28 +255,19 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, }; if (isRowMajor) { - auto euclideanUnExpRowMajor = - pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - euclideanUnExpRowMajor); - - euclideanUnExpRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); - + pairwiseDistanceMatKernel + <<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } else { - auto euclideanUnExpColMajor = - pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - euclideanUnExpColMajor); - - euclideanUnExpColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + pairwiseDistanceMatKernel + <<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index b96a536e38..000d856841 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -20,7 +20,6 @@ #include #include #include -#include #include namespace raft { @@ -69,81 +68,117 @@ struct MinReduceOp { DI void init(DataT* out, DataT maxVal) { *out = maxVal; } }; -template -__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { - redOp.init(min + tid, maxVal); - } -} +template > +struct FusedL2NN : public BaseClass { + private: + typedef Policy P; -// TODO: specialize this function for MinAndDistanceReduceOp -// with atomicCAS of 64 bit which will eliminate mutex and shfls -template -DI void updateReducedVal(int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, - IdxT m, IdxT gridStrideY) { - const auto lid = threadIdx.x % raft::WarpSize; - const auto accrowid = threadIdx.x / P::AccThCols; - - // for now have first lane from each warp update a unique output row. This - // will resolve hang issues with pre-Volta architectures -#pragma unroll - for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == 0) { + const DataT* xn; + const DataT* yn; + OutT* min; + int* mutex; + + DataT *sxNorm, *syNorm; + cub::KeyValuePair* sRed; + + DataT maxVal; + + DataT acc[P::AccRowsPerTh][P::AccColsPerTh]; + + ReduceOpT redOp; + KVPReduceOpT pairRedOp; + +#if (ENABLE_MEMCPY_ASYNC == 1) + DataT zeros[P::Veclen]; + nvcuda::experimental::pipeline pipe; +#endif + + static const DataT Two = (DataT)2.0; + static constexpr size_t SizeAndAlign = P::Veclen * sizeof(DataT); + + public: + DI FusedL2NN(OutT* _min, const DataT* _x, const DataT* _y, const DataT* _xn, + const DataT* _yn, IdxT _m, IdxT _n, IdxT _k, char* _smem, + DataT _mv, int* _mut, ReduceOpT op, KVPReduceOpT pair_op) + : BaseClass(_x, _y, _m, _n, _k, _smem), + xn(_xn), + yn(_yn), + min(_min), + mutex(_mut), + sxNorm((DataT*)_smem), + syNorm(&(sxNorm[P::Mblk])), + sRed((cub::KeyValuePair*)_smem), + maxVal(_mv), + redOp(op), + pairRedOp(pair_op) { +#if (ENABLE_MEMCPY_ASYNC == 1) #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + j + i * P::AccThRows; - if (rid < m) { - auto value = val[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, value); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } - if (j < (raft::WarpSize / P::AccThCols) - 1) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols); - val[i] = {tmpkey, tmpvalue}; - } + for (int i = 0; i < P::Veclen; ++i) { + zeros[i] = BaseClass::Zero; } +#endif } -} -template -__global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( - OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, - IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, - KVPReduceOpT pairRedOp, CoreLambda core_op, FinalLambda fin_op) { - extern __shared__ char smem[]; + DI void run() { + prolog(); + loop(); + __syncthreads(); // so that we can safely reuse smem + epilog(); + } - typedef cub::KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; + private: + DI void prolog() { + this->ldgXY(0); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } + } + this->stsXY(); + __syncthreads(); + this->pageWr ^= 1; } - // epilogue operation lambda for final value calculation - auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, DataT * regyn, IdxT gridStrideX, - IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); + DI void loop() { + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(kidx); + accumulate(); // on the previous k-block + this->stsXY(); + __syncthreads(); + this->pageWr ^= 1; + this->pageRd ^= 1; + } + accumulate(); // last iteration + } + DI void epilog() { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = blockIdx.x * P::Mblk + i; + sxNorm[i] = idx < this->m ? xn[idx] : maxVal; + } + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = blockIdx.y * P::Nblk + i; + syNorm[i] = idx < this->n ? yn[idx] : maxVal; + } + __syncthreads(); + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + this->accrowid]; + } +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + this->acccolid]; + } #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + acc[i][j] = regxn[i] + regyn[j] - Two * acc[i][j]; } } if (Sqrt) { @@ -155,112 +190,175 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( } } } - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; + // reduce + cub::KeyValuePair val[P::AccRowsPerTh]; + auto lid = raft::laneId(); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) { + auto tmpkey = this->acccolid + j * P::AccThCols + blockIdx.y * P::Nblk; + cub::KeyValuePair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < this->n) val[i] = - pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } + pairRedOp(this->accrowid + i * P::AccThRows + blockIdx.x * P::Mblk, + tmp, val[i]); } - } - }; - - auto rowEpilog_lambda = [m, mutex, min, pairRedOp, redOp, &val, - maxVal] __device__(IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { + __syncthreads(); #pragma unroll for (int j = P::AccThCols / 2; j > 0; j >>= 1) { auto tmpkey = raft::shfl(val[i].key, lid + j); auto tmpvalue = raft::shfl(val[i].value, lid + j); - KVPair tmp = {tmpkey, tmpvalue}; + cub::KeyValuePair tmp = {tmpkey, tmpvalue}; val[i] = - pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + pairRedOp(this->accrowid + i * P::AccThRows + blockIdx.x * P::Mblk, + tmp, val[i]); + } + } + if (lid % P::AccThCols == 0) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + sRed[i * P::AccThCols + this->accrowid] = val[i]; } } + __syncthreads(); + updateResults(); + } - updateReducedVal(mutex, min, val, red_op, - m, gridStrideY); + /* + * todo: From Volta onwards see if "coalesced" atomicCAS approach as + * written below helps improve perf + * ``` + * auto tid = threadIdx.x; + * auto rid = IdxT(blockIdx.x) * P::Mblk + tid; + * if (rid < m) { + * auto val = sRed[i]; + * while (atomicCAS(mutex + rid, 0, 1) == 1) + * ; + * __threadfence(); + * redOp(rid, min + rid, val); + * __threadfence(); + * atomicCAS(mutex + rid, 1, 0); + * } + * ``` + */ + DI void updateResults() { + // for now have first lane from each warp update a unique output row. This + // will resolve hang issues with pre-Volta architectures + auto nWarps = blockDim.x / raft::WarpSize; + auto lid = raft::laneId(); + auto ridx = IdxT(blockIdx.x) * P::Mblk; + if (lid == 0) { + for (int i = threadIdx.x / raft::WarpSize; i < P::Mblk; i += nWarps) { + auto rid = ridx + i; + if (rid < this->m) { + auto val = sRed[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + redOp(rid, min + rid, val); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } - // reset the val array. + DI void accumulate() { #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; + for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { + this->ldsXY(ki); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { +#pragma unroll + for (int v = 0; v < P::Veclen; ++v) { + acc[i][j] += this->regx[i][v] * this->regy[j][v]; + } + } + } + } + } + +#if (ENABLE_MEMCPY_ASYNC == 1) + DI void ldgXY(IdxT kidx) { + auto koffset = kidx + this->scolid; + auto offset = + this->pageWr * P::SmemPage + this->srowid * P::SmemStride + this->scolid; + auto* saddrx = this->sx + offset; + for (int i = 0; i < P::LdgPerThX; ++i) { + auto* sax = saddrx + i * P::LdgRowsX * P::SmemStride; + auto* gax = this->x + i * P::LdgRowsX * this->k + koffset; + auto inside = + koffset < this->k && (this->xrowid + i * P::LdgRowsX) < this->m; + __pipeline_memcpy_async(sax, inside ? gax : nullptr, SizeAndAlign, + inside ? 0 : SizeAndAlign); + } + auto* saddry = this->sy + offset; + for (int i = 0; i < P::LdgPerThY; ++i) { + auto* say = saddry + i * P::LdgRowsY * P::SmemStride; + auto* gay = this->y + i * P::LdgRowsY * this->k + koffset; + auto inside = + koffset < this->k && (this->yrowid + i * P::LdgRowsY) < this->n; + __pipeline_memcpy_async(say, inside ? gay : nullptr, SizeAndAlign, + inside ? 0 : SizeAndAlign); } - }; - - IdxT lda = k, ldb = k, ldd = n; - PairwiseDistances - obj(x, y, m, n, k, lda, ldb, ldd, xn, yn, nullptr, smem, core_op, - epilog_lambda, fin_op, rowEpilog_lambda); + pipe.commit(); + } + + DI void stsXY() { pipe.wait_prior<0>(); } +#endif // ENABLE_MEMCPY_ASYNC +}; // struct FusedL2NN + +template +__global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2NNkernel( + OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, + IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, + KVPReduceOpT pairRedOp) { + extern __shared__ char smem[]; + FusedL2NN obj( + min, x, y, xn, yn, m, n, k, smem, maxVal, mutex, redOp, pairRedOp); obj.run(); } +template +__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { + redOp.init(min + tid, maxVal); + } +} + template void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, bool initOutBuffer, cudaStream_t stream) { - typedef typename linalg::Policy4x4::Policy P; - - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef cub::KeyValuePair KVPair; - - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { - acc += x * y; - }; - + typedef typename linalg::Policy4x4::Policy Policy; + dim3 grid(raft::ceildiv(m, Policy::Mblk), + raft::ceildiv(n, Policy::Nblk)); + dim3 blk(Policy::Nthreads); + auto nblks = raft::ceildiv(m, Policy::Nthreads); + auto maxVal = std::numeric_limits::max(); CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel - <<>>(min, m, maxVal, redOp); + <<>>(min, m, maxVal, redOp); CUDA_CHECK(cudaGetLastError()); } - - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - - constexpr size_t shmemSize = - P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { - auto fusedL2NNSqrt = - fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, - core_lambda, fin_op); + fusedL2NNkernel + <<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp); } else { - auto fusedL2NN = - fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>(min, x, y, xn, yn, m, n, k, - maxVal, workspace, redOp, - pairRedOp, core_lambda, fin_op); + fusedL2NNkernel + <<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp); } - CUDA_CHECK(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/l1.cuh index 6ab084f041..ce4fbb33e3 100644 --- a/cpp/include/raft/distance/l1.cuh +++ b/cpp/include/raft/distance/l1.cuh @@ -53,6 +53,8 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, typedef typename std::conditional::type KPolicy; + dim3 grid(raft::ceildiv(m, KPolicy::Mblk), + raft::ceildiv(n, KPolicy::Nblk)); dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda @@ -64,30 +66,22 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, DataT * regyn, IdxT gridStrideX, - IdxT gridStrideY) { return; }; + DataT * regxn, DataT * regyn) { return; }; if (isRowMajor) { - auto l1RowMajor = - pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, KPolicy::SmemSize, l1RowMajor); - - l1RowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + pairwiseDistanceMatKernel + <<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } else { - auto l1ColMajor = - pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, KPolicy::SmemSize, l1ColMajor); - l1ColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + pairwiseDistanceMatKernel + <<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); } CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/pairwise_distance_base.cuh index 503397bac9..4e1605b887 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/pairwise_distance_base.cuh @@ -14,11 +14,8 @@ * limitations under the License. */ #pragma once -#include -#include #include #include -#include namespace raft { namespace distance { @@ -56,12 +53,9 @@ namespace distance { * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda */ - template > struct PairwiseDistances : public BaseClass { @@ -69,13 +63,13 @@ struct PairwiseDistances : public BaseClass { typedef Policy P; const DataT* xn; const DataT* yn; - const DataT* const yBase; + DataT* sxNorm; + DataT* syNorm; OutT* dOutput; char* smem; CoreLambda core_op; EpilogueLambda epilog_op; FinalLambda fin_op; - rowEpilogueLambda rowEpilog_op; AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; @@ -85,95 +79,34 @@ struct PairwiseDistances : public BaseClass { IdxT _k, IdxT _lda, IdxT _ldb, IdxT _ldd, const DataT* _xn, const DataT* _yn, OutT* _dOutput, char* _smem, CoreLambda _core_op, - EpilogueLambda _epilog_op, FinalLambda _fin_op, - rowEpilogueLambda _rowEpilog_op) + EpilogueLambda _epilog_op, FinalLambda _fin_op) : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), + sxNorm((DataT*)_smem), + syNorm(&(sxNorm[P::Mblk])), xn(_xn), yn(_yn), - yBase(_y), dOutput(_dOutput), smem(_smem), core_op(_core_op), epilog_op(_epilog_op), - fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op) {} + fin_op(_fin_op) {} DI void run() { - for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m; - gridStrideY += P::Mblk * gridDim.y) { - for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { - prolog(gridStrideX, gridStrideY); - loop(); - epilog(gridStrideX, gridStrideY); - } - rowEpilog_op(gridStrideY); - } + prolog(); + loop(); + epilog(); } private: - DI void updateIndicesY() { - const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; - this->pageWr = 0; - this->pageRd = 0; - } - - DI void updateIndicesXY() { - const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; - this->pageWr = 0; - this->pageRd = 0; - } - - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) { - if (gridStrideX > blockIdx.x * P::Nblk) { - updateIndicesY(); - } else if (gridStrideY > blockIdx.y * P::Mblk) { - updateIndicesXY(); - } - - typedef TxN_t VecType; - VecType zeros; - zeros.fill(BaseClass::Zero); -#pragma unroll - for (int j = 0; j < P::LdgPerThX; ++j) { - zeros.store(&this->ldgDataX[j][0], 0); - } -#pragma unroll - for (int j = 0; j < P::LdgPerThY; ++j) { - zeros.store(&this->ldgDataY[j][0], 0); - } - + DI void prolog() { this->ldgXY(0); - #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - zeros.store(&this->regx[i][0], 0); #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { acc[i][j] = BaseClass::Zero; } } -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - zeros.store(&this->regy[j][0], 0); - } - this->stsXY(); __syncthreads(); this->pageWr ^= 1; @@ -208,24 +141,19 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT gridStrideX, IdxT gridStrideY) { + DI void epilog() { if (useNorms) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); + __syncthreads(); // so that we can safely reuse smem // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = blockIdx.x * P::Mblk + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; } - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = gridStrideX + i; + auto idx = blockIdx.y * P::Nblk + i; syNorm[i] = idx < this->n ? yn[idx] : 0; } - __syncthreads(); DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; @@ -238,24 +166,21 @@ struct PairwiseDistances : public BaseClass { regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } - epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); + epilog_op(acc, regxn, regyn); } else { - epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + epilog_op(acc, nullptr, nullptr); } - if (writeOut) { - IdxT starty = gridStrideY + this->accrowid; - IdxT startx = gridStrideX + this->acccolid; - + IdxT startx = blockIdx.x * P::Mblk + this->accrowid; + IdxT starty = blockIdx.y * P::Nblk + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = startx + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - dOutput[rowId * this->n + colId] = fin_op(acc[i][j], 0); - } + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = starty + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + dOutput[rowId * this->n + colId] = fin_op(acc[i][j], 0); } } } @@ -292,11 +217,9 @@ struct PairwiseDistances : public BaseClass { * @param epilog_op the epilogue lambda * @param fin_op the final gemm epilogue lambda */ - template + typename EpilogueLambda, typename FinalLambda, bool isRowMajor = true> __global__ __launch_bounds__( Policy::Nthreads, 2) void pairwiseDistanceMatKernel(const DataT* x, const DataT* y, @@ -306,39 +229,13 @@ __global__ __launch_bounds__( EpilogueLambda epilog_op, FinalLambda fin_op) { extern __shared__ char smem[]; - auto rowEpilog = [] __device__(IdxT starty) { return; }; PairwiseDistances + EpilogueLambda, FinalLambda, isRowMajor> obj(x, y, m, n, k, lda, ldb, ldd, _xn, _yn, dOutput, smem, core_op, - epilog_op, fin_op, rowEpilog); + epilog_op, fin_op); obj.run(); } -template -dim3 launchConfigGenerator(IdxT m, IdxT n, size_t sMemSize, T func) { - const auto numSMs = raft::getMultiProcessorCount(); - int numBlocksPerSm = 0; - dim3 grid; - - CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &numBlocksPerSm, func, P::Nthreads, sMemSize)); - int minGridSize = numSMs * numBlocksPerSm; - int yChunks = raft::ceildiv(m, P::Mblk); - int xChunks = raft::ceildiv(n, P::Nblk); - grid.y = yChunks > minGridSize ? minGridSize : yChunks; - grid.x = (minGridSize - grid.y) <= 0 ? 1 : xChunks; - if (grid.x != 1) { - int i = 1; - while (grid.y * i < minGridSize) { - i++; - } - grid.x = i >= xChunks ? xChunks : i; - } - - return grid; -} - }; // namespace distance }; // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index c590abb142..86d608ea87 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -293,13 +293,13 @@ struct Contractions_NT { pageWr(0), pageRd(0) { if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; + xrowid = IdxT(blockIdx.x) * P::Mblk + srowid; + yrowid = IdxT(blockIdx.y) * P::Nblk + srowid; x = _x + xrowid * lda; y = _y + yrowid * ldb; } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; + xrowid = IdxT(blockIdx.x) * P::Mblk; + yrowid = IdxT(blockIdx.y) * P::Nblk; x = _x + xrowid + srowid * lda; y = _y + yrowid + srowid * ldb; } diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 4573a070b6..d4e39a0b5e 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -164,6 +164,7 @@ struct CompareApproxAbsKVP { typedef typename cub::KeyValuePair KVP; CompareApproxAbsKVP(T eps_) : eps(eps_) {} bool operator()(const KVP &a, const KVP &b) const { + if (a.key != b.key) return false; T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); T m = std::max(raft::abs(a.value), raft::abs(b.value)); T ratio = m >= eps ? diff / m : diff; @@ -178,6 +179,7 @@ template struct CompareExactKVP { typedef typename cub::KeyValuePair KVP; bool operator()(const KVP &a, const KVP &b) const { + if (a.key != b.key) return false; if (a.value != b.value) return false; return true; }