-
Notifications
You must be signed in to change notification settings - Fork 199
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEA] Dice Distance for Dense Inputs (#2359)
Adds support for the `DistanceType::DiceExpanded` for dense inputs. 1. Naive Kernel Implementation (unexpanded form) 2. Expanded form for dice distance that follows ground truth from `scipy.spatial.distance.dice` 3. Gtests in `cpp/test/distance/dist-dice.cu` Related to rapidsai/cuml#5129 Authors: - Anupam (https://github.com/aamijar) Approvers: - Divye Gala (https://github.com/divyegala) URL: #2359
- Loading branch information
Showing
14 changed files
with
540 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* | ||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/util/cuda_dev_essentials.cuh> // DI | ||
|
||
namespace raft::distance::detail::ops { | ||
|
||
// Epilogue operator for CUTLASS based kernel | ||
template <typename DataT, typename AccT> | ||
struct dice_cutlass_op { | ||
__device__ dice_cutlass_op() noexcept {} | ||
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept | ||
{ | ||
return static_cast<AccT>(1.0) - static_cast<AccT>(2 * accVal / (aNorm + bNorm)); | ||
} | ||
__device__ AccT operator()(DataT aData) const noexcept { return aData; } | ||
}; | ||
|
||
/** | ||
* @brief the expanded dice distance matrix calculation | ||
* | ||
* It computes the following equation: | ||
* | ||
* d(x, y) = 1 - 2*(x ⋅ y) / ( Σ(x) + Σ(y) ) | ||
*/ | ||
template <typename DataType, typename AccType, typename IdxType> | ||
struct dice_distance_op { | ||
using DataT = DataType; | ||
using AccT = AccType; | ||
using IdxT = IdxType; | ||
|
||
// Load norms of input data | ||
static constexpr bool use_norms = true; | ||
// Whether the core function requires so many instructions that it makes sense | ||
// to reduce loop unrolling, etc. We do this to keep compile times in check. | ||
static constexpr bool expensive_inner_loop = false; | ||
|
||
// Size of shared memory. This is normally decided by the kernel policy, but | ||
// some ops such as correlation_distance_op use more. | ||
template <typename Policy> | ||
static constexpr size_t shared_mem_size() | ||
{ | ||
return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); | ||
} | ||
|
||
DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; | ||
|
||
template <typename Policy> | ||
DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], | ||
DataT* regxn, | ||
DataT* regyn, | ||
IdxT gridStrideX, | ||
IdxT gridStrideY) const | ||
{ | ||
#pragma unroll | ||
for (int i = 0; i < Policy::AccRowsPerTh; ++i) { | ||
#pragma unroll | ||
for (int j = 0; j < Policy::AccColsPerTh; ++j) { | ||
acc[i][j] = 1.0 - (2 * acc[i][j] / (regxn[i] + regyn[j])); | ||
} | ||
} | ||
} | ||
|
||
constexpr dice_cutlass_op<DataT, AccT> get_cutlass_op() const | ||
{ | ||
return dice_cutlass_op<DataT, AccT>(); | ||
} | ||
}; | ||
|
||
} // namespace raft::distance::detail::ops |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright (c) 2021-2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
/* | ||
* NOTE: this file is generated by dispatch_00_generate.py | ||
* | ||
* Make changes there and run in this directory: | ||
* | ||
* > python dispatch_00_generate.py | ||
* | ||
*/ | ||
|
||
#include <raft/core/operators.hpp> // raft::identity_op | ||
#include <raft/distance/detail/distance_ops/all_ops.cuh> // ops::* | ||
#include <raft/distance/detail/pairwise_matrix/dispatch-inl.cuh> // dispatch | ||
#include <raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh> | ||
#include <raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh> | ||
#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ | ||
OpT, DataT, AccT, OutT, FinOpT, IdxT) \ | ||
template void raft::distance::detail:: \ | ||
pairwise_matrix_dispatch<OpT<DataT, AccT, IdxT>, DataT, AccT, OutT, FinOpT, IdxT>( \ | ||
OpT<DataT, AccT, IdxT> distance_op, \ | ||
IdxT m, \ | ||
IdxT n, \ | ||
IdxT k, \ | ||
const DataT* x, \ | ||
const DataT* y, \ | ||
const DataT* x_norm, \ | ||
const DataT* y_norm, \ | ||
OutT* out, \ | ||
FinOpT fin_op, \ | ||
cudaStream_t stream, \ | ||
bool is_row_major) | ||
|
||
instantiate_raft_distance_detail_pairwise_matrix_dispatch( | ||
raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int); | ||
|
||
#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch |
Oops, something went wrong.