Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hamming, jensen-shannon, kl-divergence, correlation and russellrao distance metrics #4155

Merged
merged 23 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
dfc28a4
Add interfaces in cuml for hamming, correlation, jensen-shannon, kl-d…
mdoijade Aug 6, 2021
ad67053
fix clang formatting issues
mdoijade Aug 6, 2021
38729dc
add all new distances to main API and fix function name in correlatio…
mdoijade Aug 6, 2021
0a734fd
add python interfaces for all new dist metrics, with tests for all wo…
mdoijade Aug 10, 2021
aa072ac
add test support for kl-divergence dist metric
mdoijade Aug 11, 2021
38e1067
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Aug 11, 2021
6dd14bf
pin mdoijade raft fork for testing change
mdoijade Aug 11, 2021
911eef8
fix flake formating issues in test_metrics
mdoijade Aug 11, 2021
ae9f7a2
temp commit to trigger ci
mdoijade Aug 12, 2021
5fc77f7
temp commit to trigger ci
mdoijade Aug 12, 2021
90c8f16
temp commit to trigger ci to check updated raft changes
mdoijade Aug 12, 2021
1d5c966
temp commit to trigger ci to check updated raft changes
mdoijade Aug 12, 2021
8225ee0
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Aug 13, 2021
3f6de71
temp commit to test new raft commits
mdoijade Aug 23, 2021
19d44db
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Aug 23, 2021
b71866a
revert raft mdoijade fork as raft PR is merged now
mdoijade Aug 26, 2021
3ef11bc
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Aug 26, 2021
1a1dd71
remove redundant metric arg and switch based on it on the APIs which …
mdoijade Aug 30, 2021
96c0240
merge branch-21.10
mdoijade Aug 31, 2021
675b661
Add udevice_vector changes to new distances
mdoijade Aug 31, 2021
ac4f067
fix clang format issues
mdoijade Aug 31, 2021
8e2f01d
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Sep 1, 2021
6427584
Merge branch 'branch-21.10' into additionalDistPrims
mdoijade Sep 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,16 @@ if(BUILD_CUML_CPP_LIBRARY)
src/metrics/pairwise_distance.cu
src/metrics/pairwise_distance_canberra.cu
src/metrics/pairwise_distance_chebyshev.cu
src/metrics/pairwise_distance_correlation.cu
src/metrics/pairwise_distance_cosine.cu
src/metrics/pairwise_distance_euclidean.cu
src/metrics/pairwise_distance_hamming.cu
src/metrics/pairwise_distance_hellinger.cu
src/metrics/pairwise_distance_jensen_shannon.cu
src/metrics/pairwise_distance_kl_divergence.cu
src/metrics/pairwise_distance_l1.cu
src/metrics/pairwise_distance_minkowski.cu
src/metrics/pairwise_distance_russell_rao.cu
src/metrics/r2_score.cu
src/metrics/rand_index.cu
src/metrics/silhouette_score.cu
Expand Down
59 changes: 47 additions & 12 deletions cpp/src/metrics/pairwise_distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@
#include <raft/sparse/distance/distance.cuh>
#include "pairwise_distance_canberra.cuh"
#include "pairwise_distance_chebyshev.cuh"
#include "pairwise_distance_correlation.cuh"
#include "pairwise_distance_cosine.cuh"
#include "pairwise_distance_euclidean.cuh"
#include "pairwise_distance_hamming.cuh"
#include "pairwise_distance_hellinger.cuh"
#include "pairwise_distance_jensen_shannon.cuh"
#include "pairwise_distance_kl_divergence.cuh"
#include "pairwise_distance_l1.cuh"
#include "pairwise_distance_minkowski.cuh"
#include "pairwise_distance_russell_rao.cuh"

