Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support in column major distance metrics to use contractions_nt instead of cutlass #3691

Merged
merged 25 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c0ceb93
Remove cutlass usage in row major input for euclidean exp/unexp dista…
mdoijade Mar 5, 2021
c3fbbac
Remove cutlass usage in row major input for cosine distance matrix ca…
mdoijade Mar 9, 2021
707fdb0
Remove cutlass usage in row major input for L1 distance matrix calcul…
mdoijade Mar 10, 2021
b3a56a3
fix errors reported by clang-format and copyright checker in cosine a…
mdoijade Mar 16, 2021
3ff0707
add review changes - move kernel to base class, move common epilog co…
mdoijade Mar 18, 2021
755da05
Merge branch 'branch-0.19' into branch-0.19
mdoijade Mar 19, 2021
a33fc96
add useNorms template arg instead of distanceType to select if norms …
mdoijade Mar 23, 2021
1e841ae
Merge branch 'branch-0.19' of https://github.com/mdoijade/cuml into b…
mdoijade Mar 23, 2021
559a418
fix clang-format reported errors and make RAFT git tag to be ToT of c…
mdoijade Mar 23, 2021
03426d7
Merge branch 'branch-0.19' of https://github.com/rapidsai/cuml into b…
mdoijade Mar 24, 2021
707ed1a
Add column major contraction_nt kernels usage to all pairwise distanc…
mdoijade Mar 30, 2021
cc2b9b6
fix incorrect veclen for column major by setting bytes based on lda, …
mdoijade Mar 31, 2021
8361492
Remove cutlass outputile from distance metrics template arg, grammatr…
mdoijade Mar 31, 2021
b4af5f3
Remove cutlass outputile from prims benchmark and sg tests which got …
mdoijade Mar 31, 2021
2d3dae9
Add column/row major input support to distance prims benchmark
mdoijade Mar 31, 2021
4c39580
Merge branch 'branch-0.19' of https://github.com/rapidsai/cuml into b…
mdoijade Apr 1, 2021
4d1c035
Merge branch-0.19 latest changes
mdoijade Apr 1, 2021
cd55b23
fix clang format issues with distnace_common.cuh
mdoijade Apr 1, 2021
af99d76
fix copyright year issues on modified files
mdoijade Apr 1, 2021
5168bb8
Add contractions_nt col major RAFT commit in dependencies
mdoijade Apr 5, 2021
6d173f3
Remove redundant new line in dependencies.cmake
mdoijade Apr 5, 2021
81d5557
Merge branch-0.19 into column_major_distance_prims
mdoijade Apr 5, 2021
dbdbaa3
Merge branch 'branch-0.20' of https://github.com/rapidsai/cuml into c…
mdoijade Apr 6, 2021
09185ea
Merge branch 'branch-0.20' into column_major_distance_prims
mdoijade Apr 7, 2021
ff8c923
Merge branch 'branch-0.20' into column_major_distance_prims
mdoijade Apr 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions cpp/bench/prims/distance_common.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,7 @@ namespace Distance {

struct Params {
int m, n, k;
bool isRowMajor;
}; // struct Params

template <typename T, raft::distance::DistanceType DType>
Expand Down Expand Up @@ -54,11 +55,10 @@ struct Distance : public Fixture {
}

void runBenchmark(::benchmark::State& state) override {
typedef cutlass::Shape<8, 128, 128> OutputTile_t;
loopOnState(state, [this]() {
MLCommon::Distance::distance<DType, T, T, T, OutputTile_t>(
MLCommon::Distance::distance<DType, T, T, T>(
x, y, out, params.m, params.n, params.k, (void*)workspace, worksize,
stream);
stream, params.isRowMajor);
});
}

Expand All @@ -71,13 +71,25 @@ struct Distance : public Fixture {

static std::vector<Params> getInputs() {
return {
{32, 16384, 16384}, {64, 16384, 16384}, {128, 16384, 16384},
{256, 16384, 16384}, {512, 16384, 16384}, {1024, 16384, 16384},
{16384, 32, 16384}, {16384, 64, 16384}, {16384, 128, 16384},
{16384, 256, 16384}, {16384, 512, 16384}, {16384, 1024, 16384},
{16384, 16384, 32}, {16384, 16384, 64}, {16384, 16384, 128},
{16384, 16384, 256}, {16384, 16384, 512}, {16384, 16384, 1024},
{16384, 16384, 16384},
{32, 16384, 16384, true}, {64, 16384, 16384, true},
{128, 16384, 16384, true}, {256, 16384, 16384, true},
{512, 16384, 16384, true}, {1024, 16384, 16384, true},
{16384, 32, 16384, true}, {16384, 64, 16384, true},
{16384, 128, 16384, true}, {16384, 256, 16384, true},
{16384, 512, 16384, true}, {16384, 1024, 16384, true},
{16384, 16384, 32, true}, {16384, 16384, 64, true},
{16384, 16384, 128, true}, {16384, 16384, 256, true},
{16384, 16384, 512, true}, {16384, 16384, 1024, true},
{16384, 16384, 16384, true}, {32, 16384, 16384, false},
{64, 16384, 16384, false}, {128, 16384, 16384, false},
{256, 16384, 16384, false}, {512, 16384, 16384, false},
{1024, 16384, 16384, false}, {16384, 32, 16384, false},
{16384, 64, 16384, false}, {16384, 128, 16384, false},
{16384, 256, 16384, false}, {16384, 512, 16384, false},
{16384, 1024, 16384, false}, {16384, 16384, 32, false},
{16384, 16384, 64, false}, {16384, 16384, 128, false},
{16384, 16384, 256, false}, {16384, 16384, 512, false},
{16384, 16384, 1024, false}, {16384, 16384, 16384, false},
};
}

Expand Down
172 changes: 66 additions & 106 deletions cpp/src_prims/distance/cosine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,9 @@
*/

#pragma once
#include <linalg/eltwise2d.cuh>
#include "distance_fragment_multiply_add.cuh"
#include "pairwise_distance_base.cuh"

#include <linalg/cutlass_gemm.cuh>
#include <raft/linalg/norm.cuh>
#include "distance_epilogue.cuh"
#include "distance_epilogue_functor.cuh"
#include "distance_epilogue_traits.h"

#include <cutlass/gemm/gemm_epilogue_traits.h>
#include <cutlass/gemm/thread_multiply_add.h>
#include <cutlass/shape.h>

#include <type_traits>
#include "pairwise_distance_base.cuh"

namespace MLCommon {
namespace Distance {
Expand All @@ -45,26 +33,37 @@ namespace Distance {
* @tparam Veclen number of k-elements loaded by each thread for every LDG call
* it makes. check contractions.cuh for details.
* @tparam FinalLambda the final lambda called on final distance value
* @tparam isRowMajor true if input/output is row major,
false for column major
* @param[in] x input matrix
* @param[in] y input matrix
* @param[in] xn row norms of input matrix A.
* @param[in] yn row norms of input matrix B.
* @param[in] m number of rows of A and C/D
* @param[in] n number of columns of B and C/D
* @param[in] k number of cols of A and rows of B
* @param[in] lda leading dimension of A
* @param[in] ldb leading dimension of B
* @param[in] ldd leading dimension of C/D
* @param[output] pD output matrix
* @param fin_op the final gemm epilogue lambda
* @param stream cuda stream to launch cuda operations.
*/
template <typename DataT, typename AccT, typename OutT, typename IdxT,
int VecLen, typename FinalLambda>
int VecLen, typename FinalLambda, bool isRowMajor>
void cosineImpl(const DataT *x, const DataT *y, const DataT *xn,
const DataT *yn, IdxT m, IdxT n, IdxT k, OutT *dOutput,
FinalLambda fin_op, cudaStream_t stream) {
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy Policy;
dim3 grid(raft::ceildiv<int>(m, Policy::Mblk),
raft::ceildiv<int>(n, Policy::Nblk));
dim3 blk(Policy::Nthreads);
const DataT *yn, IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb,
IdxT ldd, OutT *dOutput, FinalLambda fin_op,
cudaStream_t stream) {
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy;
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy;

typedef
typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy;

dim3 grid(raft::ceildiv<int>(m, KPolicy::Mblk),
raft::ceildiv<int>(n, KPolicy::Nblk));
dim3 blk(KPolicy::Nthreads);

// Accumulation operation lambda
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) {
Expand All @@ -73,41 +72,54 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn,

// epilogue operation lambda for final value calculation
auto epilog_lambda = [] __device__(
AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh],
AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh],
DataT * regxn, DataT * regyn) {
#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < Policy::AccColsPerTh; ++j) {
for (int j = 0; j < KPolicy::AccColsPerTh; ++j) {
acc[i][j] = acc[i][j] / (regxn[i] * regyn[j]);
}
}
};

pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, Policy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda>
<<<grid, blk, Policy::SmemSize, stream>>>(
x, y, xn, yn, m, n, k, dOutput, core_lambda, epilog_lambda, fin_op);
if (isRowMajor) {
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, true>
<<<grid, blk, KPolicy::SmemSize, stream>>>(x, y, xn, yn, m, n, k, lda,
ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
} else {
pairwiseDistanceMatKernel<true, DataT, AccT, OutT, IdxT, KPolicy,
decltype(core_lambda), decltype(epilog_lambda),
FinalLambda, false>
<<<grid, blk, KPolicy::SmemSize, stream>>>(x, y, xn, yn, m, n, k, lda,
ldb, ldd, dOutput, core_lambda,
epilog_lambda, fin_op);
}

CUDA_CHECK(cudaGetLastError());
}

template <typename DataT, typename AccT, typename OutT, typename IdxT,
typename FinalLambda>
void cosine(IdxT m, IdxT n, IdxT k, const DataT *x, const DataT *y,
const DataT *xn, const DataT *yn, OutT *dOutput, FinalLambda fin_op,
cudaStream_t stream) {
size_t bytes = sizeof(DataT) * k;
if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) {
cosineImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), FinalLambda>(
x, y, xn, yn, m, n, k, dOutput, fin_op, stream);
} else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) {
cosineImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), FinalLambda>(
x, y, xn, yn, m, n, k, dOutput, fin_op, stream);
typename FinalLambda, bool isRowMajor>
void cosine(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd,
const DataT *x, const DataT *y, const DataT *xn, const DataT *yn,
OutT *dOutput, FinalLambda fin_op, cudaStream_t stream) {
size_t bytesA = sizeof(DataT) * lda;
size_t bytesB = sizeof(DataT) * ldb;
if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) {
cosineImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), FinalLambda,
isRowMajor>(x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput,
fin_op, stream);
} else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) {
cosineImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), FinalLambda,
isRowMajor>(x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput,
fin_op, stream);
} else {
cosineImpl<DataT, AccT, OutT, IdxT, 1, FinalLambda>(
x, y, xn, yn, m, n, k, dOutput, fin_op, stream);
cosineImpl<DataT, AccT, OutT, IdxT, 1, FinalLambda, isRowMajor>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream);
}
}

