Skip to content

Commit

Permalink
Add Grid stride pairwise dist and fused L2 NN kernels (#232)
Browse files Browse the repository at this point in the history
This PR addresses issues mentioned in #221
-- Adds grid stride based fusedL2NN kernel, this gives approx 1.85x speed up over previous version of this kernel.
-- Adds support in pairwise dist base class to work for any input size by adding support for grid stride based work distribution.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Thejaswi. N. S (https://github.com/teju85)
  - Divye Gala (https://github.com/divyegala)
  - Alex Fender (https://github.com/afender)

URL: #232
  • Loading branch information
mdoijade authored Jun 2, 2021
1 parent 00c0401 commit e8f1862
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 328 deletions.
35 changes: 20 additions & 15 deletions cpp/include/raft/distance/cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn,
typedef
typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy;

dim3 grid(raft::ceildiv<int>(m, KPolicy::Mblk),
raft::ceildiv<int>(n, KPolicy::Nblk));
dim3 blk(KPolicy::Nthreads);

// Accumulation operation lambda
Expand All @@ -73,7 +71,8 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn,
// epilogue operation lambda for final value calculation
auto epilog_lambda = [] __device__(
AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh],
DataT * regxn, DataT * regyn) {
DataT * regxn, DataT * regyn, IdxT gridStrideX,
IdxT gridStrideY) {
#pragma unroll
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
#pragma unroll
Expand All @@ -83,20 +82,26 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn,
}
};

constexpr size_t shmemSize =
KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT));
if (isRowMajor) {
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, true>
<<<grid, blk, KPolicy::SmemSize, stream>>>(x, y, xn, yn, m, n, k, lda,
ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
auto cosineRowMajor =
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, true>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineRowMajor);
cosineRowMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda,
fin_op);
} else {
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, false>
<<<grid, blk, KPolicy::SmemSize, stream>>>(x, y, xn, yn, m, n, k, lda,
ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
auto cosineColMajor =
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, false>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, shmemSize, cosineColMajor);
cosineColMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda,
fin_op);
}

CUDA_CHECK(cudaGetLastError());
Expand Down
77 changes: 47 additions & 30 deletions cpp/include/raft/distance/euclidean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn,
typedef
typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy;

dim3 grid(raft::ceildiv<int>(m, KPolicy::Mblk),
raft::ceildiv<int>(n, KPolicy::Nblk));
dim3 blk(KPolicy::Nthreads);

// Accumulation operation lambda
Expand All @@ -72,7 +70,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn,
// epilogue operation lambda for final value calculation
auto epilog_lambda = [sqrt] __device__(
AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh],
DataT * regxn, DataT * regyn) {
DataT * regxn, DataT * regyn, IdxT gridStrideX,
IdxT gridStrideY) {
#pragma unroll
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
#pragma unroll
Expand All @@ -91,20 +90,29 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn,
}
};

constexpr size_t shmemSize =
KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT));
if (isRowMajor) {
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, true>
<<<grid, blk, KPolicy::SmemSize, stream>>>(x, y, xn, yn, m, n, k, lda,
ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
auto euclideanExpRowMajor =
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, true>;
dim3 grid =
launchConfigGenerator<KPolicy>(m, n, shmemSize, euclideanExpRowMajor);

euclideanExpRowMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda,
fin_op);
} else {
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, false>
<<<grid, blk, KPolicy::SmemSize, stream>>>(x, y, xn, yn, m, n, k, lda,
ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
auto euclideanExpColMajor =
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, false>;
dim3 grid =
launchConfigGenerator<KPolicy>(m, n, shmemSize, euclideanExpColMajor);
euclideanExpColMajor<<<grid, blk, shmemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda,
fin_op);
}

CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -229,8 +237,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,

typedef
typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy;
dim3 grid(raft::ceildiv<int>(m, KPolicy::Mblk),
raft::ceildiv<int>(n, KPolicy::Nblk));

dim3 blk(KPolicy::Nthreads);

// Accumulation operation lambda
Expand All @@ -242,7 +249,8 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
// epilogue operation lambda for final value calculation
auto epilog_lambda = [sqrt] __device__(
AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh],
DataT * regxn, DataT * regyn) {
DataT * regxn, DataT * regyn, IdxT gridStrideX,
IdxT gridStrideY) {
if (sqrt) {
#pragma unroll
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
Expand All @@ -255,19 +263,28 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
};

if (isRowMajor) {
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda>
<<<grid, blk, KPolicy::SmemSize, stream>>>(
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
auto euclideanUnExpRowMajor =
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, true>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize,
euclideanUnExpRowMajor);

euclideanUnExpRowMajor<<<grid, blk, KPolicy::SmemSize, stream>>>(
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);

} else {
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, isRowMajor>
<<<grid, blk, KPolicy::SmemSize, stream>>>(
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
auto euclideanUnExpColMajor =
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, false>;
dim3 grid = launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize,
euclideanUnExpColMajor);

euclideanUnExpColMajor<<<grid, blk, KPolicy::SmemSize, stream>>>(
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
}

CUDA_CHECK(cudaGetLastError());
Expand Down
Loading

0 comments on commit e8f1862

Please sign in to comment.