namespace ML {

Expand All @@ -50,22 +55,37 @@ void pairwise_distance(const raft::handle_t& handle,
pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CosineExpanded:
pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::L1:
pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Linf:
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HellingerExpanded:
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::LpUnexpanded:
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CorrelationExpanded:
pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HammingUnexpanded:
pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::JensenShannon:
pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::KLDivergence:
pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
Expand All @@ -90,22 +110,37 @@ void pairwise_distance(const raft::handle_t& handle,
pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CosineExpanded:
pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::L1:
pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Linf:
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HellingerExpanded:
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::LpUnexpanded:
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::Canberra:
pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg);
pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::CorrelationExpanded:
pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::HammingUnexpanded:
pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::JensenShannon:
pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::KLDivergence:
pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
case raft::distance::DistanceType::RusselRaoExpanded:
pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
Expand Down
30 changes: 5 additions & 25 deletions cpp/src/metrics/pairwise_distance_canberra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
* limitations under the License.
*/

//#include <cuml/metrics/metrics.hpp>
#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_canberra.cuh"

namespace ML {

Expand All @@ -30,25 +30,15 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
/* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor,
metric_arg);*/

switch (metric) {
case raft::distance::DistanceType::Canberra:
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

void pairwise_distance_canberra(const raft::handle_t& handle,
Expand All @@ -58,25 +48,15 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
/* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric,
handle.get_stream(), isRowMajor,
metric_arg);*/

switch (metric) {
case raft::distance::DistanceType::Canberra:
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Canberra>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/metrics/pairwise_distance_canberra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg);

Expand All @@ -41,7 +40,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg);

Expand Down
20 changes: 4 additions & 16 deletions cpp/src/metrics/pairwise_distance_chebyshev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,14 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());
// Call the distance function
switch (metric) {
case raft::distance::DistanceType::Linf:
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
raft::distance::pairwise_distance_impl<double, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

void pairwise_distance_chebyshev(const raft::handle_t& handle,
Expand All @@ -52,20 +46,14 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());
// Call the distance function
switch (metric) {
case raft::distance::DistanceType::Linf:
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
}
raft::distance::pairwise_distance_impl<float, int, raft::distance::DistanceType::Linf>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/metrics/pairwise_distance_chebyshev.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
double metric_arg);

Expand All @@ -39,7 +38,6 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle,
int m,
int n,
int k,
raft::distance::DistanceType metric,
bool isRowMajor,
float metric_arg);

Expand Down
65 changes: 65 additions & 0 deletions cpp/src/metrics/pairwise_distance_correlation.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>
#include <rmm/device_uvector.hpp>
#include "pairwise_distance_correlation.cuh"

namespace ML {

namespace Metrics {
void pairwise_distance_correlation(const raft::handle_t& handle,
const double* x,
const double* y,
double* dist,
int m,
int n,
int k,
bool isRowMajor,
double metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::
pairwise_distance_impl<double, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

void pairwise_distance_correlation(const raft::handle_t& handle,
const float* x,
const float* y,
float* dist,
int m,
int n,
int k,
bool isRowMajor,
float metric_arg)
{
// Allocate workspace
rmm::device_uvector<char> workspace(1, handle.get_stream());

// Call the distance function
raft::distance::
pairwise_distance_impl<float, int, raft::distance::DistanceType::CorrelationExpanded>(
x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor);
}

} // namespace Metrics
} // namespace ML
47 changes: 47 additions & 0 deletions cpp/src/metrics/pairwise_distance_correlation.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/distance.cuh>
#include <raft/handle.hpp>

namespace ML {

namespace Metrics {
void pairwise_distance_correlation(const raft::handle_t& handle,
const double* x,
const double* y,
double* dist,
int m,
int n,
int k,
bool isRowMajor,
double metric_arg);

void pairwise_distance_correlation(const raft::handle_t& handle,
const float* x,
const float* y,
float* dist,
int m,
int n,
int k,
bool isRowMajor,
float metric_arg);

} // namespace Metrics
} // namespace ML
Loading