Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support innerproduct distance in the pairwise_distance API #5230

Merged
merged 4 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ endfunction()
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${CUML_BRANCH_VERSION_raft}
PINNED_TAG branch-${CUML_BRANCH_VERSION_raft}
EXCLUDE_FROM_ALL ${CUML_EXCLUDE_RAFT_FROM_ALL}
# When PINNED_TAG above doesn't match cuml,
# force local raft clone in build directory
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/hdbscan/detail/soft_clustering.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ void all_points_dist_membership_vector(const raft::handle_t& handle,
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtExpanded, value_t, value_t, value_t, int>(
X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true);
handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true);
break;
case raft::distance::DistanceType::L1:
raft::distance::distance<raft::distance::DistanceType::L1, value_t, value_t, value_t, int>(
X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true);
handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true);
break;
case raft::distance::DistanceType::CosineExpanded:
raft::distance::
distance<raft::distance::DistanceType::CosineExpanded, value_t, value_t, value_t, int>(
X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, stream, true);
handle, X, exemplars_dense.data(), dist.data(), m, n_exemplars, n, true);
break;
default: ASSERT(false, "Incorrect metric passed!");
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_canberra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::Canberra, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_canberra(const raft::handle_t& handle,
Expand All @@ -50,7 +50,7 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::Canberra, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_chebyshev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::Linf, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_chebyshev(const raft::handle_t& handle,
Expand All @@ -49,7 +49,7 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::Linf, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void pairwise_distance_correlation(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::CorrelationExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_correlation(const raft::handle_t& handle,
Expand All @@ -52,7 +52,7 @@ void pairwise_distance_correlation(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::CorrelationExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_cosine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void pairwise_distance_cosine(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::CosineExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_cosine(const raft::handle_t& handle,
Expand All @@ -51,7 +51,7 @@ void pairwise_distance_cosine(const raft::handle_t& handle,
{
// Call the distance function
raft::distance::distance<raft::distance::DistanceType::CosineExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
16 changes: 8 additions & 8 deletions cpp/src/metrics/pairwise_distance_euclidean.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ void pairwise_distance_euclidean(const raft::handle_t& handle,
case raft::distance::DistanceType::L2Expanded:
raft::distance::
distance<raft::distance::DistanceType::L2Expanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2Unexpanded:
raft::distance::
distance<raft::distance::DistanceType::L2Unexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
Expand All @@ -75,22 +75,22 @@ void pairwise_distance_euclidean(const raft::handle_t& handle,
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
raft::distance::distance<raft::distance::DistanceType::L2Expanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2Unexpanded:
raft::distance::
distance<raft::distance::DistanceType::L2Unexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
raft::distance::
distance<raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_hamming.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void pairwise_distance_hamming(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::HammingUnexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_hamming(const raft::handle_t& handle,
Expand All @@ -52,7 +52,7 @@ void pairwise_distance_hamming(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::HammingUnexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_hellinger.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void pairwise_distance_hellinger(const raft::handle_t& handle,
// Call the distance function
raft::distance::
distance<raft::distance::DistanceType::HellingerExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_hellinger(const raft::handle_t& handle,
Expand All @@ -51,7 +51,7 @@ void pairwise_distance_hellinger(const raft::handle_t& handle,
{
raft::distance::
distance<raft::distance::DistanceType::HellingerExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_jensen_shannon.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle,
{
raft::distance::
distance<raft::distance::DistanceType::JensenShannon, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_jensen_shannon(const raft::handle_t& handle,
Expand All @@ -49,7 +49,7 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle,
float metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::JensenShannon, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_kl_divergence.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle,
double metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::KLDivergence, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_kl_divergence(const raft::handle_t& handle,
Expand All @@ -48,7 +48,7 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle,
float metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::KLDivergence, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_l1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void pairwise_distance_l1(const raft::handle_t& handle,
double metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::L1, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_l1(const raft::handle_t& handle,
Expand All @@ -48,7 +48,7 @@ void pairwise_distance_l1(const raft::handle_t& handle,
float metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::L1, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_minkowski.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void pairwise_distance_minkowski(const raft::handle_t& handle,
double metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::LpUnexpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor, metric_arg);
handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
}

void pairwise_distance_minkowski(const raft::handle_t& handle,
Expand All @@ -48,7 +48,7 @@ void pairwise_distance_minkowski(const raft::handle_t& handle,
float metric_arg)
{
raft::distance::distance<raft::distance::DistanceType::LpUnexpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor, metric_arg);
handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
}

} // namespace Metrics
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/metrics/pairwise_distance_russell_rao.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle,
{
raft::distance::
distance<raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

void pairwise_distance_russell_rao(const raft::handle_t& handle,
Expand All @@ -50,7 +50,7 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle,
{
raft::distance::
distance<raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int>(
x, y, dist, m, n, k, handle.get_stream(), isRowMajor);
handle, x, y, dist, m, n, k, isRowMajor);
}

} // namespace Metrics
Expand Down
26 changes: 15 additions & 11 deletions cpp/test/prims/distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "test_utils.h"
#include <distance/distance.cuh>
#include <gtest/gtest.h>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -147,7 +149,8 @@ template <typename DataType>
}

template <raft::distance::DistanceType distanceType, typename DataType>
void distanceLauncher(DataType* x,
void distanceLauncher(raft::resources const& handle,
DataType* x,
DataType* y,
DataType* dist,
DataType* dist2,
Expand All @@ -165,8 +168,9 @@ void distanceLauncher(DataType* x,
dist2[g_d_idx] = (d_val < threshold) ? 0.f : d_val;
return d_val;
};

distance<distanceType, DataType, DataType, DataType>(
x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor);
handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor);
}

template <raft::distance::DistanceType distanceType, typename DataType>
Expand All @@ -181,12 +185,13 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
{
params = ::testing::TestWithParam < DistanceInputs<DataType>::GetParam();
raft::random::Rng r(params.seed);
int m = params.m;
int n = params.n;
int k = params.k;
bool isRowMajor = params.isRowMajor;
cudaStream_t stream = 0;
RAFT_CUDA_TRY(cudaStreamCreate(&stream));
int m = params.m;
int n = params.n;
int k = params.k;
bool isRowMajor = params.isRowMajor;

raft::resources handle;
auto stream = raft::resource::get_cuda_stream(handle);
x.resize(m * k, stream);
y.resize(n * k, stream);
dist_ref.resize(m * n, stream);
Expand All @@ -199,7 +204,8 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
rmm::device_uvector<char> workspace(worksize);

DataType threshold = -10000.f;
distanceLauncher<distanceType, DataType>(x.data(),
distanceLauncher<distanceType, DataType>(handle,
x.data(),
y.data(),
dist.data(),
dist2.data(),
Expand All @@ -210,9 +216,7 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
threshold,
workspace.data(),
worksize,
stream,
isRowMajor);
RAFT_CUDA_TRY(cudaStreamDestroy(stream));
}

protected:
Expand Down
2 changes: 0 additions & 2 deletions cpp/test/sg/umap_parametrizable_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
using namespace ML;
using namespace ML::Metrics;

using namespace std;

using namespace MLCommon;
using namespace MLCommon::Datasets::Digits;

Expand Down
2 changes: 1 addition & 1 deletion python/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ if has_scipy():
import scipy.sparse


cdef extern from "raft/spatial/knn/ball_cover_common.h" \
cdef extern from "raft/spatial/knn/ball_cover_types.hpp" \
namespace "raft::spatial::knn":
cdef cppclass BallCoverIndex[int64_t, float, uint32_t]:
BallCoverIndex(const handle_t &handle,
Expand Down