Skip to content

Commit

Permalink
Support innerproduct distance in the pairwise_distance API
Browse files Browse the repository at this point in the history
Fixes for supporting InnerProduct distance in the pairwise_distance api - required to handle the changes in rapidsai/raft#1226
  • Loading branch information
benfred committed Feb 13, 2023
1 parent 20d2690 commit 4eb6ba2
Show file tree
Hide file tree
Showing 15 changed files with 49 additions and 45 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ endfunction()
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${CUML_BRANCH_VERSION_raft}
PINNED_TAG branch-${CUML_BRANCH_VERSION_raft}
EXCLUDE_FROM_ALL ${CUML_EXCLUDE_RAFT_FROM_ALL}
# When PINNED_TAG above doesn't match cuml,
# force local raft clone in build directory
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/hdbscan/detail/soft_clustering.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ void all_points_dist_membership_vector(const raft::handle_t& handle,
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtExpanded, value_t, value_t, value_t, int>(
X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true);
handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true);
break;
case raft::distance::DistanceType::L1:
raft::distance::distance<raft::distance::DistanceType::L1, value_t, value_t, value_t, int>(
X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true);
handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true);
break;
case raft::distance::DistanceType::CosineExpanded:
raft::distance::
distance<raft::distance::DistanceType::CosineExpanded, value_t, value_t, value_t, int>(
X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true);
handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true);
break;
default: ASSERT(false, "Incorrect metric passed!");
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_canberra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::Canberra, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_canberra(const raft::handle_t& handle,
Expand All @@ -50,7 +50,7 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::Canberra, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_chebyshev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::Linf, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_chebyshev(const raft::handle_t& handle,
Expand All @@ -49,7 +49,7 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::Linf, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void pairwise_distance_correlation(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::CorrelationExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_correlation(const raft::handle_t& handle,
Expand All @@ -52,7 +52,7 @@ void pairwise_distance_correlation(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::CorrelationExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_cosine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void pairwise_distance_cosine(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::CosineExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_cosine(const raft::handle_t& handle,
Expand All @@ -51,7 +51,7 @@ void pairwise_distance_cosine(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::CosineExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
16 changes: 8 additions & 8 deletions cpp/src/metrics/pairwise_distance_euclidean.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ void pairwise_distance_euclidean(const raft::handle_t& handle,
case raft::distance::DistanceType::L2Expanded:
raft::distance::
distance<raft::distance::DistanceType::L2Expanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2Unexpanded:
raft::distance::
distance<raft::distance::DistanceType::L2Unexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
Expand All @@ -75,22 +75,22 @@ void pairwise_distance_euclidean(const raft::handle_t& handle,
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
raft::distance::distance<raft::distance::DistanceType::L2Expanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2Unexpanded:
raft::distance::
distance<raft::distance::DistanceType::L2Unexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_hamming.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void pairwise_distance_hamming(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::HammingUnexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_hamming(const raft::handle_t& handle,
Expand All @@ -52,7 +52,7 @@ void pairwise_distance_hamming(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::HammingUnexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_hellinger.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void pairwise_distance_hellinger(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::HellingerExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_hellinger(const raft::handle_t& handle,
Expand All @@ -51,7 +51,7 @@ void pairwise_distance_hellinger(const raft::handle_t& handle,
{
raft::distance::
distance<raft::distance::DistanceType::HellingerExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_jensen_shannon.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle,
{
raft::distance::
distance<raft::distance::DistanceType::JensenShannon, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_jensen_shannon(const raft::handle_t& handle,
Expand All @@ -49,7 +49,7 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle,
float metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::JensenShannon, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_kl_divergence.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle,
double metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::KLDivergence, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_kl_divergence(const raft::handle_t& handle,
Expand All @@ -48,7 +48,7 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle,
float metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::KLDivergence, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_l1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void pairwise_distance_l1(const raft::handle_t& handle,
double metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::L1, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_l1(const raft::handle_t& handle,
Expand All @@ -48,7 +48,7 @@ void pairwise_distance_l1(const raft::handle_t& handle,
float metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::L1, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_minkowski.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void pairwise_distance_minkowski(const raft::handle_t& handle,
double metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::LpUnexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor, metric_arg);
handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
}

void pairwise_distance_minkowski(const raft::handle_t& handle,
Expand All @@ -48,7 +48,7 @@ void pairwise_distance_minkowski(const raft::handle_t& handle,
float metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::LpUnexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor, metric_arg);
handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_russell_rao.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle,
{
raft::distance::
distance<raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_russell_rao(const raft::handle_t& handle,
Expand All @@ -50,7 +50,7 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle,
{
raft::distance::
distance<raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
26 changes: 15 additions & 11 deletions cpp/test/prims/distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "test_utils.h"
#include <distance/distance.cuh>
#include <gtest/gtest.h>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -147,7 +149,8 @@ template <typename DataType>
}

template <raft::distance::DistanceType distanceType, typename DataType>
void distanceLauncher(DataType* x,
void distanceLauncher(raft::resources const& handle,
DataType* x,
DataType* y,
DataType* dist,
DataType* dist2,
Expand All @@ -165,8 +168,9 @@ void distanceLauncher(DataType* x,
dist2[g_d_idx] = (d_val < threshold) ? 0.f : d_val;
return d_val;
};

distance<distanceType, DataType, DataType, DataType>(
x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor);
handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor);
}

template <raft::distance::DistanceType distanceType, typename DataType>
Expand All @@ -181,12 +185,13 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
{
params = ::testing::TestWithParam < DistanceInputs<DataType>::GetParam();
raft::random::Rng r(params.seed);
int m = params.m;
int n = params.n;
int k = params.k;
bool isRowMajor = params.isRowMajor;
cudaStream_t stream = 0;
RAFT_CUDA_TRY(cudaStreamCreate(&stream));
int m = params.m;
int n = params.n;
int k = params.k;
bool isRowMajor = params.isRowMajor;

raft::resources handle;
auto stream = raft::resource::get_cuda_stream(handle);
x.resize(m * k, stream);
y.resize(n * k, stream);
dist_ref.resize(m * n, stream);
Expand All @@ -199,7 +204,8 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
rmm::device_uvector<char> workspace(worksize);

DataType threshold = -10000.f;
distanceLauncher<distanceType, DataType>(x.data(),
distanceLauncher<distanceType, DataType>(handle,
x.data(),
y.data(),
dist.data(),
dist2.data(),
Expand All @@ -210,9 +216,7 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
threshold,
workspace.data(),
worksize,
stream,
isRowMajor);
RAFT_CUDA_TRY(cudaStreamDestroy(stream));
}

protected:
Expand Down

0 comments on commit 4eb6ba2

Please sign in to comment.