-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW] Add chebyshev, canberra, minkowksi and hellinger distance metrics #276
Changes from 37 commits
3a4ec66
76f9a72
af89085
9c71c4a
4d76b57
da2d768
4ada29e
2e804c2
3408a40
69b316d
6a64b7a
9a30a87
14a2673
21577a4
969c65a
9c4d5a0
df4ce55
4fb00e6
5346232
b5b3c51
0a0f964
2fd7f4c
0f2c03d
484b082
04f656f
753f612
f73471c
8007d7a
45dc556
d0b8947
1c6ee73
1c65ab7
8323a3c
a33546d
d4925b4
7673074
476ed99
59e78e8
7af5e32
0e99113
f4b8d33
a71e520
842fcd0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
/* | ||
* Copyright (c) 2018-2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <raft/distance/pairwise_distance_base.cuh> | ||
|
||
namespace raft { | ||
namespace distance { | ||
|
||
/** | ||
* @brief the canberra distance matrix calculation implementer | ||
* It computes the following equation: cij = max(cij, op(ai-bj)) | ||
* @tparam DataT input data-type (for A and B matrices) | ||
* @tparam AccT accumulation data-type | ||
* @tparam OutT output data-type (for C and D matrices) | ||
* @tparam IdxT index data-type | ||
|
||
* @tparam FinalLambda final lambda called on final distance value | ||
* @tparam isRowMajor true if input/output is row major, | ||
false for column major | ||
* @param[in] x input matrix | ||
* @param[in] y input matrix | ||
* @param[in] m number of rows of A and C/D | ||
* @param[in] n number of columns of B and C/D | ||
* @param[in] k number of cols of A and rows of B | ||
* @param[in] lda leading dimension of A | ||
* @param[in] ldb leading dimension of B | ||
* @param[in] ldd leading dimension of C/D | ||
* @param[output] pD output matrix | ||
* @param fin_op the final gemm epilogue lambda | ||
*/ | ||
template <typename DataT, typename AccT, typename OutT, typename IdxT, | ||
int VecLen, typename FinalLambda, bool isRowMajor> | ||
static void canberraImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, | ||
IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, | ||
FinalLambda fin_op, cudaStream_t stream) { | ||
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy; | ||
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy; | ||
|
||
typedef | ||
typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy; | ||
|
||
dim3 blk(KPolicy::Nthreads); | ||
|
||
// Accumulation operation lambda | ||
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { | ||
const auto diff = raft::L1Op<AccT, IdxT>()(x - y); | ||
const auto add = raft::myAbs(x) + raft::myAbs(y); | ||
// deal with potential for 0 in denominator by | ||
// forcing 1/0 instead | ||
acc += ((add != 0) * diff / (add + (add == 0))); | ||
}; | ||
|
||
// epilogue operation lambda for final value calculation | ||
auto epilog_lambda = [] __device__( | ||
AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], | ||
DataT * regxn, DataT * regyn, IdxT gridStrideX, | ||
IdxT gridStrideY) { return; }; | ||
|
||
if (isRowMajor) { | ||
auto canberraRowMajor = | ||
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy, | ||
decltype(core_lambda), decltype(epilog_lambda), | ||
FinalLambda, true>; | ||
dim3 grid = | ||
launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize, canberraRowMajor); | ||
|
||
canberraRowMajor<<<grid, blk, KPolicy::SmemSize, stream>>>( | ||
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, | ||
epilog_lambda, fin_op); | ||
} else { | ||
auto canberraColMajor = | ||
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy, | ||
decltype(core_lambda), decltype(epilog_lambda), | ||
FinalLambda, false>; | ||
dim3 grid = | ||
launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize, canberraColMajor); | ||
canberraColMajor<<<grid, blk, KPolicy::SmemSize, stream>>>( | ||
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, | ||
epilog_lambda, fin_op); | ||
} | ||
|
||
CUDA_CHECK(cudaGetLastError()); | ||
} | ||
|
||
template <typename DataT, typename AccT, typename OutT, typename IdxT, | ||
typename FinalLambda, bool isRowMajor> | ||
void canberra(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, | ||
const DataT *x, const DataT *y, OutT *dOutput, FinalLambda fin_op, | ||
cudaStream_t stream) { | ||
size_t bytesA = sizeof(DataT) * lda; | ||
size_t bytesB = sizeof(DataT) * ldb; | ||
if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { | ||
canberraImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), FinalLambda, | ||
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, | ||
stream); | ||
} else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { | ||
canberraImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), FinalLambda, | ||
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, | ||
stream); | ||
} else { | ||
canberraImpl<DataT, AccT, OutT, IdxT, 1, FinalLambda, isRowMajor>( | ||
x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); | ||
} | ||
} | ||
|
||
/** | ||
* @brief the canberra distance matrix calculation | ||
* It computes the following equation: cij = max(cij, op(ai-bj)) | ||
* @tparam InType input data-type (for A and B matrices) | ||
* @tparam AccType accumulation data-type | ||
* @tparam OutType output data-type (for C and D matrices) | ||
* @tparam FinalLambda user-defined epilogue lamba | ||
* @tparam Index_ Index type | ||
* @param m number of rows of A and C/D | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add the ins/outs here as well for consistency? |
||
* @param n number of columns of B and C/D | ||
* @param k number of cols of A and rows of B | ||
* @param pA input matrix | ||
* @param pB input matrix | ||
* @param pD output matrix | ||
* @param fin_op the final element-wise epilogue lambda | ||
* @param stream cuda stream where to launch work | ||
* @param isRowMajor whether the input and output matrices are row major | ||
*/ | ||
template <typename InType, typename AccType, typename OutType, | ||
typename FinalLambda, typename Index_ = int> | ||
void canberraImpl(int m, int n, int k, const InType *pA, const InType *pB, | ||
OutType *pD, FinalLambda fin_op, cudaStream_t stream, | ||
bool isRowMajor) { | ||
typedef std::is_same<OutType, bool> is_bool; | ||
typedef typename std::conditional<is_bool::value, OutType, AccType>::type | ||
canberraOutType; | ||
Index_ lda, ldb, ldd; | ||
canberraOutType *pDcast = reinterpret_cast<canberraOutType *>(pD); | ||
if (isRowMajor) { | ||
lda = k, ldb = k, ldd = n; | ||
canberra<InType, AccType, canberraOutType, Index_, FinalLambda, true>( | ||
m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); | ||
} else { | ||
lda = n, ldb = m, ldd = m; | ||
canberra<InType, AccType, canberraOutType, Index_, FinalLambda, false>( | ||
n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); | ||
} | ||
} | ||
} // namespace distance | ||
} // namespace raft |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/* | ||
* Copyright (c) 2018-2021, NVIDIA CORPORATION. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <raft/distance/pairwise_distance_base.cuh> | ||
|
||
namespace raft { | ||
namespace distance { | ||
|
||
/** | ||
* @brief the Chebyshev distance matrix calculation implementer | ||
* It computes the following equation: cij = max(cij, op(ai-bj)) | ||
* @tparam DataT input data-type (for A and B matrices) | ||
* @tparam AccT accumulation data-type | ||
* @tparam OutT output data-type (for C and D matrices) | ||
* @tparam IdxT index data-type | ||
|
||
* @tparam FinalLambda final lambda called on final distance value | ||
* @tparam isRowMajor true if input/output is row major, | ||
false for column major | ||
* @param[in] x input matrix | ||
* @param[in] y input matrix | ||
* @param[in] m number of rows of A and C/D | ||
* @param[in] n number of columns of B and C/D | ||
* @param[in] k number of cols of A and rows of B | ||
* @param[in] lda leading dimension of A | ||
* @param[in] ldb leading dimension of B | ||
* @param[in] ldd leading dimension of C/D | ||
* @param[output] pD output matrix | ||
* @param fin_op the final gemm epilogue lambda | ||
*/ | ||
template <typename DataT, typename AccT, typename OutT, typename IdxT, | ||
int VecLen, typename FinalLambda, bool isRowMajor> | ||
static void chebyshevImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, | ||
IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, | ||
FinalLambda fin_op, cudaStream_t stream) { | ||
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy; | ||
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy; | ||
|
||
typedef | ||
typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy; | ||
|
||
dim3 blk(KPolicy::Nthreads); | ||
|
||
// Accumulation operation lambda | ||
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { | ||
const auto diff = raft::L1Op<AccT, IdxT>()(x - y); | ||
acc = raft::myMax(acc, diff); | ||
}; | ||
|
||
// epilogue operation lambda for final value calculation | ||
auto epilog_lambda = [] __device__( | ||
AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], | ||
DataT * regxn, DataT * regyn, IdxT gridStrideX, | ||
IdxT gridStrideY) { return; }; | ||
|
||
if (isRowMajor) { | ||
auto chebyshevRowMajor = | ||
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy, | ||
decltype(core_lambda), decltype(epilog_lambda), | ||
FinalLambda, true>; | ||
dim3 grid = launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize, | ||
chebyshevRowMajor); | ||
|
||
chebyshevRowMajor<<<grid, blk, KPolicy::SmemSize, stream>>>( | ||
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, | ||
epilog_lambda, fin_op); | ||
} else { | ||
auto chebyshevColMajor = | ||
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy, | ||
decltype(core_lambda), decltype(epilog_lambda), | ||
FinalLambda, false>; | ||
dim3 grid = launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize, | ||
chebyshevColMajor); | ||
chebyshevColMajor<<<grid, blk, KPolicy::SmemSize, stream>>>( | ||
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, | ||
epilog_lambda, fin_op); | ||
} | ||
|
||
CUDA_CHECK(cudaGetLastError()); | ||
} | ||
|
||
template <typename DataT, typename AccT, typename OutT, typename IdxT, | ||
typename FinalLambda, bool isRowMajor> | ||
void chebyshev(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, | ||
const DataT *x, const DataT *y, OutT *dOutput, | ||
FinalLambda fin_op, cudaStream_t stream) { | ||
size_t bytesA = sizeof(DataT) * lda; | ||
size_t bytesB = sizeof(DataT) * ldb; | ||
if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { | ||
chebyshevImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), FinalLambda, | ||
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, | ||
stream); | ||
} else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { | ||
chebyshevImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), FinalLambda, | ||
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, | ||
stream); | ||
} else { | ||
chebyshevImpl<DataT, AccT, OutT, IdxT, 1, FinalLambda, isRowMajor>( | ||
x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); | ||
} | ||
} | ||
|
||
/** | ||
* @brief the chebyshev distance matrix calculation | ||
* It computes the following equation: cij = max(cij, op(ai-bj)) | ||
* @tparam InType input data-type (for A and B matrices) | ||
* @tparam AccType accumulation data-type | ||
* @tparam OutType output data-type (for C and D matrices) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, maybe add the ins/outs for consistency w/ the above docs (also apply to the distances below) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done. |
||
* @tparam FinalLambda user-defined epilogue lamba | ||
* @tparam Index_ Index type | ||
* @param m number of rows of A and C/D | ||
* @param n number of columns of B and C/D | ||
* @param k number of cols of A and rows of B | ||
* @param pA input matrix | ||
* @param pB input matrix | ||
* @param pD output matrix | ||
* @param fin_op the final element-wise epilogue lambda | ||
* @param stream cuda stream where to launch work | ||
* @param isRowMajor whether the input and output matrices are row major | ||
*/ | ||
template <typename InType, typename AccType, typename OutType, | ||
typename FinalLambda, typename Index_ = int> | ||
void chebyshevImpl(int m, int n, int k, const InType *pA, const InType *pB, | ||
OutType *pD, FinalLambda fin_op, cudaStream_t stream, | ||
bool isRowMajor) { | ||
typedef std::is_same<OutType, bool> is_bool; | ||
typedef typename std::conditional<is_bool::value, OutType, AccType>::type | ||
chebyshevOutType; | ||
Index_ lda, ldb, ldd; | ||
chebyshevOutType *pDcast = reinterpret_cast<chebyshevOutType *>(pD); | ||
if (isRowMajor) { | ||
lda = k, ldb = k, ldd = n; | ||
chebyshev<InType, AccType, chebyshevOutType, Index_, FinalLambda, true>( | ||
m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); | ||
} else { | ||
lda = n, ldb = m, ldd = m; | ||
chebyshev<InType, AccType, chebyshevOutType, Index_, FinalLambda, false>( | ||
n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); | ||
} | ||
} | ||
} // namespace distance | ||
} // namespace raft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a new file, we should just put
2021
here.