Expand Down Expand Up @@ -135,7 +147,7 @@ void cosine(IdxT m, IdxT n, IdxT k, const DataT *x, const DataT *y,
* @param isRowMajor whether the input and output matrices are row major
*/
template <typename InType, typename AccType, typename OutType,
typename OutputTile_, typename FinalLambda, typename Index_ = int>
typename FinalLambda, typename Index_ = int>
void cosineAlgo1(Index_ m, Index_ n, Index_ k, const InType *pA,
const InType *pB, OutType *pD, AccType *workspace,
size_t worksize, FinalLambda fin_op, cudaStream_t stream,
Expand All @@ -148,12 +160,16 @@ void cosineAlgo1(Index_ m, Index_ n, Index_ k, const InType *pA,
};

typedef std::is_same<OutType, bool> is_bool;
typedef typename std::conditional<is_bool::value, OutType, AccType>::type
CosOutType;
CosOutType *pDcast = reinterpret_cast<CosOutType *>(pD);

ASSERT(!(((pA != pB) && (worksize < (m + n) * sizeof(AccType))) ||
(worksize < m * sizeof(AccType))),
"workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

Index_ lda, ldb, ldd;
InType *col_vec = workspace;
InType *row_vec = workspace;
if (pA != pB) {
Expand All @@ -168,71 +184,15 @@ void cosineAlgo1(Index_ m, Index_ n, Index_ k, const InType *pA,
}

if (isRowMajor) {
typedef typename std::conditional<is_bool::value, OutType, AccType>::type
CosOutType;

cosine<InType, AccType, CosOutType, Index_, decltype(wrapped_fin_op)>(
m, n, k, pA, pB, col_vec, row_vec, reinterpret_cast<CosOutType *>(pD),
wrapped_fin_op, stream);
lda = k, ldb = k, ldd = n;
cosine<InType, AccType, CosOutType, Index_, decltype(wrapped_fin_op), true>(
m, n, k, lda, ldb, ldd, pA, pB, col_vec, row_vec, pDcast, wrapped_fin_op,
stream);
} else {
typedef ExpandedDistanceFragmentMultiplyAdd<CosFusedDistance>
FragmentMultiplyAdd_;
typedef typename std::conditional<is_bool::value, AccType, OutType>::type
EffOutType;
EffOutType *pDCast =
reinterpret_cast<EffOutType *>(pD); // Pretend to be EffOutType;
typedef typename cutlass::Shape<8, 8, 8> AccumulatorsPerThread_;
typedef cutlass::gemm::ThreadMultiplyAdd<
AccumulatorsPerThread_, cutlass::Shape<1, 4, 8>, InType, InType, AccType>
MainLoopFunctor_;
typedef LinAlg::CustomGemmConfig<InType, AccType, EffOutType, OutputTile_,
AccumulatorsPerThread_, MainLoopFunctor_>
GemmConfig_;

typedef ExpandedDistanceEpilogueFunctor<InType, AccType, GemmConfig_,
FragmentMultiplyAdd_>
EpilogueFunctor_;

typedef typename std::conditional<
is_bool::value,
BoolEpilogueTraitsHelper<GemmConfig_, EpilogueFunctor_, Index_>,
cutlass::gemm::GemmEpilogueTraitsHelper<
GemmConfig_, EpilogueFunctor_, Index_>>::type EpilogueTraitsHelper_;

typedef typename cutlass::gemm::SimplifiedGemmEpilogueTraits<
GemmConfig_, EpilogueFunctor_, Index_, EpilogueTraitsHelper_>
GemmEpilogueTraits_;
typedef ExpandedDistanceGemmEpilogue<GemmEpilogueTraits_> GemmEpilogue_;
typedef typename EpilogueFunctor_::Params EpiParams;

cublasOperation_t transa, transb;
const InType *aPtr, *bPtr;
Index_ lda, ldb, ldd;
Index_ gemm_m, gemm_n;
InType *rvec, *cvec;

transa = CUBLAS_OP_N;
transb = CUBLAS_OP_T;
aPtr = pA;
bPtr = pB;
lda = m;
ldb = n;
ldd = m;
gemm_m = m;
gemm_n = n;
cvec = row_vec;
rvec = col_vec;

LinAlg::gemm<InType, AccType, EffOutType, OutputTile_,
AccumulatorsPerThread_, MainLoopFunctor_, Index_, GemmConfig_,
EpilogueFunctor_, GemmEpilogueTraits_, GemmEpilogue_>(
transa, transb, gemm_m, gemm_n, k, (EffOutType)1, aPtr, lda, bPtr, ldb,
(EffOutType)0, nullptr, ldd, pDCast,
[cvec, rvec] HD(EpiParams & p) {
int err = p.initializeExtra(cvec, rvec, false);
return err;
},
wrapped_fin_op, stream);
lda = n, ldb = m, ldd = m;
cosine<InType, AccType, CosOutType, Index_, decltype(wrapped_fin_op),
false>(n, m, k, lda, ldb, ldd, pB, pA, row_vec, col_vec, pDcast,
wrapped_fin_op, stream);
}
}

Expand Down
Loading