-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REVIEW] Add chebyshev, canberra, minkowksi and hellinger distance metrics #276
Merged
Merged
Changes from all commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
3a4ec66
Refactor fusedL2NN to use pairwiseDistance class. invert block y/x di…
mdoijade 76f9a72
-- add grid stride support to pairwise distance based cosine, l2, l1 …
mdoijade af89085
--Add grid stride based fusedL2NN kernel, this gives approx 1.67x spe…
mdoijade 9c71c4a
Add note on reason to use thread 0 from each warp to write final redu…
mdoijade 4d76b57
fix clangformat and copyright year
mdoijade da2d768
Merge branch 'branch-21.06' into gridStridedDist
mdoijade 4ada29e
--Add additional Mblk + Nblk shmem for storing norms, and reuse xNorm…
mdoijade 2e804c2
Use cudaOccupancyMaxActiveBlocksPerSM instead of hard-coded launch bo…
mdoijade 3408a40
Merge branch 'branch-21.06' into gridStridedDist
mdoijade 69b316d
initialize regx and regy during each prolog call
mdoijade 6a64b7a
Add chebyshev distance metric support
mdoijade 9a30a87
initialize ldgX, ldgY in prolog
mdoijade 14a2673
Merge branch 'gridStridedDist' into chebyshevDist
mdoijade 21577a4
Add hellinger distance metric support
mdoijade 969c65a
Merge branch 'branch-21.08' into gridStridedDist
mdoijade 9c4d5a0
add syncthreads post epilog calc for non-norm distance metrics to mak…
mdoijade df4ce55
Merge branch 'gridStridedDist' into chebyshevDist
mdoijade 4fb00e6
remove syncthreads in epilog and instead use ping-pong buffers in nex…
mdoijade 5346232
Add minkowski distance metric
mdoijade b5b3c51
use ping-pong buffers for safely grid striding
mdoijade 0a0f964
Merge branch 'gridStridedDist' into chebyshevDist
mdoijade 2fd7f4c
Add canberra distance metric support
mdoijade 0f2c03d
fix build failure of mst and knn test by adding cuda stream arg to rm…
mdoijade 484b082
temp commit for test rerun
mdoijade 04f656f
use ucx-py version 0.21 to temp resolve ci build failures
mdoijade 753f612
Merge branch 'fix_mst_knn_test' into gridStridedDist
mdoijade f73471c
merge branch-21.08
mdoijade 8007d7a
Merge branch 'fix_mst_knn_test' into gridStridedDist
mdoijade 45dc556
Merge branch 'branch-21.08' into gridStridedDist
mdoijade d0b8947
Merge branch 'branch-21.08' into gridStridedDist
mdoijade 1c6ee73
Merge branch 'gridStridedDist' into chebyshevDist
mdoijade 1c65ab7
remove redundant metric_arg parameter from canberra function launch
mdoijade 8323a3c
reduce sqrt in hellinger my merging prod in sqrt
mdoijade a33546d
rename minkowksi to be similar to other functions, fix documentation …
mdoijade d4925b4
Merge branch 'branch-21.08' into chebyshevDist
mdoijade 7673074
fix clang format issue
mdoijade 476ed99
fix hellinger inputs to be only in range of 0 to 1 as hellinger is ex…
mdoijade 59e78e8
fix doc issues in all dist functions, also fix copyright year to be o…
mdoijade 7af5e32
reduce sqrt in hellinger usage by overwriting input matrices by sqrt …
mdoijade 0e99113
fix clang format issues
mdoijade f4b8d33
hellinger: only sqrt inputs when x & y are not same.
mdoijade a71e520
fix clang format issues
mdoijade 842fcd0
Merge branch 'branch-21.08' into chebyshevDist
mdoijade File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <raft/distance/pairwise_distance_base.cuh> | ||
|
||
namespace raft { | ||
namespace distance { | ||
|
||
/** | ||
* @brief the canberra distance matrix calculation implementer | ||
* It computes the following equation: cij = max(cij, op(ai-bj)) | ||
* @tparam DataT input data-type (for A and B matrices) | ||
* @tparam AccT accumulation data-type | ||
* @tparam OutT output data-type (for C and D matrices) | ||
* @tparam IdxT index data-type | ||
* @tparam Veclen number of k-elements loaded by each thread | ||
for every LDG call. details in contractions.cuh | ||
* @tparam FinalLambda final lambda called on final distance value | ||
* @tparam isRowMajor true if input/output is row major, | ||
false for column major | ||
* @param[in] x input matrix | ||
* @param[in] y input matrix | ||
* @param[in] m number of rows of A and C/D | ||
* @param[in] n number of rows of B and cols of C/D | ||
* @param[in] k number of cols of A and B | ||
* @param[in] lda leading dimension of A | ||
* @param[in] ldb leading dimension of B | ||
* @param[in] ldd leading dimension of C/D | ||
* @param[output] dOutput output matrix | ||
* @param fin_op the final gemm epilogue lambda | ||
* @param stream cuda stream to launch work | ||
*/ | ||
template <typename DataT, typename AccT, typename OutT, typename IdxT, | ||
int VecLen, typename FinalLambda, bool isRowMajor> | ||
static void canberraImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, | ||
IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, | ||
FinalLambda fin_op, cudaStream_t stream) { | ||
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy; | ||
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy; | ||
|
||
typedef | ||
typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy; | ||
|
||
dim3 blk(KPolicy::Nthreads); | ||
|
||
// Accumulation operation lambda | ||
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { | ||
const auto diff = raft::L1Op<AccT, IdxT>()(x - y); | ||
const auto add = raft::myAbs(x) + raft::myAbs(y); | ||
// deal with potential for 0 in denominator by | ||
// forcing 1/0 instead | ||
acc += ((add != 0) * diff / (add + (add == 0))); | ||
}; | ||
|
||
// epilogue operation lambda for final value calculation | ||
auto epilog_lambda = [] __device__( | ||
AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], | ||
DataT * regxn, DataT * regyn, IdxT gridStrideX, | ||
IdxT gridStrideY) { return; }; | ||
|
||
if (isRowMajor) { | ||
auto canberraRowMajor = | ||
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy, | ||
decltype(core_lambda), decltype(epilog_lambda), | ||
FinalLambda, true>; | ||
dim3 grid = | ||
launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize, canberraRowMajor); | ||
|
||
canberraRowMajor<<<grid, blk, KPolicy::SmemSize, stream>>>( | ||
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, | ||
epilog_lambda, fin_op); | ||
} else { | ||
auto canberraColMajor = | ||
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy, | ||
decltype(core_lambda), decltype(epilog_lambda), | ||
FinalLambda, false>; | ||
dim3 grid = | ||
launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize, canberraColMajor); | ||
canberraColMajor<<<grid, blk, KPolicy::SmemSize, stream>>>( | ||
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, | ||
epilog_lambda, fin_op); | ||
} | ||
|
||
CUDA_CHECK(cudaGetLastError()); | ||
} | ||
|
||
template <typename DataT, typename AccT, typename OutT, typename IdxT, | ||
typename FinalLambda, bool isRowMajor> | ||
void canberra(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, | ||
const DataT *x, const DataT *y, OutT *dOutput, FinalLambda fin_op, | ||
cudaStream_t stream) { | ||
size_t bytesA = sizeof(DataT) * lda; | ||
size_t bytesB = sizeof(DataT) * ldb; | ||
if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { | ||
canberraImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), FinalLambda, | ||
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, | ||
stream); | ||
} else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { | ||
canberraImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), FinalLambda, | ||
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, | ||
stream); | ||
} else { | ||
canberraImpl<DataT, AccT, OutT, IdxT, 1, FinalLambda, isRowMajor>( | ||
x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); | ||
} | ||
} | ||
|
||
/** | ||
* @brief the canberra distance matrix calculation | ||
* It computes the following equation: cij = max(cij, op(ai-bj)) | ||
* @tparam InType input data-type (for A and B matrices) | ||
* @tparam AccType accumulation data-type | ||
* @tparam OutType output data-type (for C and D matrices) | ||
* @tparam FinalLambda user-defined epilogue lamba | ||
* @tparam Index_ Index type | ||
* @param[in] m number of rows of A and C/D | ||
* @param[in] n number of rows of B and cols of C/D | ||
* @param[in] k number of cols of A and B | ||
* @param[in] pA input matrix | ||
* @param[in] pB input matrix | ||
* @param[out] pD output matrix | ||
* @param[in] fin_op the final element-wise epilogue lambda | ||
* @param[in] stream cuda stream to launch work | ||
* @param[in] isRowMajor whether the input and output matrices are row major | ||
*/ | ||
template <typename InType, typename AccType, typename OutType, | ||
typename FinalLambda, typename Index_ = int> | ||
void canberraImpl(int m, int n, int k, const InType *pA, const InType *pB, | ||
OutType *pD, FinalLambda fin_op, cudaStream_t stream, | ||
bool isRowMajor) { | ||
typedef std::is_same<OutType, bool> is_bool; | ||
typedef typename std::conditional<is_bool::value, OutType, AccType>::type | ||
canberraOutType; | ||
Index_ lda, ldb, ldd; | ||
canberraOutType *pDcast = reinterpret_cast<canberraOutType *>(pD); | ||
if (isRowMajor) { | ||
lda = k, ldb = k, ldd = n; | ||
canberra<InType, AccType, canberraOutType, Index_, FinalLambda, true>( | ||
m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); | ||
} else { | ||
lda = n, ldb = m, ldd = m; | ||
canberra<InType, AccType, canberraOutType, Index_, FinalLambda, false>( | ||
n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); | ||
} | ||
} | ||
} // namespace distance | ||
} // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <raft/distance/pairwise_distance_base.cuh> | ||
|
||
namespace raft { | ||
namespace distance { | ||
|
||
/** | ||
* @brief the Chebyshev distance matrix calculation implementer | ||
* It computes the following equation: cij = max(cij, op(ai-bj)) | ||
* @tparam DataT input data-type (for A and B matrices) | ||
* @tparam AccT accumulation data-type | ||
* @tparam OutT output data-type (for C and D matrices) | ||
* @tparam IdxT index data-type | ||
* @tparam Veclen number of k-elements loaded by each thread | ||
for every LDG call. details in contractions.cuh | ||
* @tparam FinalLambda final lambda called on final distance value | ||
* @tparam isRowMajor true if input/output is row major, | ||
false for column major | ||
* @param[in] x input matrix | ||
* @param[in] y input matrix | ||
* @param[in] m number of rows of A and C/D | ||
* @param[in] n number of rows of B and cols of C/D | ||
* @param[in] k number of cols of A and B | ||
* @param[in] lda leading dimension of A | ||
* @param[in] ldb leading dimension of B | ||
* @param[in] ldd leading dimension of C/D | ||
* @param[out] dOutput output matrix | ||
* @param[in] fin_op the final gemm epilogue lambda | ||
* @param[in] stream cuda stream to launch work | ||
*/ | ||
template <typename DataT, typename AccT, typename OutT, typename IdxT, | ||
int VecLen, typename FinalLambda, bool isRowMajor> | ||
static void chebyshevImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, | ||
IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, | ||
FinalLambda fin_op, cudaStream_t stream) { | ||
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::Policy RowPolicy; | ||
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy; | ||
|
||
typedef | ||
typename std::conditional<isRowMajor, RowPolicy, ColPolicy>::type KPolicy; | ||
|
||
dim3 blk(KPolicy::Nthreads); | ||
|
||
// Accumulation operation lambda | ||
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { | ||
const auto diff = raft::L1Op<AccT, IdxT>()(x - y); | ||
acc = raft::myMax(acc, diff); | ||
}; | ||
|
||
// epilogue operation lambda for final value calculation | ||
auto epilog_lambda = [] __device__( | ||
AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], | ||
DataT * regxn, DataT * regyn, IdxT gridStrideX, | ||
IdxT gridStrideY) { return; }; | ||
|
||
if (isRowMajor) { | ||
auto chebyshevRowMajor = | ||
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy, | ||
decltype(core_lambda), decltype(epilog_lambda), | ||
FinalLambda, true>; | ||
dim3 grid = launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize, | ||
chebyshevRowMajor); | ||
|
||
chebyshevRowMajor<<<grid, blk, KPolicy::SmemSize, stream>>>( | ||
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, | ||
epilog_lambda, fin_op); | ||
} else { | ||
auto chebyshevColMajor = | ||
pairwiseDistanceMatKernel<false, DataT, AccT, OutT, IdxT, KPolicy, | ||
decltype(core_lambda), decltype(epilog_lambda), | ||
FinalLambda, false>; | ||
dim3 grid = launchConfigGenerator<KPolicy>(m, n, KPolicy::SmemSize, | ||
chebyshevColMajor); | ||
chebyshevColMajor<<<grid, blk, KPolicy::SmemSize, stream>>>( | ||
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, | ||
epilog_lambda, fin_op); | ||
} | ||
|
||
CUDA_CHECK(cudaGetLastError()); | ||
} | ||
|
||
template <typename DataT, typename AccT, typename OutT, typename IdxT, | ||
typename FinalLambda, bool isRowMajor> | ||
void chebyshev(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, | ||
const DataT *x, const DataT *y, OutT *dOutput, | ||
FinalLambda fin_op, cudaStream_t stream) { | ||
size_t bytesA = sizeof(DataT) * lda; | ||
size_t bytesB = sizeof(DataT) * ldb; | ||
if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { | ||
chebyshevImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), FinalLambda, | ||
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, | ||
stream); | ||
} else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { | ||
chebyshevImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), FinalLambda, | ||
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, | ||
stream); | ||
} else { | ||
chebyshevImpl<DataT, AccT, OutT, IdxT, 1, FinalLambda, isRowMajor>( | ||
x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); | ||
} | ||
} | ||
|
||
/** | ||
* @brief the chebyshev distance matrix calculation | ||
* It computes the following equation: cij = max(cij, op(ai-bj)) | ||
* @tparam InType input data-type (for A and B matrices) | ||
* @tparam AccType accumulation data-type | ||
* @tparam OutType output data-type (for C and D matrices) | ||
* @tparam FinalLambda user-defined epilogue lamba | ||
* @tparam Index_ Index type | ||
* @param[in] m number of rows of A and C/D | ||
* @param[in] n number of rows of B and cols of C/D | ||
* @param[in] k number of cols of A and B | ||
* @param[in] pA input matrix | ||
* @param[in] pB input matrix | ||
* @param[out] pD output matrix | ||
* @param[in] fin_op the final element-wise epilogue lambda | ||
* @param[in] stream cuda stream to launch work | ||
* @param[in] isRowMajor whether the input and output matrices are row major | ||
*/ | ||
template <typename InType, typename AccType, typename OutType, | ||
typename FinalLambda, typename Index_ = int> | ||
void chebyshevImpl(int m, int n, int k, const InType *pA, const InType *pB, | ||
OutType *pD, FinalLambda fin_op, cudaStream_t stream, | ||
bool isRowMajor) { | ||
typedef std::is_same<OutType, bool> is_bool; | ||
typedef typename std::conditional<is_bool::value, OutType, AccType>::type | ||
chebyshevOutType; | ||
Index_ lda, ldb, ldd; | ||
chebyshevOutType *pDcast = reinterpret_cast<chebyshevOutType *>(pD); | ||
if (isRowMajor) { | ||
lda = k, ldb = k, ldd = n; | ||
chebyshev<InType, AccType, chebyshevOutType, Index_, FinalLambda, true>( | ||
m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); | ||
} else { | ||
lda = n, ldb = m, ldd = m; | ||
chebyshev<InType, AccType, chebyshevOutType, Index_, FinalLambda, false>( | ||
n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); | ||
} | ||
} | ||
} // namespace distance | ||
} // namespace raft |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, maybe add the ins/outs for consistency w/ the above docs (also apply to the distances below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.