Skip to content

Commit

Permalink
[FEA] Dice Distance for Dense Inputs (#2359)
Browse files Browse the repository at this point in the history
Adds support for the `DistanceType::DiceExpanded` for dense inputs.
1. Naive Kernel Implementation (unexpanded form)
2. Expanded form for dice distance that follows ground truth from `scipy.spatial.distance.dice`
3. Gtests in `cpp/test/distance/dist-dice.cu` 

Related to rapidsai/cuml#5129

Authors:
  - Anupam (https://github.com/aamijar)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #2359
  • Loading branch information
aamijar authored Jun 24, 2024
1 parent b86a5f9 commit b863f18
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 8 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu
Expand Down
82 changes: 79 additions & 3 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ using distance_tag = std::integral_constant<DistanceType, d>;
* - DistanceType::Canberra:
* - DistanceType::CorrelationExpanded:
* - DistanceType::CosineExpanded:
* - DistanceType::DiceExpanded:
* - DistanceType::HammingUnexpanded:
* - DistanceType::HellingerExpanded:
* - DistanceType::JensenShannon:
Expand Down Expand Up @@ -238,6 +239,79 @@ void distance_impl(raft::resources const& handle,
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

template <typename DataT, typename AccT, typename OutT, typename FinOpT, typename IdxT = int>
void distance_impl(raft::resources const& handle,
distance_tag<DistanceType::DiceExpanded> distance_type,
const DataT* x,
const DataT* y,
OutT* out,
IdxT m,
IdxT n,
IdxT k,
AccT* workspace,
size_t worksize,
FinOpT fin_op,
bool is_row_major,
DataT) // unused
{
// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))),
"OutT can be uint8_t, float, double,"
"if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT).");

ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

DataT* x_norm = workspace;
DataT* y_norm = workspace;
// TODO: Column major case looks to have lower accuracy for X == Y,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if (x == y && is_row_major) {
raft::linalg::reduce(x_norm,
x,
k,
std::max(m, n),
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
} else {
y_norm += m;
raft::linalg::reduce(x_norm,
x,
k,
m,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
raft::linalg::reduce(y_norm,
y,
k,
n,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
}

ops::dice_distance_op<DataT, AccT, IdxT> distance_op{};
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

template <typename DataT, typename AccT, typename OutT, typename FinOpT, typename IdxT = int>
void distance_impl(raft::resources const& handle,
distance_tag<DistanceType::HammingUnexpanded> distance_type,
Expand Down Expand Up @@ -794,9 +868,11 @@ template <raft::distance::DistanceType distanceType,
typename Index_ = int>
size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k)
{
size_t worksize = 0;
constexpr bool is_allocated = (distanceType <= raft::distance::DistanceType::CosineExpanded) ||
(distanceType == raft::distance::DistanceType::CorrelationExpanded);
size_t worksize = 0;
constexpr bool is_allocated =
(distanceType <= raft::distance::DistanceType::CosineExpanded) ||
(distanceType == raft::distance::DistanceType::CorrelationExpanded) ||
(distanceType == raft::distance::DistanceType::DiceExpanded);
constexpr int numOfBuffers =
(distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1;

Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/distance/detail/distance_ops/all_ops.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
#include <raft/distance/detail/distance_ops/canberra.cuh>
#include <raft/distance/detail/distance_ops/correlation.cuh>
#include <raft/distance/detail/distance_ops/cosine.cuh>
#include <raft/distance/detail/distance_ops/dice.cuh>
#include <raft/distance/detail/distance_ops/hamming.cuh>
#include <raft/distance/detail/distance_ops/hellinger.cuh>
#include <raft/distance/detail/distance_ops/jensen_shannon.cuh>
Expand Down
85 changes: 85 additions & 0 deletions cpp/include/raft/distance/detail/distance_ops/dice.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace raft::distance::detail::ops {

// Epilogue operator for CUTLASS based kernel
template <typename DataT, typename AccT>
struct dice_cutlass_op {
__device__ dice_cutlass_op() noexcept {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
{
return static_cast<AccT>(1.0) - static_cast<AccT>(2 * accVal / (aNorm + bNorm));
}
__device__ AccT operator()(DataT aData) const noexcept { return aData; }
};

/**
* @brief the expanded dice distance matrix calculation
*
* It computes the following equation:
*
* d(x, y) = 1 - 2*(x ⋅ y) / ( Σ(x) + Σ(y) )
*/
template <typename DataType, typename AccType, typename IdxType>
struct dice_distance_op {
using DataT = DataType;
using AccT = AccType;
using IdxT = IdxType;

// Load norms of input data
static constexpr bool use_norms = true;
// Whether the core function requires so many instructions that it makes sense
// to reduce loop unrolling, etc. We do this to keep compile times in check.
static constexpr bool expensive_inner_loop = false;

// Size of shared memory. This is normally decided by the kernel policy, but
// some ops such as correlation_distance_op use more.
template <typename Policy>
static constexpr size_t shared_mem_size()
{
return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT));
}

DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; };

template <typename Policy>
DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh],
DataT* regxn,
DataT* regyn,
IdxT gridStrideX,
IdxT gridStrideY) const
{
#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < Policy::AccColsPerTh; ++j) {
acc[i][j] = 1.0 - (2 * acc[i][j] / (regxn[i] + regyn[j]));
}
}
}

constexpr dice_cutlass_op<DataT, AccT> get_cutlass_op() const
{
return dice_cutlass_op<DataT, AccT>();
}
};

} // namespace raft::distance::detail::ops
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -120,6 +120,10 @@ instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::dice_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
Expand Down
48 changes: 48 additions & 0 deletions cpp/include/raft/distance/distance-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, raft::identity_op, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, double, double, double, raft::identity_op, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int);
instantiate_raft_distance_distance(
Expand Down Expand Up @@ -286,6 +290,10 @@ instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, float, float, float, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, double, double, double, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, double, double, double, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::HammingUnexpanded, float, float, float, int);
instantiate_raft_distance_distance(
Expand Down Expand Up @@ -362,6 +370,10 @@ instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, float, float, float, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, double, double, double, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, double, double, double, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::HammingUnexpanded, float, float, float, int);
instantiate_raft_distance_distance(
Expand Down Expand Up @@ -429,6 +441,10 @@ instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::CosineExpanded, float, float, float, int);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::CosineExpanded, double, double, double, int);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::DiceExpanded, float, float, float, int);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::DiceExpanded, double, double, double, int);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::HammingUnexpanded, float, float, float, int);
instantiate_raft_distance_getWorkspaceSize(
Expand Down Expand Up @@ -547,6 +563,22 @@ instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineE
double,
int,
raft::layout_f_contiguous);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_c_contiguous);
instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded,
double,
double,
double,
int,
raft::layout_c_contiguous);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_f_contiguous);
instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded,
double,
double,
double,
int,
raft::layout_f_contiguous);
instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded,
float,
float,
Expand Down Expand Up @@ -822,6 +854,22 @@ instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded,
double,
raft::layout_f_contiguous,
int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_c_contiguous, int);
instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded,
double,
double,
double,
raft::layout_c_contiguous,
int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_f_contiguous, int);
instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded,
double,
double,
double,
raft::layout_f_contiguous,
int);
instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded,
float,
float,
Expand Down
3 changes: 3 additions & 0 deletions cpp/include/raft/distance/distance-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ void pairwise_distance(raft::resources const& handle,
case DistanceType::RusselRaoExpanded:
dispatch(std::integral_constant<DistanceType, DistanceType::RusselRaoExpanded>{});
break;
case DistanceType::DiceExpanded:
dispatch(std::integral_constant<DistanceType, DistanceType::DiceExpanded>{});
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by dispatch_00_generate.py
*
* Make changes there and run in this directory:
*
* > python dispatch_00_generate.py
*
*/

#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/all_ops.cuh> // ops::*
#include <raft/distance/detail/pairwise_matrix/dispatch-inl.cuh> // dispatch
#include <raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh>
#include <raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh>
#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \
OpT, DataT, AccT, OutT, FinOpT, IdxT) \
template void raft::distance::detail:: \
pairwise_matrix_dispatch<OpT<DataT, AccT, IdxT>, DataT, AccT, OutT, FinOpT, IdxT>( \
OpT<DataT, AccT, IdxT> distance_op, \
IdxT m, \
IdxT n, \
IdxT k, \
const DataT* x, \
const DataT* y, \
const DataT* x_norm, \
const DataT* y_norm, \
OutT* out, \
FinOpT fin_op, \
cudaStream_t stream, \
bool is_row_major)

instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int);

#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch
Loading

0 comments on commit b863f18

Please sign in to comment.