Skip to content

Commit

Permalink
Add hamming, jensen-shannon, kl-divergence, correlation and russellra…
Browse files Browse the repository at this point in the history
…o distance metrics (#4155)

-- This PR depends on RAFT PR - rapidsai/raft#306
-- Adds cpp & python interfaces for these distance metrics with pytest support for each of them.
-- also remove redundant commented code in canberra distance metric

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #4155
  • Loading branch information
mdoijade authored Sep 14, 2021
1 parent b459603 commit 55ddc8b
Show file tree
Hide file tree
Showing 27 changed files with 675 additions and 144 deletions.
5 changes: 5 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,16 @@ if(BUILD_CUML_CPP_LIBRARY)
src/metrics/pairwise_distance.cu
src/metrics/pairwise_distance_canberra.cu
src/metrics/pairwise_distance_chebyshev.cu
src/metrics/pairwise_distance_correlation.cu
src/metrics/pairwise_distance_cosine.cu
src/metrics/pairwise_distance_euclidean.cu
src/metrics/pairwise_distance_hamming.cu
src/metrics/pairwise_distance_hellinger.cu
src/metrics/pairwise_distance_jensen_shannon.cu
src/metrics/pairwise_distance_kl_divergence.cu
src/metrics/pairwise_distance_l1.cu
src/metrics/pairwise_distance_minkowski.cu
src/metrics/pairwise_distance_russell_rao.cu
src/metrics/r2_score.cu
src/metrics/rand_index.cu
src/metrics/silhouette_score.cu
Expand Down
59 changes: 47 additions & 12 deletions cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@
#include <raft/sparse/distance/distance.cuh>
#include "pairwise_distance_canberra.cuh"
#include "pairwise_distance_chebyshev.cuh"
#include "pairwise_distance_correlation.cuh"
#include "pairwise_distance_cosine.cuh"
#include "pairwise_distance_euclidean.cuh"
#include "pairwise_distance_hamming.cuh"
#include "pairwise_distance_hellinger.cuh"
#include "pairwise_distance_jensen_shannon.cuh"
#include "pairwise_distance_kl_divergence.cuh"
#include "pairwise_distance_l1.cuh"
#include "pairwise_distance_minkowski.cuh"
#include "pairwise_distance_russell_rao.cuh"

namespace ML {

Expand All @@ -50,22 +55,37 @@ void pairwise_distance(const raft::handle_t& handle,
pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CosineExpanded:
pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::L1:
pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Linf:
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HellingerExpanded:
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::LpUnexpanded:
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CorrelationExpanded:
pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HammingUnexpanded:
pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::JensenShannon:
pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::KLDivergence:
pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
Expand All @@ -90,22 +110,37 @@ void pairwise_distance(const raft::handle_t& handle,
pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CosineExpanded:
pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::L1:
pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Linf:
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HellingerExpanded:
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::LpUnexpanded:
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CorrelationExpanded:
pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HammingUnexpanded:
pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::JensenShannon:
pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::KLDivergence:
pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
Expand Down
30 changes: 5 additions & 25 deletions cpp/src/metrics/pairwise_distance_canberra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
* limitations under the License.
*/

//#include <cuml/metrics/metrics.hpp>
#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_canberra.cuh"

namespace ML {

Expand All @@ -30,25 +30,15 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
/* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor,
metric_arg);*/

switch (metric) {
case raft::distance::DistanceType::Canberra:
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

void pairwise_distance_canberra(const raft::handle_t& handle,
Expand All @@ -58,25 +48,15 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
/* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor,
metric_arg);*/

switch (metric) {
case raft::distance::DistanceType::Canberra:
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/metrics/pairwise_distance_canberra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg);

Expand All @@ -41,7 +40,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg);

Expand Down
20 changes: 4 additions & 16 deletions cpp/src/metrics/pairwise_distance_chebyshev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,14 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());
// Call the distance function
switch (metric) {
case raft::distance::DistanceType::Linf:
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

void pairwise_distance_chebyshev(const raft::handle_t& handle,
Expand All @@ -52,20 +46,14 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());
// Call the distance function
switch (metric) {
case raft::distance::DistanceType::Linf:
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/metrics/pairwise_distance_chebyshev.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg);

Expand All @@ -39,7 +38,6 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg);

Expand Down
65 changes: 65 additions & 0 deletions cpp/src/metrics/pairwise_distance_correlation.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_correlation.cuh"

namespace ML {

namespace Metrics {
void pairwise_distance_correlation(const raft::handle_t& handle,
const double* x,
const double* y,
double* dist,
int m,
int n,
int k,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

void pairwise_distance_correlation(const raft::handle_t& handle,
const float* x,
const float* y,
float* dist,
int m,
int n,
int k,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
} // namespace ML
47 changes: 47 additions & 0 deletions cpp/src/metrics/pairwise_distance_correlation.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>

namespace ML {

namespace Metrics {
void pairwise_distance_correlation(const raft::handle_t& handle,
const double* x,
const double* y,
double* dist,
int m,
int n,
int k,
bool isRowMajor,
double metric_arg);

void pairwise_distance_correlation(const raft::handle_t& handle,
const float* x,
const float* y,
float* dist,
int m,
int n,
int k,
bool isRowMajor,
float metric_arg);

} // namespace Metrics
} // namespace ML
Loading

0 comments on commit 55ddc8b

Please sign in to comment.