From b863f1810fdb0679f7777b656486dbfc96470d2b Mon Sep 17 00:00:00 2001 From: Anupam <54245698+aamijar@users.noreply.github.com> Date: Mon, 24 Jun 2024 19:42:24 -0400 Subject: [PATCH] [FEA] Dice Distance for Dense Inputs (#2359) Adds support for the `DistanceType::DiceExpanded` for dense inputs. 1. Naive Kernel Implementation (unexpanded form) 2. Expanded form for dice distance that follows ground truth from `scipy.spatial.distance.dice` 3. Gtests in `cpp/test/distance/dist-dice.cu` Related to https://github.com/rapidsai/cuml/issues/5129 Authors: - Anupam (https://github.com/aamijar) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/2359 --- cpp/CMakeLists.txt | 2 + cpp/include/raft/distance/detail/distance.cuh | 82 ++++++++++++- .../distance/detail/distance_ops/all_ops.cuh | 3 +- .../distance/detail/distance_ops/dice.cuh | 85 +++++++++++++ .../detail/pairwise_matrix/dispatch-ext.cuh | 6 +- cpp/include/raft/distance/distance-ext.cuh | 48 ++++++++ cpp/include/raft/distance/distance-inl.cuh | 3 + .../dispatch_dice_double_double_double_int.cu | 51 ++++++++ .../dispatch_dice_float_float_float_int.cu | 51 ++++++++ cpp/src/distance/distance.cu | 50 +++++++- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_dice.cu | 112 ++++++++++++++++++ cpp/test/distance/distance_base.cuh | 37 +++++- cpp/test/test_utils.h | 17 +++ 14 files changed, 540 insertions(+), 8 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance_ops/dice.cuh create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu create mode 100644 cpp/test/distance/dist_dice.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 39472cae67..fe9132b223 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -302,6 +302,8 @@ if(RAFT_COMPILE_LIBRARY) src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index a5c8c0ef4b..b708360074 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -55,6 +55,7 @@ using distance_tag = std::integral_constant; * - DistanceType::Canberra: * - DistanceType::CorrelationExpanded: * - DistanceType::CosineExpanded: + * - DistanceType::DiceExpanded: * - DistanceType::HammingUnexpanded: * - DistanceType::HellingerExpanded: * - DistanceType::JensenShannon: @@ -238,6 +239,79 @@ void distance_impl(raft::resources const& handle, distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + bool is_row_major, + DataT) // unused +{ + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), + "OutT can be uint8_t, float, double," + "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); + + ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + DataT* x_norm = workspace; + DataT* y_norm = workspace; + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if (x == y && is_row_major) { + raft::linalg::reduce(x_norm, + x, + k, + std::max(m, n), + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + } else { + y_norm += m; + raft::linalg::reduce(x_norm, + x, + k, + m, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + raft::linalg::reduce(y_norm, + y, + k, + n, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + } + + ops::dice_distance_op distance_op{}; + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + template void distance_impl(raft::resources const& handle, distance_tag distance_type, @@ -794,9 +868,11 @@ template size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) { - size_t worksize = 0; - constexpr bool is_allocated = (distanceType <= raft::distance::DistanceType::CosineExpanded) || - (distanceType == raft::distance::DistanceType::CorrelationExpanded); + size_t worksize = 0; + constexpr bool is_allocated = + (distanceType <= raft::distance::DistanceType::CosineExpanded) || + (distanceType == raft::distance::DistanceType::CorrelationExpanded) || + (distanceType == raft::distance::DistanceType::DiceExpanded); constexpr int numOfBuffers = (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; diff --git a/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh index 3e8f4e86fb..84eb3c705b 100644 --- a/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/distance/detail/distance_ops/dice.cuh b/cpp/include/raft/distance/detail/distance_ops/dice.cuh new file mode 100644 index 0000000000..edd7e42a8d --- /dev/null +++ b/cpp/include/raft/distance/detail/distance_ops/dice.cuh @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace raft::distance::detail::ops { + +// Epilogue operator for CUTLASS based kernel +template +struct dice_cutlass_op { + __device__ dice_cutlass_op() noexcept {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + return static_cast(1.0) - static_cast(2 * accVal / (aNorm + bNorm)); + } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + +/** + * @brief the expanded dice distance matrix calculation + * + * It computes the following equation: + * + * d(x, y) = 1 - 2*(x ⋅ y) / ( Σ(x) + Σ(y) ) + */ +template +struct dice_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + // Load norms of input data + static constexpr bool use_norms = true; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = 1.0 - (2 * acc[i][j] / (regxn[i] + regyn[j])); + } + } + } + + constexpr dice_cutlass_op get_cutlass_op() const + { + return dice_cutlass_op(); + } +}; + +} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh index e1dc6f9b37..bced721ec8 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -120,6 +120,10 @@ instantiate_raft_distance_detail_pairwise_matrix_dispatch( raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); instantiate_raft_distance_detail_pairwise_matrix_dispatch( raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::dice_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int); instantiate_raft_distance_detail_pairwise_matrix_dispatch( raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); instantiate_raft_distance_detail_pairwise_matrix_dispatch( diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh index a634e8c995..2d41e029fe 100644 --- a/cpp/include/raft/distance/distance-ext.cuh +++ b/cpp/include/raft/distance/distance-ext.cuh @@ -204,6 +204,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, raft::identity_op, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); instantiate_raft_distance_distance( @@ -286,6 +290,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_distance( @@ -362,6 +370,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_distance( @@ -429,6 +441,10 @@ instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_getWorkspaceSize( @@ -547,6 +563,22 @@ instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineE double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, float, float, @@ -822,6 +854,22 @@ instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, float, float, diff --git a/cpp/include/raft/distance/distance-inl.cuh b/cpp/include/raft/distance/distance-inl.cuh index 647c5b2908..13c9d57efd 100644 --- a/cpp/include/raft/distance/distance-inl.cuh +++ b/cpp/include/raft/distance/distance-inl.cuh @@ -306,6 +306,9 @@ void pairwise_distance(raft::resources const& handle, case DistanceType::RusselRaoExpanded: dispatch(std::integral_constant{}); break; + case DistanceType::DiceExpanded: + dispatch(std::integral_constant{}); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; } diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu new file mode 100644 index 0000000000..a259f8b3b0 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu new file mode 100644 index 0000000000..e89f8b422c --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::dice_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu index 8c94608311..8fe0bf2007 100644 --- a/cpp/src/distance/distance.cu +++ b/cpp/src/distance/distance.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,6 +72,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, raft::identity_op, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); instantiate_raft_distance_distance( @@ -154,6 +158,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_distance( @@ -230,6 +238,10 @@ instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_distance( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_distance( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_distance( @@ -297,6 +309,10 @@ instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::CosineExpanded, float, float, float, int); instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, double, double, double, int); instantiate_raft_distance_getWorkspaceSize( raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); instantiate_raft_distance_getWorkspaceSize( @@ -415,6 +431,22 @@ instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineE double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, float, float, @@ -690,6 +722,22 @@ instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, float, float, diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index ff0518a4d0..ed923fb1db 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -156,6 +156,7 @@ if(BUILD_TESTS) distance/dist_canberra.cu distance/dist_correlation.cu distance/dist_cos.cu + distance/dist_dice.cu distance/dist_hamming.cu distance/dist_hellinger.cu distance/dist_inner_product.cu diff --git a/cpp/test/distance/dist_dice.cu b/cpp/test/distance/dist_dice.cu new file mode 100644 index 0000000000..e127659dc6 --- /dev/null +++ b/cpp/test/distance/dist_dice.cu @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceExpDice : public DistanceTest { +}; + +template +class DistanceExpDiceXequalY + : public DistanceTestSameBuffer {}; + +const std::vector> inputsf = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; + +const std::vector> inputsXeqYf = { + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, +}; + +typedef DistanceExpDice DistanceExpDiceF; +TEST_P(DistanceExpDiceF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApproxNaN(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceF, ::testing::ValuesIn(inputsf)); + +typedef DistanceExpDiceXequalY DistanceExpDiceXequalYF; +TEST_P(DistanceExpDiceXequalYF, Result) +{ + int m = params.m; + int n = params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + n, + raft::CompareApproxNaN(params.tolerance), + stream)); + n = params.isRowMajor ? m : m / 2; + m = params.isRowMajor ? m / 2 : m; + + ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m, + n, + raft::CompareApproxNaN(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceXequalYF, ::testing::ValuesIn(inputsXeqYf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceExpDice DistanceExpDiceD; +TEST_P(DistanceExpDiceD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApproxNaN(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpDiceD, ::testing::ValuesIn(inputsd)); + +class BigMatrixDice : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixDice, Result) {} + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 2854a8f3df..f44fb18519 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -96,6 +96,34 @@ RAFT_KERNEL naiveL1_Linf_CanberraDistanceKernel(DataType* dist, dist[outidx] = acc; } +template +RAFT_KERNEL naiveDiceDistanceKernel( + DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) +{ + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { return; } + + DataType acc_a = DataType(0); + DataType acc_b = DataType(0); + DataType acc_ab = DataType(0); + + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc_a += a; + acc_b += b; + acc_ab += a * b; + } + + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + + // Use 1.0 - (dice dissimilarity) to calc the distance + dist[outidx] = (DataType)1.0 - (2 * acc_ab / ((acc_a) + (acc_b))); +} + template RAFT_KERNEL naiveCosineDistanceKernel( DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) @@ -391,6 +419,9 @@ void naiveDistance(DataType* dist, naiveCorrelationDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::DiceExpanded: + naiveDiceDistanceKernel<<>>(dist, x, y, m, n, k, isRowMajor); + break; default: FAIL() << "should be here\n"; } RAFT_CUDA_TRY(cudaPeekAtLastError()); @@ -482,7 +513,8 @@ class DistanceTest : public ::testing::TestWithParam> { // Hellinger works only on positive numbers uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); - } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded) { + } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded || + distanceType == raft::distance::DistanceType::DiceExpanded) { uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); // Russel rao works on boolean values. @@ -571,7 +603,8 @@ class DistanceTestSameBuffer : public ::testing::TestWithParam +struct CompareApproxNaN { + CompareApproxNaN(T eps_) : eps(eps_) {} + bool operator()(const T& a, const T& b) const + { + T diff = std::abs(a - b); + T m = std::max(std::abs(a), std::abs(b)); + T ratio = diff > eps ? diff / m : diff; + + if (std::isnan(a) && std::isnan(b)) { return true; } + return (ratio <= eps); + } + + private: + T eps; +}; + template ::std::ostream& operator<<(::std::ostream& os, const raft::KeyValuePair& kv) {