From 4eb6ba21dae86a9c342b47f19bae87e24273d6f3 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 13 Feb 2023 11:13:59 -0800 Subject: [PATCH] Support innerproduct distance in the pairwise_distance API Fixes for supporting InnerProduct distance in the pairwise_distance api - required to handle the changes in rapidsai/raft#1226 --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- cpp/src/hdbscan/detail/soft_clustering.cuh | 6 ++--- cpp/src/metrics/pairwise_distance_canberra.cu | 4 +-- .../metrics/pairwise_distance_chebyshev.cu | 4 +-- .../metrics/pairwise_distance_correlation.cu | 4 +-- cpp/src/metrics/pairwise_distance_cosine.cu | 4 +-- .../metrics/pairwise_distance_euclidean.cu | 16 ++++++------ cpp/src/metrics/pairwise_distance_hamming.cu | 4 +-- .../metrics/pairwise_distance_hellinger.cu | 4 +-- .../pairwise_distance_jensen_shannon.cu | 4 +-- .../pairwise_distance_kl_divergence.cu | 4 +-- cpp/src/metrics/pairwise_distance_l1.cu | 4 +-- .../metrics/pairwise_distance_minkowski.cu | 4 +-- .../metrics/pairwise_distance_russell_rao.cu | 4 +-- cpp/test/prims/distance_base.cuh | 26 +++++++++++-------- 15 files changed, 49 insertions(+), 45 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 3306e0a7ce..a890a5e72e 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -93,7 +93,7 @@ endfunction() # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft} FORK rapidsai - PINNED_TAG branch-${CUML_BRANCH_VERSION_raft} + PINNED_TAG branch-${CUML_BRANCH_VERSION_raft} EXCLUDE_FROM_ALL ${CUML_EXCLUDE_RAFT_FROM_ALL} # When PINNED_TAG above doesn't match cuml, # force local raft clone in build directory diff --git a/cpp/src/hdbscan/detail/soft_clustering.cuh b/cpp/src/hdbscan/detail/soft_clustering.cuh index 2139ba67ad..bb79378c0c 100644 --- a/cpp/src/hdbscan/detail/soft_clustering.cuh +++ b/cpp/src/hdbscan/detail/soft_clustering.cuh @@ -91,16 +91,16 @@ void all_points_dist_membership_vector(const raft::handle_t& handle, case raft::distance::DistanceType::L2SqrtExpanded: raft::distance:: distance( - X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true); + handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true); break; case raft::distance::DistanceType::L1: raft::distance::distance( - X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true); + handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true); break; case raft::distance::DistanceType::CosineExpanded: raft::distance:: distance( - X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true); + handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true); break; default: ASSERT(false, "Incorrect metric passed!"); } diff --git a/cpp/src/metrics/pairwise_distance_canberra.cu b/cpp/src/metrics/pairwise_distance_canberra.cu index 757b2f8dab..ac1b6e6c96 100644 --- a/cpp/src/metrics/pairwise_distance_canberra.cu +++ b/cpp/src/metrics/pairwise_distance_canberra.cu @@ -35,7 +35,7 @@ void pairwise_distance_canberra(const raft::handle_t& handle, { // Call the distance function raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_canberra(const raft::handle_t& handle, @@ -50,7 +50,7 @@ void pairwise_distance_canberra(const raft::handle_t& handle, { // Call the distance function raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cu b/cpp/src/metrics/pairwise_distance_chebyshev.cu index 0dbf77914d..ac763fd95c 100644 --- a/cpp/src/metrics/pairwise_distance_chebyshev.cu +++ b/cpp/src/metrics/pairwise_distance_chebyshev.cu @@ -34,7 +34,7 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, { // Call the distance function raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_chebyshev(const raft::handle_t& handle, @@ -49,7 +49,7 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, { // Call the distance function raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_correlation.cu b/cpp/src/metrics/pairwise_distance_correlation.cu index 630282d510..9874904be4 100644 --- a/cpp/src/metrics/pairwise_distance_correlation.cu +++ b/cpp/src/metrics/pairwise_distance_correlation.cu @@ -36,7 +36,7 @@ void pairwise_distance_correlation(const raft::handle_t& handle, // Call the distance function raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_correlation(const raft::handle_t& handle, @@ -52,7 +52,7 @@ void pairwise_distance_correlation(const raft::handle_t& handle, // Call the distance function raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_cosine.cu b/cpp/src/metrics/pairwise_distance_cosine.cu index 3e4f9c6e81..59fbd835fe 100644 --- a/cpp/src/metrics/pairwise_distance_cosine.cu +++ b/cpp/src/metrics/pairwise_distance_cosine.cu @@ -36,7 +36,7 @@ void pairwise_distance_cosine(const raft::handle_t& handle, // Call the distance function raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_cosine(const raft::handle_t& handle, @@ -51,7 +51,7 @@ void pairwise_distance_cosine(const raft::handle_t& handle, { // Call the distance function raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_euclidean.cu b/cpp/src/metrics/pairwise_distance_euclidean.cu index 0a73883ba9..3a32439e75 100644 --- a/cpp/src/metrics/pairwise_distance_euclidean.cu +++ b/cpp/src/metrics/pairwise_distance_euclidean.cu @@ -39,22 +39,22 @@ void pairwise_distance_euclidean(const raft::handle_t& handle, case raft::distance::DistanceType::L2Expanded: raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); break; case raft::distance::DistanceType::L2SqrtExpanded: raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); break; case raft::distance::DistanceType::L2Unexpanded: raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } @@ -75,22 +75,22 @@ void pairwise_distance_euclidean(const raft::handle_t& handle, switch (metric) { case raft::distance::DistanceType::L2Expanded: raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); break; case raft::distance::DistanceType::L2SqrtExpanded: raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); break; case raft::distance::DistanceType::L2Unexpanded: raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } diff --git a/cpp/src/metrics/pairwise_distance_hamming.cu b/cpp/src/metrics/pairwise_distance_hamming.cu index 26ac60c368..fecd594244 100644 --- a/cpp/src/metrics/pairwise_distance_hamming.cu +++ b/cpp/src/metrics/pairwise_distance_hamming.cu @@ -36,7 +36,7 @@ void pairwise_distance_hamming(const raft::handle_t& handle, // Call the distance function raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_hamming(const raft::handle_t& handle, @@ -52,7 +52,7 @@ void pairwise_distance_hamming(const raft::handle_t& handle, // Call the distance function raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cu b/cpp/src/metrics/pairwise_distance_hellinger.cu index e6afeedcc7..cb9795c21c 100644 --- a/cpp/src/metrics/pairwise_distance_hellinger.cu +++ b/cpp/src/metrics/pairwise_distance_hellinger.cu @@ -36,7 +36,7 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, // Call the distance function raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_hellinger(const raft::handle_t& handle, @@ -51,7 +51,7 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, { raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu index 3de240da98..7a72cd0c13 100644 --- a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu @@ -35,7 +35,7 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, { raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_jensen_shannon(const raft::handle_t& handle, @@ -49,7 +49,7 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, float metric_arg) { raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cu b/cpp/src/metrics/pairwise_distance_kl_divergence.cu index fd5146a295..bfa1382937 100644 --- a/cpp/src/metrics/pairwise_distance_kl_divergence.cu +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cu @@ -34,7 +34,7 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, double metric_arg) { raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_kl_divergence(const raft::handle_t& handle, @@ -48,7 +48,7 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, float metric_arg) { raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_l1.cu b/cpp/src/metrics/pairwise_distance_l1.cu index 5c2edd6c66..c704998ae5 100644 --- a/cpp/src/metrics/pairwise_distance_l1.cu +++ b/cpp/src/metrics/pairwise_distance_l1.cu @@ -34,7 +34,7 @@ void pairwise_distance_l1(const raft::handle_t& handle, double metric_arg) { raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_l1(const raft::handle_t& handle, @@ -48,7 +48,7 @@ void pairwise_distance_l1(const raft::handle_t& handle, float metric_arg) { raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cu b/cpp/src/metrics/pairwise_distance_minkowski.cu index d7dfcc96d0..370621e7fc 100644 --- a/cpp/src/metrics/pairwise_distance_minkowski.cu +++ b/cpp/src/metrics/pairwise_distance_minkowski.cu @@ -34,7 +34,7 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, double metric_arg) { raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, isRowMajor, metric_arg); } void pairwise_distance_minkowski(const raft::handle_t& handle, @@ -48,7 +48,7 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, float metric_arg) { raft::distance::distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, isRowMajor, metric_arg); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cu b/cpp/src/metrics/pairwise_distance_russell_rao.cu index ab11531219..293e42457c 100644 --- a/cpp/src/metrics/pairwise_distance_russell_rao.cu +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cu @@ -35,7 +35,7 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, { raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } void pairwise_distance_russell_rao(const raft::handle_t& handle, @@ -50,7 +50,7 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, { raft::distance:: distance( - x, y, dist, m, n, k, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, isRowMajor); } } // namespace Metrics diff --git a/cpp/test/prims/distance_base.cuh b/cpp/test/prims/distance_base.cuh index 59991de0c4..10b8ed72ae 100644 --- a/cpp/test/prims/distance_base.cuh +++ b/cpp/test/prims/distance_base.cuh @@ -17,6 +17,8 @@ #include "test_utils.h" #include #include +#include +#include #include #include #include @@ -147,7 +149,8 @@ template } template -void distanceLauncher(DataType* x, +void distanceLauncher(raft::resources const& handle, + DataType* x, DataType* y, DataType* dist, DataType* dist2, @@ -165,8 +168,9 @@ void distanceLauncher(DataType* x, dist2[g_d_idx] = (d_val < threshold) ? 0.f : d_val; return d_val; }; + distance( - x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor); + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor); } template @@ -181,12 +185,13 @@ class DistanceTest : public ::testing::TestWithParam> { { params = ::testing::TestWithParam < DistanceInputs::GetParam(); raft::random::Rng r(params.seed); - int m = params.m; - int n = params.n; - int k = params.k; - bool isRowMajor = params.isRowMajor; - cudaStream_t stream = 0; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + int m = params.m; + int n = params.n; + int k = params.k; + bool isRowMajor = params.isRowMajor; + + raft::resources handle; + auto stream = raft::resource::get_cuda_stream(handle); x.resize(m * k, stream); y.resize(n * k, stream); dist_ref.resize(m * n, stream); @@ -199,7 +204,8 @@ class DistanceTest : public ::testing::TestWithParam> { rmm::device_uvector workspace(worksize); DataType threshold = -10000.f; - distanceLauncher(x.data(), + distanceLauncher(handle, + x.data(), y.data(), dist.data(), dist2.data(), @@ -210,9 +216,7 @@ class DistanceTest : public ::testing::TestWithParam> { threshold, workspace.data(), worksize, - stream, isRowMajor); - RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } protected: